Source code for jaclearn.rl.env

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : env.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/17/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import collections

from jacinle.utils.cache import cached_property

__all__ = ['RLEnvBase', 'SimpleRLEnvBase', 'ProxyRLEnvBase']


[docs] class RLEnvBase(object):
[docs] def __init__(self): self._stats = collections.defaultdict(list)
@property def stats(self): return self._stats
[docs] def append_stat(self, name, value): self._stats[name].append(value) return self
[docs] def clear_stats(self): self._stats = collections.defaultdict(list) return self
@cached_property def action_space(self): return self._get_action_space() @property def current_state(self): return self._get_current_state()
[docs] def action(self, action): return self._action(action)
[docs] def restart(self, *args, **kwargs): return self._restart(*args, **kwargs)
[docs] def finish(self, *args, **kwargs): return self._finish(*args, **kwargs)
[docs] def play_one_episode(self, func, ret_states=False, ret_actions=False, restart_kwargs=None, finish_kwargs=None, max_steps=10000): states = [] actions = [] self.restart(**(restart_kwargs or {})) for step in range(max_steps): state = self.current_state action = func(state) r, is_over = self.action(action) if ret_actions: actions.append(action) if ret_states: states.append(state) if is_over: self.finish(**(finish_kwargs or {})) break if ret_states: states.append(self.current_state) returns = [] if ret_states: returns.append(states) if ret_actions: returns.append(actions) return returns[0] if len(returns) == 1 else tuple(returns)
def _get_action_space(self): return None def _get_current_state(self): return None def _action(self, action): raise NotImplementedError() def _restart(self, *args, **kwargs): raise NotImplementedError() def _finish(self, *args, **kwargs): pass @property def unwrapped(self): return self
[docs] def evaluate_one_episode(self, func): self.play_one_episode(func) score = self.stats['score'][-1] self.clear_stats() return score
[docs] class SimpleRLEnvBase(RLEnvBase): _current_state = None
[docs] def __init__(self): super().__init__() self._reward_history = []
def _get_current_state(self): return self._current_state def _set_current_state(self, state): self._current_state = state
[docs] def action(self, action): r, is_over = self._action(action) self._reward_history.append(r) return r, is_over
[docs] def restart(self, *args, **kwargs): rc = self._restart(*args, **kwargs) self._reward_history = [] return rc
[docs] def finish(self, *args, **kwargs): rc = self._finish(*args, **kwargs) self.append_stat('score', sum(self._reward_history)) self.append_stat('length', len(self._reward_history)) return rc
[docs] class ProxyRLEnvBase(RLEnvBase):
[docs] def __init__(self, other): super().__init__() self.__proxy = other
@property def proxy(self): return self.__proxy @property def stats(self): return self.__proxy.stats
[docs] def append_stat(self, name, value): self.__proxy.append_stat(name) return self
[docs] def clear_stats(self): self.__proxy.clear_stats() return self
def _get_action_space(self): return self.__proxy.action_space def _get_current_state(self): return self.__proxy.current_state def _action(self, action): return self.__proxy.action(action) def _restart(self, *args, **kwargs): return self.__proxy.restart(*args, **kwargs) def _finish(self, *args, **kwargs): return self.__proxy.finish(*args, **kwargs) @property def unwrapped(self): return self.proxy.unwrapped