Source code for jaclearn.rl.engines.mujoco.gym_recorder
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : gym_recorder.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 04/16/2020
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os.path as osp
import mujoco_py
from copy import deepcopy
import jacinle.io as io
import jaclearn.math.rotation as R
from jacinle.utils.enum import JacEnum
[docs]
class MujocoObjectType(JacEnum):
BODY = 'body'
GEOM = 'geom'
SITE = 'site'
[docs]
class MujocoGymRecorder(object):
[docs]
def __init__(self, gym_environ):
self.gym_environ = gym_environ
self.sim = gym_environ.sim
self.gym_states = list()
self.object_names = dict()
self.object_poses = list()
self.logs = dict()
self.reset()
def _init_object_names(self):
for obj_type in MujocoObjectType.choice_objs():
all_names = getattr(self.sim.model, obj_type.value + '_names')
self.object_names[obj_type.value] = deepcopy(all_names)
[docs]
def reset(self):
self._init_object_names()
self.gym_states = list()
self.object_poses = list()
self.logs = dict()
self.step()
[docs]
def log(self, key, message):
log_index = len(self.gym_states) - 1
if log_index not in self.logs:
self.logs[log_index] = dict()
self.logs[log_index][key] = message
[docs]
def step(self):
self.gym_states.append(self.sim.get_state())
new_object_poses = dict()
for obj_type, names in self.object_names.items():
new_object_poses[obj_type] = this_type = {n: dict() for n in names}
for value_name in ['xpos', 'xquat', 'xvelp', 'xvelr']:
if obj_type in ('geom', 'site') and value_name == 'xquat':
for i, (name, value) in enumerate(zip(names, getattr(self.sim.data, obj_type + '_' + 'xmat'))):
this_type[name][value_name] = R.mat2quat(value.reshape(3, 3)).tolist()
else:
for i, (name, value) in enumerate(zip(names, getattr(self.sim.data, obj_type + '_' + value_name))):
this_type[name][value_name] = value.tolist()
self.object_poses.append(new_object_poses)
[docs]
def dump(self, save_dir):
io.mkdir(save_dir)
io.dump(osp.join(save_dir, 'mj_states.pkl'), self.gym_states)
io.dump(osp.join(save_dir, 'objects.json'), dict(
names=self.object_names,
poses=self.object_poses
))
io.dump(osp.join(save_dir, 'logs.json'), self.logs)
[docs]
def hook(self):
old_reset = self.gym_environ.reset
old_step = self.gym_environ.step
def reset(*args, **kwargs):
retval = old_reset(*args, **kwargs)
self.reset()
return retval
def step(*args, **kwargs):
retval = old_step(*args, **kwargs)
self.step()
return retval
self.gym_environ.reset = reset
self.gym_environ.step = step
return self