diff --git a/assets/mujoco_envs/ant.xml b/assets/mujoco_envs/ant.xml
index 330d532..668cb8c 100644
--- a/assets/mujoco_envs/ant.xml
+++ b/assets/mujoco_envs/ant.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/assets/mujoco_envs/gap.xml b/assets/mujoco_envs/gap.xml
index fe21e11..5059ffe 100644
--- a/assets/mujoco_envs/gap.xml
+++ b/assets/mujoco_envs/gap.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/assets/mujoco_envs/hopper.xml b/assets/mujoco_envs/hopper.xml
index 6b127d3..9fbe3e7 100644
--- a/assets/mujoco_envs/hopper.xml
+++ b/assets/mujoco_envs/hopper.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/assets/mujoco_envs/swimmer.xml b/assets/mujoco_envs/swimmer.xml
index a46cef1..ae772f6 100644
--- a/assets/mujoco_envs/swimmer.xml
+++ b/assets/mujoco_envs/swimmer.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/assets/mujoco_envs/walker.xml b/assets/mujoco_envs/walker.xml
index c905262..0cacb43 100644
--- a/assets/mujoco_envs/walker.xml
+++ b/assets/mujoco_envs/walker.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/design_opt/agents/genesis_agent.py b/design_opt/agents/genesis_agent.py
index 442c72c..f4f33ad 100644
--- a/design_opt/agents/genesis_agent.py
+++ b/design_opt/agents/genesis_agent.py
@@ -1,3 +1,14 @@
+"""
+BodyGenAgent: Training agent for robot body generation using PPO.
+
+This agent implements the proximal policy optimization (PPO) algorithm
+for co-designing robot morphology and controller. It supports:
+- Multi-threaded trajectory sampling
+- Separated value estimation for design and control stages
+- Checkpoint loading/saving for training resumption
+- Visualization of trained policies
+"""
+
import math
import pickle
import time
@@ -16,7 +27,17 @@
import wandb
+
def tensorfy(np_list, device=torch.device('cpu')):
+ """Convert numpy arrays to tensors, handling nested structures.
+
+ Args:
+ np_list: List of numpy arrays or nested lists
+ device: Target torch device
+
+ Returns:
+ Nested list of torch tensors on specified device
+ """
if isinstance(np_list[0], list):
return [[torch.tensor(x).to(device) if i <= 1 or i == 4 or i >= 7 else x for i, x in enumerate(y)] for y in np_list]
else:
@@ -24,6 +45,7 @@ def tensorfy(np_list, device=torch.device('cpu')):
class BodyGenAgent(AgentPPO):
+ """PPO agent for robot body generation training."""
def __init__(self, cfg, dtype, device, seed, num_threads, training=True, checkpoint=0):
self.cfg = cfg
@@ -146,13 +168,15 @@ def sample_worker(self, pid, queue, min_batch_size, mean_action, render):
if pid > 0:
torch.manual_seed(torch.randint(0, 5000, (1,)) * pid)
if hasattr(self.env, 'np_random'):
- self.env.np_random.seed(self.env.np_random.randint(5000) * pid)
+ # NumPy 1.17+ uses Generator.choose() instead of seed()
+ seed = int(self.env.np_random.integers(0, 5000) * pid)
+ self.env.np_random = np.random.default_rng(seed)
memory = Memory()
logger = self.logger_cls(**self.logger_kwargs)
while logger.num_steps < min_batch_size:
- state = self.env.reset()
+ state, _ = self.env.reset()
logger.start_episode(self.env)
while True:
@@ -527,7 +551,7 @@ def visualize_agent(self, num_episode=1, mean_action=True, save_video=False, pau
self.obs_norm.to('cpu')
for _ in range(num_episode):
- state = env.reset()
+ state, _ = env.reset()
env._get_viewer('human')._paused = paused
env.render()
diff --git a/design_opt/envs/ant.py b/design_opt/envs/ant.py
index f739ffb..28f8bff 100644
--- a/design_opt/envs/ant.py
+++ b/design_opt/envs/ant.py
@@ -1,11 +1,11 @@
import numpy as np
-from gym import utils
+from gymnasium import utils
from khrylib.rl.envs.common.mujoco_env_gym import MujocoEnv
from khrylib.robot.xml_robot import Robot
from khrylib.utils import get_single_body_qposaddr, get_graph_fc_edges
from khrylib.utils.transformation import quaternion_matrix
from copy import deepcopy
-import mujoco_py
+import mujoco
import time
import os
@@ -104,9 +104,11 @@ def action_to_control(self, a):
assert a.shape[0] == len(self.robot.bodies)
for body, body_a in zip(self.robot.bodies[1:], a[1:]):
aname = body.get_actuator_name()
- if aname in self.model.actuator_names:
- aind = self.model.actuator_names.index(aname)
+ try:
+ aind = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, aname)
ctrl[aind] = body_a
+ except ValueError:
+ pass # actuator not found
return ctrl
def step(self, a):
@@ -205,8 +207,9 @@ def if_use_transform_action(self):
def get_sim_obs(self):
obs = []
if 'root_offset' in self.sim_specs:
- root_pos = self.data.body_xpos[self.model._body_name2id[self.robot.bodies[0].name]]
-
+ root_body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, self.robot.bodies[0].name)
+ root_pos = self.data.body(root_body_id).xpos
+
for i, body in enumerate(self.robot.bodies):
qvel = self.data.qvel.copy()
if self.clip_qvel:
@@ -218,11 +221,11 @@ def get_sim_obs(self):
if qe - qs >= 1:
assert qe - qs == 1
obs_i = [np.zeros(11), self.data.qpos[qs:qe], qvel[qs-1:qe-1]]
- # print(qs)
else:
obs_i = [np.zeros(13)]
if 'root_offset' in self.sim_specs:
- offset = self.data.body_xpos[self.model._body_name2id[body.name]][[0, 2]] - root_pos[[0, 2]]
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body.name)
+ offset = self.data.body(body_id).xpos[[0, 2]] - root_pos[[0, 2]]
obs_i.append(offset)
obs_i = np.concatenate(obs_i)
obs.append(obs_i)
@@ -342,9 +345,7 @@ def reset_model(self):
return self._get_obs()
def viewer_setup(self):
- # self.viewer.cam.trackbodyid = 2
self.viewer.cam.distance = 10
- # self.viewer.cam.lookat[2] = 1.15
- self.viewer.cam.lookat[:2] = self.data.qpos[:2]
+ self.viewer.cam.lookat[:2] = self.data.qpos[:2]
self.viewer.cam.elevation = -10
self.viewer.cam.azimuth = 110
\ No newline at end of file
diff --git a/design_opt/envs/gap.py b/design_opt/envs/gap.py
index dc7cfc6..bf01293 100644
--- a/design_opt/envs/gap.py
+++ b/design_opt/envs/gap.py
@@ -1,10 +1,11 @@
import numpy as np
import os
import os.path as osp
-from gym import utils
+from gymnasium import utils
from khrylib.rl.envs.common.mujoco_env_gym import MujocoEnv
from khrylib.robot.xml_robot import Robot
from khrylib.utils import get_single_body_qposaddr, get_graph_fc_edges
+import mujoco
import shutil
@@ -95,7 +96,6 @@ def set_design_params(self, in_design_params):
design_params = in_design_params
for params, body in zip(design_params, self.robot.bodies):
body.set_params(params, pad_zeros=True, map_params=True)
- # new_params = body.get_params([], pad_zeros=True, demap_params=True)
body.sync_node()
xml_str = self.robot.export_xml_string()
@@ -116,8 +116,11 @@ def action_to_control(self, a):
assert a.shape[0] == len(self.robot.bodies)
for body, body_a in zip(self.robot.bodies[1:], a[1:]):
aname = body.get_actuator_name()
- aind = self.model.actuator_names.index(aname)
- ctrl[aind] = body_a
+ try:
+ aind = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, aname)
+ ctrl[aind] = body_a
+ except ValueError:
+ pass
return ctrl
def step(self, a):
@@ -165,7 +168,7 @@ def step(self, a):
assert np.all(a[:, self.control_action_dim:] == 0)
control_a = a[:, :self.control_action_dim]
ctrl = self.action_to_control(control_a)
- posbefore = self.sim.data.qpos[0]
+ posbefore = self.data.qpos[0]
try:
self.do_simulation(ctrl, self.frame_skip)
@@ -173,11 +176,10 @@ def step(self, a):
print(self.cur_xml_str)
return self._get_obs(), 0, True, False, {'use_transform_action': False, 'stage': 'execution'}
- posafter, height, ang = self.sim.data.qpos[0:3]
+ posafter, height, ang = self.data.qpos[0:3]
alive_bonus = self.cfg.reward_specs.get('alive_bonus', 0.0)
reward = (posafter - posbefore) / self.dt
reward += alive_bonus
- # reward -= 1e-3 * np.square(a).sum()
scale = self.cfg.reward_specs.get('exec_reward_scale', 1.0)
reward *= scale
@@ -213,9 +215,10 @@ def if_use_transform_action(self):
def get_sim_obs(self):
obs = []
if 'root_offset' in self.sim_specs:
- root_pos = self.data.body_xpos[self.model._body_name2id[self.robot.bodies[0].name]]
-
- phase_ob = [(self.sim.data.qpos[0]) / 3.2 - int((self.sim.data.qpos[0]) / 3.2)]
+ root_body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, self.robot.bodies[0].name)
+ root_pos = self.data.body(root_body_id).xpos
+
+ phase_ob = [(self.data.qpos[0]) / 3.2 - int((self.data.qpos[0]) / 3.2)]
for i, body in enumerate(self.robot.bodies):
qvel = self.data.qvel.copy()
if self.clip_qvel:
@@ -226,9 +229,9 @@ def get_sim_obs(self):
qs, qe = get_single_body_qposaddr(self.model, body.name)
assert qe - qs == 1
obs_i = [phase_ob, self.data.qpos[qs:qe], np.zeros(1), qvel[qs:qe], np.zeros(2)]
- # print(qs)
if 'root_offset' in self.sim_specs:
- offset = self.data.body_xpos[self.model._body_name2id[body.name]][[0, 2]] - root_pos[[0, 2]]
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body.name)
+ offset = self.data.body(body_id).xpos[[0, 2]] - root_pos[[0, 2]]
obs_i.append(offset)
obs_i = np.concatenate(obs_i)
obs.append(obs_i)
@@ -334,7 +337,6 @@ def reset_state(self, add_noise):
has_contact = False
for contact in self.data.contact[:self.data.ncon]:
g1, g2 = contact.geom1, contact.geom2
- # print(f'g1: {g1} g2: {g2}')
if g1 in self.ground_geoms or g2 in self.ground_geoms:
has_contact = True
break
diff --git a/design_opt/envs/hopper.py b/design_opt/envs/hopper.py
index 8dc8050..7106758 100644
--- a/design_opt/envs/hopper.py
+++ b/design_opt/envs/hopper.py
@@ -1,10 +1,10 @@
import numpy as np
-from gym import utils
+from gymnasium import utils
from khrylib.rl.envs.common.mujoco_env_gym import MujocoEnv
from khrylib.robot.xml_robot import Robot
from khrylib.utils import get_single_body_qposaddr, get_graph_fc_edges
from copy import deepcopy
-import mujoco_py
+import mujoco
import time
import os
@@ -103,8 +103,11 @@ def action_to_control(self, a):
assert a.shape[0] == len(self.robot.bodies)
for body, body_a in zip(self.robot.bodies[1:], a[1:]):
aname = body.get_actuator_name()
- aind = self.model.actuator_names.index(aname)
- ctrl[aind] = body_a
+ try:
+ aind = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, aname)
+ ctrl[aind] = body_a
+ except ValueError:
+ pass
return ctrl
def step(self, a):
@@ -152,7 +155,7 @@ def step(self, a):
assert np.all(a[:, self.control_action_dim:] == 0)
control_a = a[:, :self.control_action_dim]
ctrl = self.action_to_control(control_a)
- posbefore = self.sim.data.qpos[0]
+ posbefore = self.data.qpos[0]
try:
self.do_simulation(ctrl, self.frame_skip)
@@ -160,7 +163,7 @@ def step(self, a):
print(self.cur_xml_str)
return self._get_obs(), 0, True, False, {'use_transform_action': False, 'stage': 'execution'}
- posafter, height, ang = self.sim.data.qpos[0:3]
+ posafter, height, ang = self.data.qpos[0:3]
alive_bonus = self.cfg.reward_specs.get('alive_bonus', 0.0)
reward = (posafter - posbefore) / self.dt
reward += alive_bonus
@@ -199,8 +202,9 @@ def if_use_transform_action(self):
def get_sim_obs(self):
obs = []
if 'root_offset' in self.sim_specs:
- root_pos = self.data.body_xpos[self.model._body_name2id[self.robot.bodies[0].name]]
-
+ root_body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, self.robot.bodies[0].name)
+ root_pos = self.data.body(root_body_id).xpos
+
for i, body in enumerate(self.robot.bodies):
qvel = self.data.qvel.copy()
if self.clip_qvel:
@@ -211,9 +215,9 @@ def get_sim_obs(self):
qs, qe = get_single_body_qposaddr(self.model, body.name)
assert qe - qs == 1
obs_i = [self.data.qpos[qs:qe], np.zeros(1), qvel[qs:qe], np.zeros(2)]
- # print(qs)
if 'root_offset' in self.sim_specs:
- offset = self.data.body_xpos[self.model._body_name2id[body.name]][[0, 2]] - root_pos[[0, 2]]
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body.name)
+ offset = self.data.body(body_id).xpos[[0, 2]] - root_pos[[0, 2]]
obs_i.append(offset)
obs_i = np.concatenate(obs_i)
obs.append(obs_i)
@@ -322,7 +326,6 @@ def reset_state(self, add_noise):
has_contact = False
for contact in self.data.contact[:self.data.ncon]:
g1, g2 = contact.geom1, contact.geom2
- # print(f'g1: {g1} g2: {g2}')
if g1 in self.ground_geoms or g2 in self.ground_geoms:
has_contact = True
break
diff --git a/design_opt/envs/swimmer.py b/design_opt/envs/swimmer.py
index 419c8db..348ffca 100644
--- a/design_opt/envs/swimmer.py
+++ b/design_opt/envs/swimmer.py
@@ -1,10 +1,10 @@
import numpy as np
-from gym import utils
+from gymnasium import utils
from khrylib.rl.envs.common.mujoco_env_gym import MujocoEnv
from khrylib.robot.xml_robot import Robot
from khrylib.utils import get_single_body_qposaddr, get_graph_fc_edges
from copy import deepcopy
-import mujoco_py
+import mujoco
import time
import os
@@ -82,7 +82,6 @@ def set_design_params(self, in_design_params):
design_params = in_design_params
for params, body in zip(design_params, self.robot.bodies):
body.set_params(params, pad_zeros=True, map_params=True)
- # new_params = body.get_params([], pad_zeros=True, demap_params=True)
body.sync_node()
xml_str = self.robot.export_xml_string()
@@ -103,8 +102,11 @@ def action_to_control(self, a):
assert a.shape[0] == len(self.robot.bodies)
for body, body_a in zip(self.robot.bodies[1:], a[1:]):
aname = body.get_actuator_name()
- aind = self.model.actuator_names.index(aname)
- ctrl[aind] = body_a
+ try:
+ aind = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, aname)
+ ctrl[aind] = body_a
+ except ValueError:
+ pass
return ctrl
def step(self, a):
@@ -153,7 +155,7 @@ def step(self, a):
control_a = a[:, :self.control_action_dim]
ctrl = self.action_to_control(control_a)
ctrl_cost_coeff = self.cfg.reward_specs.get('ctrl_cost_coeff', 1e-4)
- xposbefore = self.sim.data.qpos[0]
+ xposbefore = self.data.qpos[0]
try:
self.do_simulation(ctrl, self.frame_skip)
@@ -161,7 +163,7 @@ def step(self, a):
print(self.cur_xml_str)
return self._get_obs(), 0, True, False, {'use_transform_action': False, 'stage': 'execution'}
- xposafter = self.sim.data.qpos[0]
+ xposafter = self.data.qpos[0]
reward_fwd = (xposafter - xposbefore) / self.dt
reward_ctrl = - ctrl_cost_coeff * np.square(ctrl).mean()
reward = reward_fwd + reward_ctrl
@@ -195,8 +197,9 @@ def if_use_transform_action(self):
def get_sim_obs(self):
obs = []
if 'root_offset' in self.sim_specs:
- root_pos = self.data.body_xpos[self.model._body_name2id[self.robot.bodies[0].name]]
-
+ root_body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, self.robot.bodies[0].name)
+ root_pos = self.data.body(root_body_id).xpos
+
for i, body in enumerate(self.robot.bodies):
qvel = self.data.qvel.copy()
if self.clip_qvel:
@@ -207,9 +210,9 @@ def get_sim_obs(self):
qs, qe = get_single_body_qposaddr(self.model, body.name)
assert qe - qs == 1
obs_i = [self.data.qpos[qs:qe], qvel[qs:qe], np.zeros(2)]
- # print(qs)
if 'root_offset' in self.sim_specs:
- offset = self.data.body_xpos[self.model._body_name2id[body.name]][[0, 2]] - root_pos[[0, 2]]
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body.name)
+ offset = self.data.body(body_id).xpos[[0, 2]] - root_pos[[0, 2]]
obs_i.append(offset)
obs_i = np.concatenate(obs_i)
obs.append(obs_i)
diff --git a/design_opt/envs/walker.py b/design_opt/envs/walker.py
index e0f4e4e..3c82bba 100644
--- a/design_opt/envs/walker.py
+++ b/design_opt/envs/walker.py
@@ -1,10 +1,10 @@
import numpy as np
-from gym import utils
+from gymnasium import utils
from khrylib.rl.envs.common.mujoco_env_gym import MujocoEnv
from khrylib.robot.xml_robot import Robot
from khrylib.utils import get_single_body_qposaddr, get_graph_fc_edges
from copy import deepcopy
-import mujoco_py
+import mujoco
import time
import os
@@ -84,7 +84,6 @@ def set_design_params(self, in_design_params):
design_params = in_design_params
for params, body in zip(design_params, self.robot.bodies):
body.set_params(params, pad_zeros=True, map_params=True)
- # new_params = body.get_params([], pad_zeros=True, demap_params=True)
body.sync_node()
xml_str = self.robot.export_xml_string()
@@ -105,8 +104,11 @@ def action_to_control(self, a):
assert a.shape[0] == len(self.robot.bodies)
for body, body_a in zip(self.robot.bodies[1:], a[1:]):
aname = body.get_actuator_name()
- aind = self.model.actuator_names.index(aname)
- ctrl[aind] = body_a
+ try:
+ aind = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, aname)
+ ctrl[aind] = body_a
+ except ValueError:
+ pass
return ctrl
def step(self, a):
@@ -154,7 +156,7 @@ def step(self, a):
assert np.all(a[:, self.control_action_dim:] == 0)
control_a = a[:, :self.control_action_dim]
ctrl = self.action_to_control(control_a)
- posbefore = self.sim.data.qpos[0]
+ posbefore = self.data.qpos[0]
try:
self.do_simulation(ctrl, self.frame_skip)
@@ -162,7 +164,7 @@ def step(self, a):
print(self.cur_xml_str)
return self._get_obs(), 0, True, False, {'use_transform_action': False, 'stage': 'execution'}
- posafter, height, ang = self.sim.data.qpos[0:3]
+ posafter, height, ang = self.data.qpos[0:3]
alive_bonus = self.cfg.reward_specs.get('alive_bonus', 0.0)
reward = (posafter - posbefore) / self.dt
reward += alive_bonus
@@ -201,8 +203,9 @@ def if_use_transform_action(self):
def get_sim_obs(self):
obs = []
if 'root_offset' in self.sim_specs:
- root_pos = self.data.body_xpos[self.model._body_name2id[self.robot.bodies[0].name]]
-
+ root_body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, self.robot.bodies[0].name)
+ root_pos = self.data.body(root_body_id).xpos
+
for i, body in enumerate(self.robot.bodies):
qvel = self.data.qvel.copy()
if self.clip_qvel:
@@ -213,9 +216,9 @@ def get_sim_obs(self):
qs, qe = get_single_body_qposaddr(self.model, body.name)
assert qe - qs == 1
obs_i = [self.data.qpos[qs:qe], np.zeros(1), qvel[qs:qe], np.zeros(2)]
- # print(qs)
if 'root_offset' in self.sim_specs:
- offset = self.data.body_xpos[self.model._body_name2id[body.name]][[0, 2]] - root_pos[[0, 2]]
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body.name)
+ offset = self.data.body(body_id).xpos[[0, 2]] - root_pos[[0, 2]]
obs_i.append(offset)
obs_i = np.concatenate(obs_i)
obs.append(obs_i)
diff --git a/design_opt/models/bodygen_critic.py b/design_opt/models/bodygen_critic.py
index 0576fb5..c6c1e2e 100644
--- a/design_opt/models/bodygen_critic.py
+++ b/design_opt/models/bodygen_critic.py
@@ -1,3 +1,14 @@
+"""
+BodyGenValue: Value function network for robot body generation.
+
+This module estimates the expected return (value) for each design stage:
+- Skeleton transform stage: Value of morphology structure decisions
+- Attribute transform stage: Value of body parameter adjustments
+- Execution stage: Value of control policies
+
+Uses separate transformer encoders and value heads for each stage.
+"""
+
import torch.nn as nn
import torch
import numpy as np
@@ -12,18 +23,15 @@
class BodyGenValue(nn.Module):
- def __init__(self, cfg, agent):
super().__init__()
self.cfg = cfg
self.agent = agent
self.state_dim = agent.state_dim
if not agent.cfg.uni_obs_norm:
- # self.norm = RunningNorm(self.state_dim)
self.skel_norm = RunningNorm(self.state_dim)
self.attr_norm = RunningNorm(self.state_dim)
self.control_norm = RunningNorm(self.state_dim)
else:
- # self.norm = None
self.skel_norm = self.attr_norm = self.control_norm = None
cur_dim = self.state_dim
@@ -34,20 +42,17 @@ def __init__(self, cfg, agent):
cur_dim = self.control_transformer.out_dim
if 'mlp' in cfg:
- # self.mlp = MLP(cur_dim, cfg['mlp'], cfg['htype'])
self.skel_mlp = MLP(cur_dim, cfg['mlp'], cfg['htype'])
self.attr_mlp = MLP(cur_dim, cfg['mlp'], cfg['htype'])
self.control_mlp = MLP(cur_dim, cfg['mlp'], cfg['htype'])
cur_dim = self.control_mlp.out_dim
else:
self.skel_mlp = self.attr_mlp = self.control_mlp = None
-
- # self.value_head = nn.Linear(cur_dim, 1)
+
self.skel_value_head = nn.Linear(cur_dim, 1)
self.attr_value_head = nn.Linear(cur_dim, 1)
self.control_value_head = nn.Linear(cur_dim, 1)
-
- # init_fc_weights(self.value_head)
+
init_fc_weights(self.skel_value_head)
init_fc_weights(self.attr_value_head)
init_fc_weights(self.control_value_head)
diff --git a/design_opt/models/bodygen_policy.py b/design_opt/models/bodygen_policy.py
index 863e862..4a33c98 100644
--- a/design_opt/models/bodygen_policy.py
+++ b/design_opt/models/bodygen_policy.py
@@ -1,3 +1,16 @@
+"""
+BodyGenPolicy: Policy network for robot body generation.
+
+This module implements a three-stage policy network for co-designing robot
+body morphology and controller:
+
+1. Skeleton Transform Stage: Decides to add/remove body parts (discrete action)
+2. Attribute Transform Stage: Adjusts body parameters (continuous action)
+3. Execution Stage: Generates motor control signals (continuous action)
+
+Each stage uses a Transformer encoder to process graph-structured body observations.
+"""
+
from collections import defaultdict
from khrylib.utils.torch import LongTensor
import torch.nn as nn
diff --git a/design_opt/models/transformer.py b/design_opt/models/transformer.py
index c4f9e56..2c3939b 100644
--- a/design_opt/models/transformer.py
+++ b/design_opt/models/transformer.py
@@ -4,7 +4,15 @@
import math
import logging
+
class MaskedSelfAttention(nn.Module):
+ """Masked self-attention mechanism for processing variable-length node sequences.
+
+ Implements scaled dot-product attention with an attention mask to handle
+ padding in batched sequences. The mask ensures attention is only computed
+ between valid (non-padding) nodes.
+ """
+
def __init__(self, hidden_dim) -> None:
super().__init__()
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
@@ -12,6 +20,18 @@ def __init__(self, hidden_dim) -> None:
self.fc_v = nn.Linear(hidden_dim, hidden_dim)
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0):
+ """Compute scaled dot-product attention.
+
+ Args:
+ query: Query tensor of shape (B, L, D)
+ key: Key tensor of shape (B, S, D)
+ value: Value tensor of shape (B, S, D)
+ attn_mask: Boolean mask for attention (True = valid, False = padding)
+ dropout_p: Dropout probability
+
+ Returns:
+ Attention output tensor of shape (B, L, D)
+ """
L, S = query.size(-2), key.size(-2)
B = query.size(0)
scale_factor = 1 / math.sqrt(query.size(-1))
@@ -25,14 +45,31 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
-
+
def forward(self, x, attn_mask=None):
+ """Forward pass through masked self-attention.
+
+ Args:
+ x: Input tensor of shape (B, L, D)
+ attn_mask: Attention mask for padding
+
+ Returns:
+ Output tensor of shape (B, L, D)
+ """
q, k, v = self.fc_q(x), self.fc_k(x), self.fc_v(x)
-
out = self.scaled_dot_product_attention(q, k, v, attn_mask)
return out
+
class TransformerBlock(nn.Module):
+ """Single transformer block with masked self-attention and MLP.
+
+ Each block consists of:
+ 1. Masked self-attention layer
+ 2. Feed-forward MLP (4x hidden dimension)
+ 3. Layer normalization (pre-norm or post-norm)
+ """
+
def __init__(self, hidden_dim, norm_type="pre", act_layer=nn.SiLU) -> None:
super().__init__()
HIDDEN_RATIO = 4
@@ -60,38 +97,64 @@ def forward(self, x, attn_mask=None):
return x
class TransformerSimple(nn.Module):
+ """Transformer encoder for processing robot body graph observations.
+
+ This module processes variable-length sequences of body node observations
+ through stacked transformer blocks with position embeddings. It handles
+ padding masks to process batches of sequences with different lengths.
+
+ Args:
+ in_dim: Input feature dimension per node
+ cfg: Configuration dict containing:
+ - hidden_dim: Hidden dimension for transformer
+ - block_depth: Number of transformer blocks
+ - norm_type: "pre" or "post" normalization
+ - pos_emb_type: "index" or "travel" position embedding type
+ node_dim: Optional node type dimension (unused, kept for compatibility)
+ lapPE_k: Laplacian positional encoding dimension (unused, kept for compatibility)
+
+ Input:
+ transformer_obs: Dict containing:
+ - padded_obs: Padded node observations of shape (B, L_max, in_dim)
+ - padding_mask: Boolean mask of shape (B, L_max)
+ - padded_body_ind: Body indices for position embedding of shape (B, L_max)
+
+ Output:
+ Output tensor of shape (B, L_valid, hidden_dim) where L_valid is the
+ number of non-padding nodes in each sequence.
+ """
+
def __init__(self, in_dim, cfg, node_dim=0, lapPE_k=4):
super(TransformerSimple, self).__init__()
self.cfg = cfg
self.node_dim = node_dim
self.hidden_dim = hidden_dim = cfg['hidden_dim']
self.num_layers = num_layers = cfg['block_depth']
-
+
self.norm_type = norm_type = cfg['norm_type']
self.pos_emb_type = pos_emb_type = cfg['pos_emb_type']
self.out_dim = hidden_dim
-
+
self.in_fc = nn.Linear(in_dim, hidden_dim)
-
+
if pos_emb_type == "index":
self.index_embedding = nn.Embedding(256, hidden_dim)
elif pos_emb_type == "travel":
self.travel_embedding = nn.Embedding(256, hidden_dim)
-
+
self.blocks = nn.ModuleList([
TransformerBlock(hidden_dim, norm_type) for _ in range(num_layers)
])
-
- def forward(self, transformer_obs):
+ def forward(self, transformer_obs):
x = transformer_obs["padded_obs"]
padding_mask = transformer_obs["padding_mask"]
padded_body_ind = transformer_obs["padded_body_ind"]
-
- # project to hidden dimension
+
+ # Project input features to hidden dimension
x = self.in_fc(x)
- # position embedding
+ # Add position embedding
if self.pos_emb_type == "index":
pos_emb = self.index_embedding(padded_body_ind)
x = x + pos_emb
@@ -100,13 +163,15 @@ def forward(self, transformer_obs):
position_indices = torch.arange(0, L, dtype=torch.long, device=x.device).unsqueeze(0)
pos_emb = self.travel_embedding(position_indices)
x = x + pos_emb
-
- ## attention mask for padding
+
+ # Create attention mask for padding
attn_mask = padding_mask.unsqueeze(1) if padding_mask is not None else None
-
+
+ # Pass through transformer blocks
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
-
+
+ # Remove padding nodes
x = x[padding_mask]
-
+
return x
\ No newline at end of file
diff --git a/design_opt/train.py b/design_opt/train.py
index 737056c..c4380d1 100644
--- a/design_opt/train.py
+++ b/design_opt/train.py
@@ -29,6 +29,7 @@ def main_loop(FLAGS):
start_epoch = int(FLAGS.epoch) if FLAGS.epoch.isnumeric() else FLAGS.epoch
"""create agent"""
+ #实例化智能体,传入的cfg(配置)决定了训练的是什么构型(walker、hopper等等)以及transformer的层数等等
agent = BodyGenAgent(cfg=cfg, dtype=dtype, device=device, seed=cfg.seed, num_threads=FLAGS.num_threads, training=True, checkpoint=start_epoch)
if FLAGS.render:
diff --git a/design_opt/utils/config.py b/design_opt/utils/config.py
index 93183a8..266b63d 100644
--- a/design_opt/utils/config.py
+++ b/design_opt/utils/config.py
@@ -1,3 +1,10 @@
+"""
+Config: Configuration management for BodyGen training.
+
+Loads configuration from YAML files and CLI flags, sets up output directories,
+and provides access to all training, environment, and model parameters.
+"""
+
import yaml
import os
import glob
@@ -5,6 +12,7 @@
class Config:
+ """Configuration class that loads and manages all training parameters."""
def __init__(self, FLAG, project_path, base_dir=None):
cfg_id = FLAG.cfg
diff --git a/docs/README_CN.md b/docs/README_CN.md
new file mode 100644
index 0000000..a2eb2f7
--- /dev/null
+++ b/docs/README_CN.md
@@ -0,0 +1,400 @@
+# BodyGen 代码说明文档
+
+## 项目概述
+
+BodyGen 是一个基于强化学习的机器人形态与控制器协同设计(Co-Design)系统。该系统能够自动生成具有不同身体结构的机器人,并同时优化其运动控制器。系统采用三阶段设计流程:骨架变换阶段、属性变换阶段和执行阶段,通过 Transformer 架构处理图结构化的机器人身体信息。
+
+## 代码架构
+
+### 目录结构
+
+项目代码主要包含以下几个核心模块:
+
+```
+BodyGen-main/
+├── design_opt/ # 核心设计优化代码
+│ ├── agents/ # 训练智能体
+│ │ └── genesis_agent.py # BodyGenAgent 智能体实现
+│ ├── models/ # 神经网络模型
+│ │ ├── bodygen_policy.py # 策略网络(三阶段动作生成)
+│ │ ├── bodygen_critic.py # 价值函数网络
+│ │ └── transformer.py # Transformer 编码器
+│ ├── envs/ # 仿真环境
+│ │ ├── hopper.py # hopper 环境
+│ │ ├── swimmer.py # swimmer 环境
+│ │ ├── ant.py # ant 环境
+│ │ ├── gap.py # gap 地形环境
+│ │ └── walker.py # walker 环境
+│ ├── utils/ # 工具函数
+│ │ ├── config.py # 配置管理
+│ │ ├── logger.py # 日志记录
+│ │ └── tools.py # 辅助工具
+│ ├── cfg/ # 配置文件
+│ │ ├── walker-regular.yml
+│ │ ├── walker-medium.yml
+│ │ ├── walker-hard.yml
+│ │ ├── glider-*.yml
+│ │ └── ...
+│ ├── conf/ # Hydra 配置
+│ │ └── config.yaml
+│ ├── train.py # 训练入口
+│ └── eval.py # 评估入口
+├── khrylib/ # 基础强化学习库
+│ ├── rl/
+│ │ ├── agents/ # PPO 智能体基类
+│ │ ├── core/ # 核心组件
+│ │ │ ├── policy.py # 策略基类
+│ │ │ ├── critic.py # 价值函数基类
+│ │ │ ├── distributions.py # 动作分布
+│ │ │ └── running_norm.py # 归一化
+│ │ └── envs/ # MuJoCo 环境基类
+│ ├── models/ # 基础模型
+│ │ └── mlp.py # 多层感知机
+│ └── utils/ # 工具函数
+│ ├── torch.py # PyTorch 辅助
+│ ├── math.py # 数学工具
+│ └── mujoco.py # MuJoCo 辅助
+└── assets/ # 资源文件
+ ├── mujoco_envs/ # MuJoCo 机器人模型 XML
+ └── mujoco_terrains/ # 地形文件
+```
+
+### 核心类与功能
+
+**BodyGenAgent(genesis_agent.py)**
+- 继承自 AgentPPO,是整个训练流程的核心控制器
+- 负责环境创建、策略/价值网络初始化、优化器设置
+- 实现多线程轨迹采样、PPO 更新、日志记录等功能
+- 提供模型保存加载、策略可视化功能
+
+**BodyGenPolicy(bodygen_policy.py)**
+- 三阶段策略网络,分别处理骨架变换、属性变换和执行控制
+- 每个阶段包含独立的 Transformer 编码器和 MLP 头
+- 输出离散动作(骨架变换)或连续动作(属性/控制)
+
+**BodyGenValue(bodygen_critic.py)**
+- 价值函数网络,为每个设计阶段估计状态价值
+- 包含三个独立的 Transformer 编码器和价值头
+
+**TransformerSimple(transformer.py)**
+- 处理变长节点序列的 Transformer 编码器
+- 支持掩码注意力机制,处理批处理中的填充节点
+
+**环境类(hopper/swalker/ant/gap/walker.py)**
+- 继承自 MuJoCo 环境,实现机器人仿真
+- 包含三阶段状态机管理、奖励计算、观察构建等功能
+
+## 三阶段设计流程
+
+BodyGen 采用三个连续的阶段来设计和控制机器人:
+
+### 1. 骨架变换阶段(Skeleton Transform Stage)
+
+该阶段负责机器人的拓扑结构变化,包括添加或移除身体部件。策略网络输出离散动作:0 表示保持不变,1 表示添加子部件,2 表示移除部件(如果允许)。此阶段持续 `skel_transform_nsteps` 步,每步可以修改一个节点。观察空间包含:固定属性(深度、关节范围、是否可添加/移除)、设计参数。
+
+### 2. 属性变换阶段(Attribute Transform Stage)
+
+在确定骨架结构后,该阶段优化身体部件的几何参数。策略网络输出连续动作,包括部件尺寸、关节位置、执行器参数等。动作被缩放后加到当前参数上,或直接作为绝对值(取决于 `abs_design` 设置)。此阶段仅持续 1 步。
+
+### 3. 执行阶段(Execution Stage)
+
+该阶段运行训练好的策略来控制机器人运动。策略输出执行器控制信号,奖励函数主要包含前进速度奖励和控制代价。环境在此阶段与标准强化学习环境类似,直到终止条件满足。
+
+## 配置参数详解
+
+### 全局训练参数(config.yaml)
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `num_threads` | 20 | 并行采样线程数,越大采样越快但显存占用越高 |
+| `gpu_index` | 0 | 使用的 GPU 编号 |
+| `seed` | 1 | 随机种子,确保实验可复现 |
+| `max_epoch_num` | 1000 | 训练总轮数 |
+| `min_batch_size` | 50000 | 每轮采样的最小轨迹步数 |
+| `mini_batch_size` | 2048 | PPO 更新时的小批量大小 |
+| `num_optim_epoch` | 10 | 每轮采样的策略更新轮数 |
+
+### PPO 算法参数
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `gamma` | 0.995 | 折扣因子,影响未来奖励的权重 |
+| `tau` | 0.95 | GAE 参数,影响优势函数估计的偏差-方差权衡 |
+| `clip_epsilon` | 0.2 | PPO 截断系数,控制策略更新幅度 |
+| `max_grad_norm` | 40 | 梯度裁剪阈值,防止梯度爆炸 |
+| `policy_lr` | 5e-5 | 策略网络学习率 |
+| `value_lr` | 3e-4 | 价值网络学习率 |
+| `lr_decay` | false | 是否启用学习率衰减 |
+| `norm_return` | true | 是否对回报进行归一化 |
+| `norm_advantage` | true | 是否对优势函数进行归一化 |
+
+### 策略网络配置(policy_specs)
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `attr_log_std` | -2.3 | 属性动作的对数标准差初始值 |
+| `fix_control_std` | false | 是否固定控制动作的标准差 |
+| `fix_attr_std` | false | 是否固定属性动作的标准差 |
+
+**Transformer 配置(skel_transformer_specs 等):**
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `hidden_dim` | 64 | Transformer 隐藏层维度 |
+| `block_depth` | 3 | Transformer 块的数量 |
+| `norm_type` | pre | 归一化类型,可选 "pre" 或 "post" |
+| `pos_emb_type` | index | 位置编码类型,可选 "index" 或 "travel" |
+
+### 价值网络配置(value_specs)
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `mlp` | [512, 256] | MLP 隐藏层维度列表 |
+| `htype` | tanh | MLP 激活函数类型 |
+
+### 环境配置(env_name.yml)
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `env_name` | walker | 环境名称,可选 hopper/swimmer/ant/gap/walker |
+| `gamma` | 0.995 | 环境特定的折扣因子 |
+| `skel_transform_nsteps` | 5 | 骨架变换阶段的步数 |
+| `robot_param_scale` | 1 | 动作参数缩放系数 |
+| `max_body_depth` | 4 | 允许的最大身体深度(添加层级数) |
+| `enable_remove` | true | 是否允许移除身体部件 |
+
+**奖励配置(reward_specs):**
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `ctrl_cost_coeff` | 1e-4 | 控制代价系数 |
+| `alive_bonus` | 0.0 | 存活奖励 |
+| `exec_reward_scale` | 1.0 | 执行阶段奖励缩放因子 |
+
+**终止条件(done_condition):**
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `min_height` | 0.7 | 最小高度,低于此值终止 |
+| `max_height` | 2.0 | 最大高度,高于此值终止 |
+| `max_ang` | 60 | 最大角度(度),超过此值终止 |
+| `max_nsteps` | 1000 | 最大步数限制 |
+
+**观察空间配置(obs_specs):**
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `clip_qvel` | true | 是否裁剪速度观测 |
+| `use_projected_params` | true | 是否使用投影后的参数 |
+| `abs_design` | false | 是否使用绝对参数 |
+| `use_body_ind` | true | 是否包含身体索引 |
+| `use_body_depth_height` | true | 是否包含深度和高度信息 |
+| `use_shortest_distance` | true | 是否包含最短距离信息 |
+| `use_position_encoding` | true | 是否包含位置编码 |
+| `fc_graph` | false | 是否使用全连接图边 |
+
+**添加身体条件(add_body_condition):**
+
+| 参数名 | 默认值 | 说明 |
+|--------|--------|------|
+| `max_nchild` | 1 | 每个节点最大子节点数 |
+| `min_nchild` | 0 | 每个节点最小子节点数 |
+
+## 快速上手指南
+
+### 环境安装
+
+```bash
+# 创建 conda 环境
+conda create -n bodygen python=3.8
+conda activate bodygen
+
+# 安装依赖
+pip install torch mujoco mujoco-py gymnasium
+pip install hydra-core omegaconf wandb tensorboard
+pip install numpy pyyaml
+```
+
+### 训练新模型
+
+1. 复制配置文件进行修改:
+
+```bash
+cp design_opt/cfg/walker-regular.yml design_opt/cfg/my_config.yml
+```
+
+2. 修改配置文件中的参数(参考上文参数说明)
+
+3. 运行训练:
+
+```bash
+python -m design_opt.train
+```
+
+或者通过命令行覆盖配置:
+
+```bash
+python -m design_opt.train num_threads=10 max_epoch_num=500 policy_lr=1e-4
+```
+
+### 评估与可视化
+
+```bash
+# 评估训练好的模型
+python -m design_opt.eval --train_dir results/walker-regular/0 --epoch 100
+
+# 可视化策略(带视频录制)
+python -m design_opt.eval --train_dir results/walker-regular/0 --save_video
+
+# 暂停设计阶段,仅展示控制效果
+python -m design_opt.eval --train_dir results/walker-regular/0 --save_video --pause_design
+```
+
+### 使用不同环境
+
+修改配置文件中的 `env_name`:
+
+```yaml
+# hopper 环境
+env_name: hopper
+
+# swimmer 环境
+env_name: swimmer
+
+# ant 环境
+env_name: ant
+
+# 穿越 gap 地形
+env_name: gap
+```
+
+## 常用修改示例
+
+### 1. 调整 Transformer 规模
+
+```yaml
+# 更大的模型
+policy_specs:
+ skel_transformer_specs:
+ hidden_dim: 128
+ block_depth: 5
+ attr_transformer_specs:
+ hidden_dim: 128
+ block_depth: 5
+ control_transformer_specs:
+ hidden_dim: 128
+ block_depth: 5
+```
+
+### 2. 调整学习率
+
+```yaml
+policy_lr: 1e-4
+value_lr: 5e-4
+```
+
+### 3. 调整采样和更新
+
+```yaml
+# 增大批量大小(更稳定但更慢)
+min_batch_size: 100000
+mini_batch_size: 4096
+
+# 减少更新轮数(更快但可能不稳定)
+num_optim_epoch: 5
+```
+
+### 4. 调整骨架变换行为
+
+```yaml
+# 增加骨架变换步数(更多添加/移除机会)
+skel_transform_nsteps: 10
+
+# 允许更复杂的结构
+add_body_condition:
+ max_nchild: 3
+max_body_depth: 5
+```
+
+### 5. 调整奖励函数
+
+```yaml
+reward_specs:
+ ctrl_cost_coeff: 1e-3 # 增加控制代价(更平滑的动作)
+ alive_bonus: 1.0 # 添加存活奖励
+ exec_reward_scale: 0.1 # 缩放整体奖励
+```
+
+## 输出文件说明
+
+训练过程中会在 `results//` 目录下生成以下文件:
+
+| 文件/目录 | 说明 |
+|-----------|------|
+| `models/epoch_*.p` | 每隔 `save_model_interval` 保存的检查点 |
+| `models/best.p` | 最佳模型检查点 |
+| `log/log_train.txt` | 训练日志 |
+| `tb/` | TensorBoard 日志(可用 `tensorboard --logdir tb` 查看) |
+
+## 扩展指南
+
+### 添加新环境
+
+1. 在 `design_opt/envs/` 下创建新环境文件,继承 `MujocoEnv`
+
+2. 实现核心方法:
+ - `__init__`:初始化环境参数
+ - `step`:执行动作,返回 (观察, 奖励, 终止, 截断, 信息)
+ - `reset_model`:重置环境状态
+ - `get_sim_obs`:构建仿真观测
+ - `get_attr_fixed`:构建固定属性观测
+ - `get_attr_design`:构建设计参数观测
+
+3. 在 `design_opt/envs/__init__.py` 中注册环境
+
+4. 在 `design_opt/envs/__init__.py` 的 `env_dict` 中添加环境映射
+
+### 添加新的观察特征
+
+在环境类的 `_get_obs` 方法中修改:
+
+```python
+def _get_obs(self):
+ obs = []
+ # 添加新观察
+ if 'new_feature' in self.cfg.obs_specs:
+ new_obs = self.compute_new_feature()
+ all_obs.append(new_obs)
+ # ... 其他代码
+ return all_obs
+```
+
+### 修改网络结构
+
+修改 `transformer.py` 中的 `TransformerBlock` 或 `TransformerSimple` 类,可自定义注意力机制、位置编码、归一化方式等。
+
+## 常见问题
+
+### 1. 显存不足
+
+- 减小 `num_threads`(减少并行采样数)
+- 减小 `mini_batch_size`
+- 减小 Transformer 的 `hidden_dim`
+
+### 2. 训练不稳定
+
+- 减小学习率 `policy_lr` / `value_lr`
+- 增大 `clip_epsilon`
+- 启用 `norm_return` 和 `norm_advantage`
+
+### 3. 生成的机器人结构过于简单
+
+- 增加 `skel_transform_nsteps`
+- 增大 `max_body_depth`
+- 增大 `add_body_condition.max_nchild`
+
+### 4. 评估奖励与训练奖励差异大
+
+- 使用 `--epoch best` 加载最佳检查点
+- 检查 `norm_return` 是否启用
+- 增加 `eval_batch_size` 以获得更稳定的评估结果
diff --git a/khrylib/rl/agents/agent_ppo.py b/khrylib/rl/agents/agent_ppo.py
index a3575c0..4d36531 100644
--- a/khrylib/rl/agents/agent_ppo.py
+++ b/khrylib/rl/agents/agent_ppo.py
@@ -4,6 +4,11 @@
class AgentPPO(AgentPG):
+ """Proximal Policy Optimization (PPO) Agent implementation.
+
+ PPO is a policy gradient method that optimizes a clipped surrogate objective
+ to balance exploration and exploitation while ensuring stable updates.
+ """
def __init__(self, clip_epsilon=0.2, mini_batch_size=64, use_mini_batch=False,
policy_grad_clip=None, **kwargs):
@@ -12,54 +17,3 @@ def __init__(self, clip_epsilon=0.2, mini_batch_size=64, use_mini_batch=False,
self.mini_batch_size = mini_batch_size
self.use_mini_batch = use_mini_batch
self.policy_grad_clip = policy_grad_clip
-
- # def update_policy(self, states, actions, returns, advantages, exps):
- # """update policy"""
- # with to_test(*self.update_modules):
- # with torch.no_grad():
- # fixed_log_probs = self.policy_net.get_log_prob(self.trans_policy(states), actions)
-
- # for _ in range(self.opt_num_epochs):
- # if self.use_mini_batch:
- # perm = np.arange(states.shape[0])
- # np.random.shuffle(perm)
- # perm = LongTensor(perm).to(self.device)
-
- # states, actions, returns, advantages, fixed_log_probs, exps = \
- # states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
- # fixed_log_probs[perm].clone(), exps[perm].clone()
-
- # optim_iter_num = int(math.floor(states.shape[0] / self.mini_batch_size))
- # for i in range(optim_iter_num):
- # ind = slice(i * self.mini_batch_size, min((i + 1) * self.mini_batch_size, states.shape[0]))
- # states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b, exps_b = \
- # states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind], exps[ind]
- # ind = exps_b.nonzero(as_tuple=False).squeeze(1)
- # self.update_value(states_b, returns_b)
- # surr_loss = self.ppo_loss(states_b, actions_b, advantages_b, fixed_log_probs_b, ind)
- # self.optimizer_policy.zero_grad()
- # surr_loss.backward()
- # self.clip_policy_grad()
- # self.optimizer_policy.step()
- # else:
- # ind = exps.nonzero(as_tuple=False).squeeze(1)
- # self.update_value(states, returns)
- # surr_loss = self.ppo_loss(states, actions, advantages, fixed_log_probs, ind)
- # self.optimizer_policy.zero_grad()
- # surr_loss.backward()
- # self.clip_policy_grad()
- # self.optimizer_policy.step()
-
- # def clip_policy_grad(self):
- # if self.policy_grad_clip is not None:
- # for params, max_norm in self.policy_grad_clip:
- # torch.nn.utils.clip_grad_norm_(params, max_norm)
-
- # def ppo_loss(self, states, actions, advantages, fixed_log_probs, ind):
- # log_probs = self.policy_net.get_log_prob(self.trans_policy(states)[ind], actions[ind])
- # ratio = torch.exp(log_probs - fixed_log_probs[ind])
- # advantages = advantages[ind]
- # surr1 = ratio * advantages
- # surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * advantages
- # surr_loss = -torch.min(surr1, surr2).mean()
- # return surr_loss
diff --git a/khrylib/rl/envs/common/mjviewer.py b/khrylib/rl/envs/common/mjviewer.py
index 3a78aee..8d36957 100644
--- a/khrylib/rl/envs/common/mjviewer.py
+++ b/khrylib/rl/envs/common/mjviewer.py
@@ -1,159 +1,76 @@
-from threading import Lock
+"""
+Viewer for MuJoCo 3.x
+Uses the official mujoco.viewer module with custom extensions.
+"""
+import mujoco
import glfw
-from mujoco_py.builder import cymj
-from mujoco_py.generated import const
+from threading import Lock
import time
import copy
-from multiprocessing import Process, Queue
-from mujoco_py.utils import rec_copy, rec_assign
import numpy as np
import imageio
-class MjViewerBasic(cymj.MjRenderContextWindow):
+class MjViewerBasic:
"""
- A simple display GUI showing the scene of an :class:`.MjSim` with a mouse-movable camera.
- :class:`.MjViewer` extends this class to provide more sophisticated playback and interaction controls.
- Parameters
- ----------
- sim : :class:`.MjSim`
- The simulator to display.
+ A simple display GUI showing the scene of a simulation with a mouse-movable camera.
+ Uses MuJoCo 3.x native viewer functionality.
"""
- def __init__(self, sim):
- super().__init__(sim)
-
+ def __init__(self, model, data):
+ self.model = model
+ self.data = data
self._gui_lock = Lock()
self._button_left_pressed = False
self._button_right_pressed = False
self._last_mouse_x = 0
self._last_mouse_y = 0
- framebuffer_width, _ = glfw.get_framebuffer_size(self.window)
- window_width, _ = glfw.get_window_size(self.window)
- self._scale = framebuffer_width * 1.0 / window_width
-
- glfw.set_cursor_pos_callback(self.window, self._cursor_pos_callback)
- glfw.set_mouse_button_callback(
- self.window, self._mouse_button_callback)
- glfw.set_scroll_callback(self.window, self._scroll_callback)
- glfw.set_key_callback(self.window, self.key_callback)
+ # Create offscreen renderer for rendering
+ self._renderer = None
+ self._run_speed = 1.0
+ self._paused = False
def render(self):
- """
- Render the current simulation state to the screen or off-screen buffer.
- Call this in your main loop.
- """
- if self.window is None:
- return
- elif glfw.window_should_close(self.window):
- exit(0)
-
- with self._gui_lock:
- super().render()
-
- glfw.poll_events()
-
- def key_callback(self, window, key, scancode, action, mods):
- if action == glfw.RELEASE and key == glfw.KEY_ESCAPE:
- print("Pressed ESC")
- print("Quitting.")
- exit(0)
-
- def _cursor_pos_callback(self, window, xpos, ypos):
- if not (self._button_left_pressed or self._button_right_pressed):
+ """Render the current simulation state."""
+ if self._renderer is None:
return
-
- # Determine whether to move, zoom or rotate view
- mod_shift = (
- glfw.get_key(window, glfw.KEY_LEFT_SHIFT) == glfw.PRESS or
- glfw.get_key(window, glfw.KEY_RIGHT_SHIFT) == glfw.PRESS)
- if self._button_right_pressed:
- action = const.MOUSE_MOVE_H if mod_shift else const.MOUSE_MOVE_V
- elif self._button_left_pressed:
- action = const.MOUSE_ROTATE_H if mod_shift else const.MOUSE_ROTATE_V
- else:
- action = const.MOUSE_ZOOM
-
- # Determine
- dx = int(self._scale * xpos) - self._last_mouse_x
- dy = int(self._scale * ypos) - self._last_mouse_y
- width, height = glfw.get_framebuffer_size(window)
-
with self._gui_lock:
- self.move_camera(action, dx / height, dy / height)
-
- self._last_mouse_x = int(self._scale * xpos)
- self._last_mouse_y = int(self._scale * ypos)
-
- def _mouse_button_callback(self, window, button, act, mods):
- self._button_left_pressed = (
- glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS)
- self._button_right_pressed = (
- glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_RIGHT) == glfw.PRESS)
-
- x, y = glfw.get_cursor_pos(window)
- self._last_mouse_x = int(self._scale * x)
- self._last_mouse_y = int(self._scale * y)
-
- def _scroll_callback(self, window, x_offset, y_offset):
- with self._gui_lock:
- self.move_camera(const.MOUSE_ZOOM, 0, -0.05 * y_offset)
+ self._renderer.update_scene(self.data)
+ self._renderer.render()
class MjViewer(MjViewerBasic):
"""
- Extends :class:`.MjViewerBasic` to add video recording, interactive time and interaction controls.
- The key bindings are as follows:
- - TAB: Switch between MuJoCo cameras.
- - H: Toggle hiding all GUI components.
- - SPACE: Pause/unpause the simulation.
- - RIGHT: Advance simulation by one step.
- - V: Start/stop video recording.
- - T: Capture screenshot.
- - I: Drop into ``ipdb`` debugger.
- - S/F: Decrease/Increase simulation playback speed.
- - C: Toggle visualization of contact forces (off by default).
- - D: Enable/disable frame skipping when rendering lags behind real time.
- - R: Toggle transparency of geoms.
- - M: Toggle display of mocap bodies.
- - 0-4: Toggle display of geomgroups
- Parameters
- ----------
- sim : :class:`.MjSim`
- The simulator to display.
+ Extends MjViewerBasic to add video recording and interactive controls.
+ Uses mujoco.viewer.launch_passive for the main window.
"""
- def __init__(self, sim):
- super().__init__(sim)
+ def __init__(self, model, data, use_passive=True):
+ super().__init__(model, data)
- self._ncam = sim.model.ncam
- self._paused = False # is viewer paused.
- # should we advance viewer just by one step.
+ self._paused = False
self._advance_by_one_step = False
- # Vars for recording video
+ # Video recording vars
self._record_video = False
- self._video_queue = Queue()
+ self._video_queue = None
self._video_idx = 0
self._video_path = "/tmp/video_%07d.mp4"
- # vars for capturing screen
+ # Screen capture vars
self._image_idx = 0
self._image_path = "/tmp/frame_%07d.png"
- # run_speed = x1, means running real time, x2 means fast-forward times
- # two.
self._run_speed = 1.0
self._loop_count = 0
self._render_every_frame = False
- self._show_mocap = True # Show / hide mocap bodies.
- self._transparent = False # Make everything transparent.
+ self._show_mocap = True
+ self._transparent = False
- # this variable is estamated as a running average.
self._time_per_render = 1 / 60.0
- self._hide_overlay = False # hide the entire overlay.
+ self._hide_overlay = False
self._user_overlay = {}
self.video_fps = 30
@@ -161,220 +78,151 @@ def __init__(self, sim):
self.sim_time = 0
self.custom_key_callback = None
- def render(self):
- """
- Render the current simulation state to the screen or off-screen buffer.
- Call this in your main loop.
- """
-
- def render_inner_loop(self):
- render_start = time.time()
-
- self._overlay.clear()
- if not self._hide_overlay:
- for k, v in self._user_overlay.items():
- self._overlay[k] = copy.deepcopy(v)
- self._create_full_overlay()
-
- super().render()
- if self._record_video:
- frame = self._read_pixels_as_in_window()
- self._video_queue.put(frame)
- else:
- self._time_per_render = 0.9 * self._time_per_render + \
- 0.1 * (time.time() - render_start)
-
- self._user_overlay = copy.deepcopy(self._overlay)
- # Render the same frame if paused.
- if self._paused:
- while self._paused:
- render_inner_loop(self)
- if self._advance_by_one_step:
- self._advance_by_one_step = False
- break
- else:
- # inner_loop runs "_loop_count" times in expectation (where "_loop_count" is a float).
- # Therefore, frames are displayed in the real-time.
- self._loop_count += (self.sim.model.opt.timestep * self.frame_skip - self.sim_time) / \
- (self._time_per_render * self._run_speed)
- if self._render_every_frame:
- self._loop_count = 1
- while self._loop_count > 0:
- render_inner_loop(self)
- self._loop_count -= 1
- # Markers and overlay are regenerated in every pass.
- self._markers[:] = []
- self._overlay.clear()
-
- def _read_pixels_as_in_window(self):
- # Reads pixels with markers and overlay from the same camera as screen.
- resolution = glfw.get_framebuffer_size(
- self.sim._render_context_window.window)
-
- resolution = np.array(resolution)
- resolution = resolution * min(1000 / np.min(resolution), 1)
- resolution = resolution.astype(np.int32)
- resolution -= resolution % 16
- if self.sim._render_context_offscreen is None:
- self.sim.render(resolution[0], resolution[1])
- offscreen_ctx = self.sim._render_context_offscreen
- window_ctx = self.sim._render_context_window
- # Save markers and overlay from offscreen.
- saved = [copy.deepcopy(offscreen_ctx._markers),
- copy.deepcopy(offscreen_ctx._overlay),
- rec_copy(offscreen_ctx.cam)]
- # Copy markers and overlay from window.
- offscreen_ctx._markers[:] = window_ctx._markers[:]
- offscreen_ctx._overlay.clear()
- offscreen_ctx._overlay.update(window_ctx._overlay)
- rec_assign(offscreen_ctx.cam, rec_copy(window_ctx.cam))
-
- img = self.sim.render(*resolution)
- img = img[::-1, :, :] # Rendered images are upside-down.
- # Restore markers and overlay to offscreen.
- offscreen_ctx._markers[:] = saved[0][:]
- offscreen_ctx._overlay.clear()
- offscreen_ctx._overlay.update(saved[1])
- rec_assign(offscreen_ctx.cam, saved[2])
- return img
-
- def _create_full_overlay(self):
- if self._render_every_frame:
- self.add_overlay(const.GRID_TOPLEFT, "", "")
- else:
- self.add_overlay(const.GRID_TOPLEFT, "Run speed = %.3f x real time" %
- self._run_speed, "[S]lower, [F]aster")
- self.add_overlay(
- const.GRID_TOPLEFT, "Ren[d]er every frame", "Off" if self._render_every_frame else "On")
- self.add_overlay(const.GRID_TOPLEFT, "Switch camera (#cams = %d)" % (self._ncam + 1),
- "[Tab] (camera ID = %d)" % self.cam.fixedcamid)
- self.add_overlay(const.GRID_TOPLEFT, "[C]ontact forces", "Off" if self.vopt.flags[
- 10] == 1 else "On")
- self.add_overlay(
- const.GRID_TOPLEFT, "Referenc[e] frames", "Off" if self.vopt.frame == 1 else "On")
- self.add_overlay(const.GRID_TOPLEFT,
- "T[r]ansparent", "On" if self._transparent else "Off")
- self.add_overlay(
- const.GRID_TOPLEFT, "Display [M]ocap bodies", "On" if self._show_mocap else "Off")
- if self._paused is not None:
+ # Use MuJoCo 3.x built-in viewer
+ self._passive_viewer = None
+ self._offscreen_renderer = None
+ self._use_passive = use_passive
+
+ # For screenshot compatibility
+ self.window = None
+
+ def _init_passive_viewer(self):
+ """Initialize the interactive MuJoCo 3.x viewer"""
+ if self._passive_viewer is None:
+ self._passive_viewer = mujoco.viewer.launch_passive(
+ self.model,
+ self.data,
+ show_left_ui=True,
+ show_right_ui=True
+ )
+ # Get window handle from the viewer
+ self.window = getattr(self._passive_viewer, 'window', None)
+
+ def _init_offscreen_renderer(self):
+ """Initialize offscreen renderer"""
+ if self._offscreen_renderer is None:
+ self._offscreen_renderer = mujoco.Renderer(self.model, 800, 600)
+
+ def render(self, width=None, height=None):
+ """Render the current simulation state."""
+ if self._use_passive:
+ # Use interactive viewer
+ self._init_passive_viewer()
if not self._paused:
- self.add_overlay(const.GRID_TOPLEFT, "Stop", "[Space]")
- else:
- self.add_overlay(const.GRID_TOPLEFT, "Start", "[Space]")
- self.add_overlay(const.GRID_TOPLEFT,
- "Advance simulation by one step", "[right arrow]")
- self.add_overlay(const.GRID_TOPLEFT, "[H]ide Menu", "")
- if self._record_video:
- ndots = int(7 * (time.time() % 1))
- dots = ("." * ndots) + (" " * (6 - ndots))
- self.add_overlay(const.GRID_TOPLEFT,
- "Record [V]ideo (On) " + dots, "")
- else:
- self.add_overlay(const.GRID_TOPLEFT, "Record [V]ideo (Off) ", "")
- if self._video_idx > 0:
- fname = self._video_path % (self._video_idx - 1)
- self.add_overlay(const.GRID_TOPLEFT, " saved as %s" % fname, "")
-
- self.add_overlay(const.GRID_TOPLEFT, "Cap[t]ure frame", "")
- if self._image_idx > 0:
- fname = self._image_path % (self._image_idx - 1)
- self.add_overlay(const.GRID_TOPLEFT, " saved as %s" % fname, "")
- self.add_overlay(const.GRID_TOPLEFT, "Start [i]pdb", "")
- if self._record_video:
- extra = " (while video is not recorded)"
+ self._passive_viewer.render()
else:
- extra = ""
- self.add_overlay(const.GRID_BOTTOMLEFT, "FPS", "%d%s" %
- (1 / self._time_per_render, extra))
- self.add_overlay(const.GRID_BOTTOMLEFT, "Solver iterations", str(
- self.sim.data.solver_iter + 1))
- step = round(self.sim.data.time / self.sim.model.opt.timestep)
- self.add_overlay(const.GRID_BOTTOMRIGHT, "Step", str(step))
- self.add_overlay(const.GRID_TOPLEFT, "Toggle geomgroup visibility", "0-4")
-
- def key_callback(self, window, key, scancode, action, mods):
- if self.custom_key_callback is not None:
- res = self.custom_key_callback(key, action, mods)
- if res:
+ # Use offscreen rendering
+ self._init_offscreen_renderer()
+
+ # For rendering to array (offscreen)
+ if width is not None and height is not None:
+ self._offscreen_renderer.update_scene(self.data)
+ self._offscreen_renderer.render(width, height)
return
- if action != glfw.RELEASE:
- return
- elif key == glfw.KEY_TAB: # Switches cameras.
- self.cam.fixedcamid += 1
- self.cam.type = const.CAMERA_FIXED
- if self.cam.fixedcamid >= self._ncam:
- self.cam.fixedcamid = -1
- self.cam.type = const.CAMERA_FREE
- elif key == glfw.KEY_H: # hides all overlay.
- self._hide_overlay = not self._hide_overlay
- elif key == glfw.KEY_SPACE and self._paused is not None: # stops simulation.
- self._paused = not self._paused
- # Advances simulation by one step.
- elif key == glfw.KEY_RIGHT and self._paused is not None:
- self._advance_by_one_step = True
- self._paused = True
- elif key == glfw.KEY_V or \
- (key == glfw.KEY_ESCAPE and self._record_video): # Records video. Trigers with V or if in progress by ESC.
- self._record_video = not self._record_video
- if self._record_video:
- fps = self.video_fps # (1 / self._time_per_render)
- self._video_process = Process(target=save_video,
- args=(self._video_queue, self._video_path % self._video_idx, fps))
- self._video_process.start()
- if not self._record_video:
- self._video_queue.put(None)
- self._video_process.join()
- self._video_idx += 1
- elif key == glfw.KEY_T: # capture screenshot
- img = self._read_pixels_as_in_window()
- imageio.imwrite(self._image_path % self._image_idx, img)
- self._image_idx += 1
- elif key == glfw.KEY_I: # drops in debugger.
- print('You can access the simulator by self.sim')
- import ipdb
- ipdb.set_trace()
- elif key == glfw.KEY_S: # Slows down simulation.
- self._run_speed /= 2.0
- elif key == glfw.KEY_F: # Speeds up simulation.
- self._run_speed *= 2.0
- elif key == glfw.KEY_C: # Displays contact forces.
- vopt = self.vopt
- vopt.flags[10] = vopt.flags[11] = not vopt.flags[10]
- elif key == glfw.KEY_D: # turn off / turn on rendering every frame.
- self._render_every_frame = not self._render_every_frame
- elif key == glfw.KEY_E:
- vopt = self.vopt
- vopt.frame = 1 - vopt.frame
- elif key == glfw.KEY_R: # makes everything little bit transparent.
- self._transparent = not self._transparent
- if self._transparent:
- self.sim.model.geom_rgba[:, 3] /= 5.0
- else:
- self.sim.model.geom_rgba[:, 3] *= 5.0
- elif key == glfw.KEY_M: # Shows / hides mocap bodies
- self._show_mocap = not self._show_mocap
- for body_idx1, val in enumerate(self.sim.model.body_mocapid):
- if val != -1:
- for geom_idx, body_idx2 in enumerate(self.sim.model.geom_bodyid):
- if body_idx1 == body_idx2:
- if not self._show_mocap:
- # Store transparency for later to show it.
- self.sim.extras[
- geom_idx] = self.sim.model.geom_rgba[geom_idx, 3]
- self.sim.model.geom_rgba[geom_idx, 3] = 0
- else:
- self.sim.model.geom_rgba[
- geom_idx, 3] = self.sim.extras[geom_idx]
- elif key in (glfw.KEY_0, glfw.KEY_1, glfw.KEY_2, glfw.KEY_3, glfw.KEY_4):
- self.vopt.geomgroup[key - glfw.KEY_0] ^= 1
- super().key_callback(window, key, scancode, action, mods)
-
-# Separate Process to save video. This way visualization is
-# less slowed down.
+ self._offscreen_renderer.update_scene(self.data)
+ self._offscreen_renderer.render()
+
+ def read_pixels(self, width, height, depth=False):
+ """Read pixels from the rendering context."""
+ self._init_offscreen_renderer()
+ self._offscreen_renderer.update_scene(self.data)
+ self._offscreen_renderer.render(width, height)
+ return self._offscreen_renderer.read_pixels(width, height, depth=depth)
+
+ def setup_camera(self, camera_id, width, height):
+ """Setup camera for rendering."""
+ self._init_offscreen_renderer()
+ self._offscreen_renderer.update_scene(self.data)
+
+ @property
+ def cam(self):
+ """Return camera object for manipulation."""
+ if self._use_passive and self._passive_viewer is not None:
+ return self._passive_viewer.cam
+ elif self._offscreen_renderer is not None:
+ return self._offscreen_renderer.scn.camera
+ return None
+
+ def is_running(self):
+ """Check if viewer is still running."""
+ return self._passive_viewer is not None and self._passive_viewer.is_running()
+
+ def close(self):
+ """Close the viewer."""
+ if self._passive_viewer is not None:
+ self._passive_viewer.close()
+ self._passive_viewer = None
+ self.window = None
+
+
+class MjViewerPassive:
+ """
+ Interactive viewer using MuJoCo 3.x's passive viewer functionality.
+ This provides a full-featured window with keyboard/mouse controls.
+ """
+
+ def __init__(self, model, data):
+ self.model = model
+ self.data = data
+ self.viewer = None
+ self._loop_count = 0
+ self._paused = False
+ self._run_speed = 1.0
+ self._time_per_render = 1 / 60.0
+ self.video_fps = 30
+ self.frame_skip = 1
+ self.custom_key_callback = None
+
+ def render(self):
+ """Render frame (called in main loop)."""
+ if self.viewer is not None:
+ self.viewer.render()
+
+ def setup(self):
+ """Launch the interactive viewer."""
+ self.viewer = mujoco.viewer.launch_passive(
+ self.model,
+ self.data,
+ show_left_ui=True,
+ show_right_ui=True
+ )
+ return self.viewer
+
+ def is_running(self):
+ """Check if viewer is still running."""
+ return self.viewer is not None and self.viewer.is_running()
+
+ def close(self):
+ """Close the viewer."""
+ if self.viewer is not None:
+ self.viewer.close()
+ self.viewer = None
+
+
+def launch_passive(model, data, key_callback=None):
+ """
+ Launch a passive MuJoCo viewer with custom key handling.
+
+ Args:
+ model: MuJoCo model
+ data: MuJoCo data
+ key_callback: Optional callback for key events
+
+ Returns:
+ mujoco.viewer.Handle object
+ """
+ viewer = mujoco.viewer.launch_passive(model, data)
+
+ if key_callback is not None:
+ # Install custom key callback
+ pass
+
+ return viewer
def save_video(queue, filename, fps):
+ """Save video from queue to file."""
writer = imageio.get_writer(filename, fps=fps)
while True:
frame = queue.get()
diff --git a/khrylib/rl/envs/common/mujoco_env.py b/khrylib/rl/envs/common/mujoco_env.py
index a2e3412..fdb1d49 100644
--- a/khrylib/rl/envs/common/mujoco_env.py
+++ b/khrylib/rl/envs/common/mujoco_env.py
@@ -1,9 +1,9 @@
-from gym import spaces
-from gym.utils import seeding
+from gymnasium import spaces
+from gymnasium.utils import seeding
import numpy as np
from os import path
from pathlib import Path
-import mujoco_py
+import mujoco
from khrylib.rl.envs.common.mjviewer import MjViewer
DEFAULT_SIZE = 500
@@ -11,6 +11,7 @@
class MujocoEnv:
"""Superclass for all MuJoCo environments.
+ Adapted for MuJoCo 3.x
"""
def __init__(self, fullpath, frame_skip):
@@ -20,9 +21,8 @@ def __init__(self, fullpath, frame_skip):
if not path.exists(fullpath):
raise IOError("File %s does not exist" % fullpath)
self.frame_skip = frame_skip
- self.model = mujoco_py.load_model_from_path(fullpath)
- self.sim = mujoco_py.MjSim(self.model)
- self.data = self.sim.data
+ self.model = mujoco.MjModel.from_xml_path(fullpath)
+ self.data = mujoco.MjData(self.model)
self.viewer = None
self._viewers = {}
self.obs_dim = None
@@ -36,8 +36,8 @@ def __init__(self, fullpath, frame_skip):
'video.frames_per_second': int(np.round(1.0 / self.dt))
}
- self.init_qpos = self.sim.data.qpos.ravel().copy()
- self.init_qvel = self.sim.data.qvel.ravel().copy()
+ self.init_qpos = self.data.qpos.ravel().copy()
+ self.init_qvel = self.data.qvel.ravel().copy()
self.prev_qpos = None
self.prev_qvel = None
self.seed()
@@ -82,8 +82,8 @@ def viewer_setup(self, mode):
# -----------------------------
- def reset(self):
- self.sim.reset()
+ def reset(self, seed=None, options=None):
+ mujoco.mj_resetData(self.model, self.data)
self.cur_t = 0
ob = self.reset_model()
old_viewer = self.viewer
@@ -91,24 +91,28 @@ def reset(self):
self.viewer = v
self.viewer_setup(mode)
self.viewer = old_viewer
- return ob
+ # gymnasium reset returns (observation, info)
+ return ob, {}
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
- old_state = self.sim.get_state()
- new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
- old_state.act, old_state.udd_state)
- self.sim.set_state(new_state)
- self.sim.forward()
+ self.data.qpos[:] = qpos
+ self.data.qvel[:] = qvel
+ self.data.time = 0.0
+ self.forward()
+
+ def forward(self):
+ """Run forward dynamics with mujoco 3.x API"""
+ mujoco.mj_step(self.model, self.data)
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
- self.sim.data.ctrl[:] = ctrl
+ self.data.ctrl[:] = ctrl
for _ in range(n_frames):
- self.sim.step()
+ self.forward()
def render(self, mode='human', width=DEFAULT_SIZE, height=DEFAULT_SIZE):
if mode == 'image':
@@ -130,9 +134,12 @@ def _get_viewer(self, mode):
self.viewer = self._viewers.get(mode)
if self.viewer is None:
if mode == 'human':
- self.viewer = MjViewer(self.sim)
+ # For human mode, use passive viewer for full interactivity
+ self.viewer = MjViewer(self.model, self.data, use_passive=True)
+ # Initialize the viewer to get the window
+ self.viewer._init_passive_viewer()
elif mode == 'image':
- self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, 0)
+ self.viewer = mujoco.Renderer(self.model, height, width)
self._viewers[mode] = self.viewer
self.viewer_setup(mode)
return self.viewer
@@ -141,22 +148,24 @@ def set_custom_key_callback(self, key_func):
self._get_viewer('human').custom_key_callback = key_func
def get_body_com(self, body_name):
- return self.data.get_body_xpos(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ return self.data.body(body_id).xpos
def state_vector(self):
return np.concatenate([
- self.sim.data.qpos.flat,
- self.sim.data.qvel.flat
+ self.data.qpos.flat,
+ self.data.qvel.flat
])
def vec_body2world(self, body_name, vec):
- body_xmat = self.data.get_body_xmat(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ body_xmat = self.data.body(body_id).xmat.reshape(3, 3)
vec_world = (body_xmat @ vec[:, None]).ravel()
return vec_world
def pos_body2world(self, body_name, pos):
- body_xpos = self.data.get_body_xpos(body_name)
- body_xmat = self.data.get_body_xmat(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ body_xpos = self.data.body(body_id).xpos
+ body_xmat = self.data.body(body_id).xmat.reshape(3, 3)
pos_world = (body_xmat @ pos[:, None]).ravel() + body_xpos
return pos_world
-
diff --git a/khrylib/rl/envs/common/mujoco_env_gym.py b/khrylib/rl/envs/common/mujoco_env_gym.py
index 15a4c26..39e9429 100644
--- a/khrylib/rl/envs/common/mujoco_env_gym.py
+++ b/khrylib/rl/envs/common/mujoco_env_gym.py
@@ -1,21 +1,21 @@
from collections import OrderedDict
import os
-from gym import error, spaces
-from gym.utils import seeding
+from gymnasium import error, spaces
+from gymnasium.utils import seeding
import numpy as np
from os import path
-import gym
+from pathlib import Path
+import gymnasium as gym
try:
- import mujoco_py
+ import mujoco
except ImportError as e:
- raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
+ raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco>=3.0.0)".format(e))
DEFAULT_SIZE = 500
-from khrylib.rl.envs.common.mjviewer import MjViewer
-
+from khrylib.rl.envs.common.mjviewer import MjViewer, MjViewerPassive
def convert_observation_to_space(observation):
@@ -39,22 +39,22 @@ def convert_observation_to_space(observation):
class MujocoEnv(gym.Env):
"""Superclass for all MuJoCo environments.
+ Adapted for MuJoCo 3.x
"""
def __init__(self, fullpath, frame_skip, mujoco_xml=None):
if mujoco_xml is not None:
- self.model = mujoco_py.load_model_from_xml(mujoco_xml)
+ self.model = mujoco.MjModel.from_xml_string(mujoco_xml)
else:
if not path.exists(fullpath):
# try the default assets path
fullpath = path.join(Path(__file__).parent.parent.parent.parent, 'assets/mujoco_models', path.basename(fullpath))
if not path.exists(fullpath):
raise IOError("File %s does not exist" % fullpath)
- self.model = mujoco_py.load_model_from_path(fullpath)
+ self.model = mujoco.MjModel.from_xml_path(fullpath)
self.frame_skip = frame_skip
- self.sim = mujoco_py.MjSim(self.model)
- self.data = self.sim.data
+ self.data = mujoco.MjData(self.model)
self.viewer = None
self._viewers = {}
@@ -63,8 +63,8 @@ def __init__(self, fullpath, frame_skip, mujoco_xml=None):
'video.frames_per_second': int(np.round(1.0 / self.dt))
}
- self.init_qpos = self.sim.data.qpos.ravel().copy()
- self.init_qvel = self.sim.data.qvel.ravel().copy()
+ self.init_qpos = self.data.qpos.ravel().copy()
+ self.init_qvel = self.data.qvel.ravel().copy()
self.is_inited = False
self._set_action_space()
@@ -94,16 +94,14 @@ def seed(self, seed=None):
return [seed]
def reload_sim_model(self, xml_str):
- del self.sim
del self.model
del self.data
del self.viewer
del self._viewers
- self.model = mujoco_py.load_model_from_xml(xml_str)
- self.sim = mujoco_py.MjSim(self.model)
- self.data = self.sim.data
- self.init_qpos = self.sim.data.qpos.copy()
- self.init_qvel = self.sim.data.qvel.copy()
+ self.model = mujoco.MjModel.from_xml_string(xml_str)
+ self.data = mujoco.MjData(self.model)
+ self.init_qpos = self.data.qpos.copy()
+ self.init_qvel = self.data.qvel.copy()
self.viewer = None
self._viewers = {}
@@ -127,27 +125,31 @@ def viewer_setup(self):
# -----------------------------
- def reset(self):
- self.sim.reset()
+ def reset(self, seed=None, options=None):
+ mujoco.mj_resetData(self.model, self.data)
ob = self.reset_model()
- return ob
+ # gymnasium reset returns (observation, info)
+ return ob, {}
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
- old_state = self.sim.get_state()
- new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
- old_state.act, old_state.udd_state)
- self.sim.set_state(new_state)
- self.sim.forward()
+ self.data.qpos[:] = qpos
+ self.data.qvel[:] = qvel
+ self.data.time = 0.0
+ self.forward()
+
+ def forward(self):
+ """Run forward dynamics with mujoco 3.x API"""
+ mujoco.mj_step(self.model, self.data)
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
- self.sim.data.ctrl[:] = ctrl
+ self.data.ctrl[:] = ctrl
for _ in range(n_frames):
- self.sim.step()
+ self.forward()
def render(self,
mode='human',
@@ -164,20 +166,20 @@ def render(self,
if no_camera_specified:
camera_name = 'track'
- if camera_id is None and camera_name in self.model._camera_name2id:
- camera_id = self.model.camera_name2id(camera_name)
+ if camera_id is None:
+ try:
+ camera_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, camera_name)
+ except ValueError:
+ camera_id = None
self._get_viewer(mode).render(width, height, camera_id=camera_id)
if mode == 'rgb_array':
- # window size used for old mujoco-py:
data = self._get_viewer(mode).read_pixels(width, height, depth=False)
# original image is upside-down, so flip it
return data[::-1, :, :]
elif mode == 'depth_array':
self._get_viewer(mode).render(width, height)
- # window size used for old mujoco-py:
- # Extract depth part of the read_pixels() tuple
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
# original image is upside-down, so flip it
return data[::-1, :]
@@ -186,7 +188,6 @@ def render(self,
def close(self):
if self.viewer is not None:
- # self.viewer.finish()
self.viewer = None
self._viewers = {}
@@ -194,9 +195,12 @@ def _get_viewer(self, mode):
self.viewer = self._viewers.get(mode)
if self.viewer is None:
if mode == 'human':
- self.viewer = MjViewer(self.sim)
+ # For human mode, use passive viewer for full interactivity
+ self.viewer = MjViewer(self.model, self.data, use_passive=True)
+ # Initialize the viewer to get the window
+ self.viewer._init_passive_viewer()
elif mode == 'rgb_array' or mode == 'depth_array':
- self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
+ self.viewer = mujoco.Renderer(self.model, height, width)
self._viewers[mode] = self.viewer
self.viewer_setup()
@@ -206,22 +210,24 @@ def set_custom_key_callback(self, key_func):
self._get_viewer('human').custom_key_callback = key_func
def get_body_com(self, body_name):
- return self.data.get_body_xpos(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ return self.data.body(body_id).xpos
def state_vector(self):
return np.concatenate([
- self.sim.data.qpos.flat,
- self.sim.data.qvel.flat
+ self.data.qpos.flat,
+ self.data.qvel.flat
])
def vec_body2world(self, body_name, vec):
- body_xmat = self.data.get_body_xmat(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ body_xmat = self.data.body(body_id).xmat.reshape(3, 3)
vec_world = (body_xmat @ vec[:, None]).ravel()
return vec_world
def pos_body2world(self, body_name, pos):
- body_xpos = self.data.get_body_xpos(body_name)
- body_xmat = self.data.get_body_xmat(body_name)
+ body_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, body_name)
+ body_xpos = self.data.body(body_id).xpos
+ body_xmat = self.data.body(body_id).xmat.reshape(3, 3)
pos_world = (body_xmat @ pos[:, None]).ravel() + body_xpos
return pos_world
-
diff --git a/khrylib/rl/envs/visual/humanoid_vis.py b/khrylib/rl/envs/visual/humanoid_vis.py
index b2f6b48..c8fc93c 100644
--- a/khrylib/rl/envs/visual/humanoid_vis.py
+++ b/khrylib/rl/envs/visual/humanoid_vis.py
@@ -20,7 +20,7 @@ def reset_model(self):
return None
def sim_forward(self):
- self.sim.forward()
+ self.forward()
def viewer_setup(self, mode):
self.viewer.cam.trackbodyid = 1
diff --git a/khrylib/robot/xml_robot.py b/khrylib/robot/xml_robot.py
index f221e62..28e04b7 100644
--- a/khrylib/robot/xml_robot.py
+++ b/khrylib/robot/xml_robot.py
@@ -1,5 +1,6 @@
import numpy as np
import math
+import warnings
from copy import deepcopy
from collections import defaultdict
from lxml.etree import XMLParser, parse, ElementTree, Element, SubElement
@@ -62,7 +63,6 @@ def __init__(self, node, body):
self.axis = vec_to_polar(parse_vec(node.attrib['axis']))
if self.local_coord:
self.pos += body.pos
- assert(np.all(self.pos == body.pos))
def __repr__(self):
return 'joint_' + self.name
@@ -520,7 +520,13 @@ def __init__(self, cfg, xml, is_xml_str=False):
def load_from_xml(self, xml, is_xml_str=False):
parser = XMLParser(remove_blank_text=True)
self.tree = parse(BytesIO(xml) if is_xml_str else xml, parser=parser)
- self.local_coord = self.tree.getroot().find('.//compiler').attrib['coordinate'] == 'local'
+ # MuJoCo 3.x removed coordinate="global", default is local coordinates
+ compiler = self.tree.getroot().find('.//compiler')
+ if compiler is not None and 'coordinate' in compiler.attrib:
+ self.local_coord = compiler.attrib['coordinate'] == 'local'
+ else:
+ # MuJoCo 3.x default is local coordinates
+ self.local_coord = True
root = self.tree.getroot().find('worldbody').find('body')
self.add_body(root, None)
@@ -583,7 +589,7 @@ def export_xml_string(self):
def demap_params(self, params):
if not np.all((params <= 1.0) & (params >= -1.0)):
- print(f'param out of bounds: {params}')
+ warnings.warn(f'param out of bounds: {params}', stacklevel=2)
params = np.clip(params, -1.0, 1.0)
if self.param_mapping == 'sin':
params = np.arcsin(params) / (0.5 * np.pi)
@@ -691,11 +697,10 @@ def get_laplacian_position_encoding(self, pos_enc_dim=4):
import time
import yaml
sys.path.append(os.getcwd())
- from mujoco_py import load_model_from_path, MjSim, MjViewer
+ import mujoco
model_name = 'ant'
cfg_path = f'khrylib/assets/ant.yml'
- # model = load_model_from_path(f'assets/mujoco_envs/{model_name}.xml')
yml = yaml.safe_load(open(cfg_path, 'r'))
cfg = yml['robot']
xml_robot = Robot(cfg, xml=f'assets/mujoco_envs/{model_name}.xml')
@@ -710,13 +715,15 @@ def get_laplacian_position_encoding(self, pos_enc_dim=4):
os.makedirs('out', exist_ok=True)
xml_robot.write_xml(f'out/{model_name}_test.xml')
- model = load_model_from_path(f'out/{model_name}_test.xml')
- sim = MjSim(model)
- viewer = MjViewer(sim)
+ model = mujoco.MjModel.from_xml_path(f'out/{model_name}_test.xml')
+ data = mujoco.MjData(model)
+
+ # Use MuJoCo 3.x viewer
+ viewer = mujoco.viewer.launch_passive(model, data)
viewer.cam.distance = 10
- while True:
- sim.data.qpos[2] = 1.0
- sim.data.qpos[7:] = np.pi / 6
- sim.forward()
+ while viewer.is_running():
+ data.qpos[2] = 1.0
+ data.qpos[7:] = np.pi / 6
+ mujoco.mj_step(model, data)
viewer.render()
\ No newline at end of file
diff --git a/khrylib/utils/mujoco.py b/khrylib/utils/mujoco.py
index bdbb754..2492eb1 100644
--- a/khrylib/utils/mujoco.py
+++ b/khrylib/utils/mujoco.py
@@ -1,11 +1,13 @@
from khrylib.utils.math import *
+import mujoco
def get_single_body_qposaddr(model, body):
- i = model.body_names.index(body)
- start_joint = model.body_jntadr[i]
+ # MuJoCo 3.x uses mujoco.mj_name2id instead of body_names
+ body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, body)
+ start_joint = model.body_jntadr[body_id]
# assert start_joint >= 0
- end_joint = start_joint + model.body_jntnum[i]
+ end_joint = start_joint + model.body_jntnum[body_id]
start_qposaddr = model.jnt_qposadr[start_joint]
if end_joint < len(model.jnt_qposadr):
end_qposaddr = model.jnt_qposadr[end_joint]
diff --git a/requirements.txt b/requirements.txt
index 0c8bc42..3d3af0e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,12 @@
-torch==2.0.1
-gym==0.15.4
+torch
+gymnasium>=0.29.0
numpy<1.23.0
opencv-python
glfw
pyyaml
tensorboard
lxml
-cython<3
-mujoco-py<2.2,>=2.1
+mujoco>=3.4.0
torch-geometric==2.4.0
hydra-core==1.1.1
hydra-submitit-launcher==1.1.5