Source code for jaclearn.rl.envs.maze.taxi
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : taxi.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/17/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import numpy as np
from .maze import MazeEnv, CustomLavaWorldEnv
from ...env import SimpleRLEnvBase
__all__ = ['CustomTaxiEnv', 'CustomLavaWorldTaxiEnv']
[docs]
class CustomTaxiEnv(SimpleRLEnvBase):
_start_point = None
_final_point1 = None
_final_point2 = None
[docs]
def __init__(self, maze_env=None, use_coord=False, *args, **kwargs):
super().__init__()
if maze_env is None:
maze_env = MazeEnv(*args, **kwargs)
assert isinstance(maze_env, MazeEnv)
self._maze_env = maze_env
self._use_coord = use_coord
self._phase = 0
@property
def maze_env(self):
return self._maze_env
@property
def phase(self):
return self._phase
def _get_action_space(self):
return self._maze_env.action_space
def _refresh_current_state(self):
state = self._maze_env.current_state
if not self._use_coord:
self._set_current_state(state)
else:
x = np.zeros(shape=(self._maze_env.map_size[1], ), dtype='uint8')
y = np.zeros(shape=(self._maze_env.map_size[0], ), dtype='uint8')
if self._phase == 2:
state = state.copy()
fp = self._maze_env.final_point
x[fp[1]], y[fp[0]] = 1, 1
if (state[fp[0]+1, fp[1]+1] == self._maze_env._colors[3]).all():
state[fp[0]+1, fp[1]+1] = self._maze_env._colors[0]
self._set_current_state((state, x, y))
[docs]
def restart(self, start_point=None, final_point1=None, final_point2=None):
self._start_point = start_point
self._final_point1 = final_point1
self._final_point2 = final_point2
self._maze_env.restart(start_point=start_point, final_point=final_point1)
self._phase = 1
self._refresh_current_state()
super().restart()
def _restart(self):
pass
def _action(self, action):
# Sanity phase check
assert self._phase in (1, 2), 'Invalid phase'
reward, is_over = self._maze_env.action(action)
if is_over:
if self._phase == 1:
self._phase = 2
self._enter_phase2()
is_over = False
self._refresh_current_state()
return reward, is_over
def _enter_phase2(self):
self._maze_env.restart(
obstacles=self._maze_env.obstacles,
start_point=self._maze_env.current_point,
final_point=self._final_point2
)
def _finish(self):
self._maze_env.finish()
[docs]
class CustomLavaWorldTaxiEnv(CustomTaxiEnv):
[docs]
def __init__(self, maze_env=None, use_coord=False, *args, **kwargs):
if maze_env is None:
maze_env = CustomLavaWorldEnv(*args, **kwargs)
assert isinstance(maze_env, CustomLavaWorldEnv)
super().__init__(maze_env, use_coord=use_coord)
def _enter_phase2(self):
self._maze_env.restart(start_point=self._maze_env.current_point, final_point=self._final_point2)
def _finish(self):
super()._finish()
if self._phase == 2 and self._maze_env._current_point == self._maze_env._final_point:
self.append_stat('success', 1)
else:
self.append_stat('success', 0)