Source code for jaclearn.rl.envs.gym

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : gym.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/17/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import threading
import numpy as np
import collections

import jacinle.io as jacio

from ..env import SimpleRLEnvBase, ProxyRLEnvBase
from ..space import DiscreteActionSpace, ContinuousActionSpace

try:
    import gym
    import gym.wrappers
except ImportError:
    gym = None

_ENV_LOCK = threading.Lock()


[docs] def get_env_lock(): return _ENV_LOCK
__all__ = ['GymRLEnv', 'GymAtariRLEnv', 'GymPreventStuckProxy']
[docs] class GymRLEnv(SimpleRLEnvBase):
[docs] def __init__(self, name, dump_dir=None, force_dump=False, state_mode='DEFAULT'): super().__init__() with get_env_lock(): self._gym = self._make_env(name) if dump_dir: jacio.mkdir(dump_dir) self._gym = gym.wrappers.Monitor(self._gym, dump_dir, force=force_dump) assert state_mode in ('DEFAULT', 'RENDER', 'BOTH') self._state_mode = state_mode
def _make_env(self, name): return gym.make(name) @property def gym(self): return self._gym
[docs] def render(self, mode='human', close=False): return self._gym.render(mode=mode, close=close)
def _set_current_state(self, o): if self._state_mode == 'DEFAULT': pass else: rendered = self.render('rgb_array') if self._state_mode == 'RENDER': o = rendered else: o = (o, rendered) super()._set_current_state(o) def _get_action_space(self): spc = self._gym.action_space if isinstance(spc, gym.spaces.discrete.Discrete): try: action_meanings = self._gym.unwrapped.get_action_meanings() except AttributeError: if 'Atari' in self._gym.unwrapped.__class__.__name__: from gym.envs.atari.atari_env import ACTION_MEANING action_meanings = [ACTION_MEANING[i] for i in range(spc.n)] else: action_meanings = ['unknown{}'.format(i) for i in range(spc.n)] return DiscreteActionSpace(spc.n, action_meanings=action_meanings) elif isinstance(spc, gym.spaces.box.Box): return ContinuousActionSpace(spc.low, spc.high, spc.shape) else: raise ValueError('Unknown gym space spec: {}.'.format(spc)) def _action(self, action): # hack for continuous control if type(action) in (tuple, list): action = np.array(action) o, r, is_over, info = self._gym.step(action) self._set_current_state(o) return r, is_over def _restart(self): o = self._gym.reset() self._set_current_state(o) def _finish(self): self._gym.close()
[docs] class GymAtariRLEnv(GymRLEnv):
[docs] def __init__(self, name, *args, live_lost_as_eoe=True, **kwargs): super().__init__(name, *args, **kwargs) self._live_lost_as_eoe = live_lost_as_eoe
def _action(self, action): old_lives = self._gym.unwrapped.ale.lives() r, is_over = super()._action(action) new_lives = self._gym.unwrapped.ale.lives() if self._live_lost_as_eoe and old_lives > new_lives: is_over = True return r, is_over
[docs] class GymPreventStuckProxy(ProxyRLEnvBase):
[docs] def __init__(self, other, max_repeat, action): super().__init__(other) self._action_list = collections.deque(maxlen=max_repeat) self._insert_action = action
def _action(self, action): self._action_list.append(action) if self._action_list.count(self._action_list[0]) == self._action_list.maxlen: action = self._insert_action r, is_over = self.proxy.action(action) if is_over: self._action_list.clear() return r, is_over def _restart(self, *args, **kwargs): super()._restart(*args, **kwargs) self._action_list.clear()