-
Notifications
You must be signed in to change notification settings - Fork 291
Description
Hello,
I have looked at the report you have provided and am trying to reproduce the results with PPO. I succeed for an easy task (CheetahRun) but fail for HumanoidWalk (I expected sth around 800 but am getting 20). From what I understand the hyperparams are the following:
'num_timesteps': 60_000_000, 'num_envs': 2048, 'learning_rate': 1e-3, 'batch_size': 1024, 'discounting': 0.9995, 'entropy_cost': 1e-2, 'episode_length': 1000, 'num_minibatches': 32, 'num_updates_per_batch': 16, 'unroll_length': 30, 'reward_scaling': 10.0, 'action_repeat': 1, 'policy_hidden_sizes': '128,128,128,128', 'value_hidden_sizes': '128,128,128,128', 'activation': 'swish',
I paste the script I use for training below. I also tried increasing the num_timesteps to 10 times longer but that also did not perform well
`"""
Train a PPO agent on MuJoCo Playground environments using Brax's PPO implementation.
This script uses:
- brax.training.agents.ppo: JAX-based PPO implementation
- mujoco_playground: JAX-based MuJoCo environments (MJX)
Available environments include:
- DM Control Suite: CheetahRun, WalkerWalk, HumanoidWalk, HopperHop, etc.
- Locomotion: Go1JoystickFlatTerrain, SpotFlatTerrainJoystick, etc.
Usage:
python train_ppo_mjx.py --env HumanoidWalk
python train_ppo_mjx.py --env CheetahRun --num_timesteps 10000000
"""
import argparse
import os
import sys
Parse --gpus argument BEFORE importing JAX to set CUDA_VISIBLE_DEVICES
def _get_gpu_arg():
for i, arg in enumerate(sys.argv):
if arg == '--gpus' and i + 1 < len(sys.argv):
return sys.argv[i + 1]
return None
_gpu_arg = _get_gpu_arg()
if _gpu_arg:
os.environ['CUDA_VISIBLE_DEVICES'] = _gpu_arg
print(f"Setting CUDA_VISIBLE_DEVICES={_gpu_arg}")
Set JAX memory settings
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
Set MuJoCo to use EGL for headless rendering (no display required)
os.environ["MUJOCO_GL"] = "egl"
Now import JAX and other libraries
import functools
import time
import json
import pickle
import jax
import jax.numpy as jnp
import numpy as np
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground import registry
from mujoco_playground import wrapper
from mujoco_playground.config import dm_control_suite_params
import wandb
import matplotlib.pyplot as plt
import imageio
def parse_args():
parser = argparse.ArgumentParser(description='PPO Training on MuJoCo Playground')
parser.add_argument('--env', type=str, default='CheetahRun',
help='Environment name (e.g., HumanoidWalk, CheetahRun)')
parser.add_argument('--num_timesteps', type=int, default=10_000_000,
help='Total number of environment timesteps')
parser.add_argument('--seed', type=int, default=42,
help='Random seed')
parser.add_argument('--gpus', type=str, default=None,
help='Comma-separated GPU IDs to use (e.g., "0,3,4"). Default: all available')
# PPO hyperparameters
parser.add_argument('--num_envs', type=int, default=2048,
help='Number of parallel environments (default 2048)')
parser.add_argument('--num_eval_envs', type=int, default=128,
help='Number of evaluation environments (default 128)')
parser.add_argument('--episode_length', type=int, default=1000,
help='Max episode length (default 1000)')
parser.add_argument('--unroll_length', type=int, default=10,
help='Unroll length for PPO (default 10)')
parser.add_argument('--num_minibatches', type=int, default=32,
help='Number of minibatches (default 32)')
parser.add_argument('--num_updates_per_batch', type=int, default=8,
help='Number of PPO updates per batch (default 8)')
parser.add_argument('--learning_rate', type=float, default=3e-4,
help='Learning rate (default 3e-4)')
parser.add_argument('--entropy_cost', type=float, default=1e-2,
help='Entropy bonus coefficient (default 1e-2)')
parser.add_argument('--discounting', type=float, default=0.97,
help='Discount factor gamma (default 0.97)')
parser.add_argument('--reward_scaling', type=float, default=0.1,
help='Reward scaling factor (default 0.1)')
parser.add_argument('--clipping_epsilon', type=float, default=0.3,
help='PPO clipping epsilon (default 0.3)')
parser.add_argument('--gae_lambda', type=float, default=0.95,
help='GAE lambda (default 0.95)')
parser.add_argument('--normalize_observations', action='store_true', default=True,
help='Normalize observations (default True)')
parser.add_argument('--policy_hidden_sizes', type=str, default='32,32,32,32',
help='Policy network hidden layer sizes (default "32,32,32,32")')
parser.add_argument('--value_hidden_sizes', type=str, default='32,32,32,32',
help='Value network hidden layer sizes (default "32,32,32,32")')
parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'swish'],
help='Activation function (default "tanh", brax default is "swish")')
parser.add_argument('--batch_size', type=int, default=256,
help='Batch size (default 256)')
parser.add_argument('--max_grad_norm', type=float, default=1.0,
help='Max gradient norm for clipping (default 1.0)')
parser.add_argument('--action_repeat', type=int, default=1,
help='Action repeat (default 1)')
# Eval and logging
parser.add_argument('--num_evals', type=int, default=10,
help='Number of evaluations during training (default 10)')
parser.add_argument('--num_eval_videos', type=int, default=5,
help='Number of evaluation videos to save (default 5)')
return parser.parse_args()
def final_evaluation(env, inference_fn, key, env_name,
num_eval_trials=100, max_steps=1000, save_dir="eval_results",
num_gifs=10):
"""Run final evaluation with GIF saving, JSON rewards, and histogram.
Evaluation is parallelized on GPU using vmap.
Only a subset of trials (num_gifs) have trajectories saved for GIF rendering.
"""
print("\n" + "=" * 60)
print(f"Running Final Evaluation ({num_eval_trials} trials)...")
print("=" * 60)
# Create save directory
os.makedirs(save_dir, exist_ok=True)
gifs_dir = os.path.join(save_dir, "gifs")
os.makedirs(gifs_dir, exist_ok=True)
jit_inference_fn = jax.jit(inference_fn)
# Define a single rollout function that returns reward only (GIFs done separately)
def rollout_episode(key):
"""Rollout a single episode, return total reward."""
key, reset_key = jax.random.split(key)
def step_fn(carry, _):
state, total_reward, done_flag, key = carry
key, act_key = jax.random.split(key)
# Get action from policy
action, _ = jit_inference_fn(state.obs, act_key)
# Step environment
next_state = env.step(state, action)
reward = next_state.reward
done = next_state.done
# Accumulate reward (only if not done)
total_reward = total_reward + reward * (1.0 - done_flag)
done_flag = jnp.maximum(done_flag, done)
return (next_state, total_reward, done_flag, key), None
# Reset and run episode
state = env.reset(reset_key)
(_, total_reward, _, _), _ = jax.lax.scan(
step_fn, (state, 0.0, 0.0, key), None, length=max_steps
)
return total_reward
# Vectorize over all trials
keys = jax.random.split(key, num_eval_trials)
# JIT compile the vectorized evaluation
@jax.jit
def eval_all(keys):
return jax.vmap(rollout_episode)(keys)
print(" Running parallel evaluation on GPU...")
rewards = eval_all(keys)
rewards = np.array(rewards)
print(f" Evaluation complete! Mean reward: {np.mean(rewards):.2f}")
# Save GIFs for a subset of trials (rendering is CPU-bound, so do sequential rollouts)
print(f" Saving {num_gifs} GIFs...")
gif_indices = np.linspace(0, num_eval_trials - 1, num_gifs, dtype=int)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
for i, trial_idx in enumerate(gif_indices):
reward_val = float(rewards[trial_idx])
# Do a fresh rollout for this GIF (env.render needs a list of states)
gif_key = jax.random.key(trial_idx)
state = jit_reset(gif_key)
rollout = [state]
for _ in range(max_steps):
gif_key, act_key = jax.random.split(gif_key)
action, _ = jit_inference_fn(state.obs, act_key)
state = jit_step(state, action)
rollout.append(state)
if state.done:
break
try:
# Render every 2nd frame for smaller GIFs
images = env.render(rollout[::2], height=240, width=320)
gif_path = os.path.join(gifs_dir, f"trial_{trial_idx:03d}_reward_{reward_val:.1f}.gif")
imageio.mimsave(gif_path, images, fps=30, loop=0)
print(f" GIF {i+1}/{num_gifs}: trial {trial_idx}, reward = {reward_val:.2f}")
except Exception as e:
print(f" GIF {i+1}/{num_gifs}: trial {trial_idx}, reward = {reward_val:.2f} (failed: {e})")
# Save rewards to JSON
rewards_list = rewards.tolist()
rewards_data = {
"env_name": env_name,
"num_trials": num_eval_trials,
"rewards": rewards_list,
"mean": float(np.mean(rewards)),
"std": float(np.std(rewards)),
"min": float(np.min(rewards)),
"max": float(np.max(rewards)),
}
json_path = os.path.join(save_dir, "eval_rewards.json")
with open(json_path, 'w') as f:
json.dump(rewards_data, f, indent=2)
print(f"\n Rewards saved to: {json_path}")
# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(rewards, bins=20, edgecolor='black', alpha=0.7)
plt.axvline(np.mean(rewards), color='red', linestyle='--', linewidth=2,
label=f'Mean: {np.mean(rewards):.2f}')
plt.axvline(np.median(rewards), color='orange', linestyle='--', linewidth=2,
label=f'Median: {np.median(rewards):.2f}')
plt.xlabel('Reward', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.title(f'Final Evaluation Rewards - {env_name} (PPO)\n({num_eval_trials} trials)', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
hist_path = os.path.join(save_dir, "reward_histogram.png")
plt.savefig(hist_path, dpi=150, bbox_inches='tight')
plt.close()
print(f" Histogram saved to: {hist_path}")
# Log to wandb
wandb.log({
"eval/mean_reward": float(np.mean(rewards)),
"eval/std_reward": float(np.std(rewards)),
"eval/min_reward": float(np.min(rewards)),
"eval/max_reward": float(np.max(rewards)),
})
wandb.log({"eval/histogram": wandb.Image(hist_path)})
# Log sample GIFs to wandb
sample_gifs = sorted(os.listdir(gifs_dir))[:5]
for gif_name in sample_gifs:
gif_path = os.path.join(gifs_dir, gif_name)
try:
wandb.log({f"eval/gif_{gif_name}": wandb.Video(gif_path, fps=30, format="gif")})
except Exception:
pass
print(f"\n Final Evaluation Results:")
print(f" Mean reward: {np.mean(rewards):.2f}")
print(f" Std reward: {np.std(rewards):.2f}")
print(f" Min reward: {np.min(rewards):.2f}")
print(f" Max reward: {np.max(rewards):.2f}")
return rewards_list
def main():
args = parse_args()
# Environment name
env_name = args.env
# Environment-specific hyperparameter defaults (from successful runs)
env_specific_configs = {
'CheetahRun': {
'num_timesteps': 150_000_000,
'num_envs': 4096,
'learning_rate': 1e-4,
'batch_size': 256,
'discounting': 0.97,
'entropy_cost': 0.01,
'episode_length': 1000,
'num_minibatches': 32,
'num_updates_per_batch': 8,
'unroll_length': 10,
'reward_scaling': 0.1,
'action_repeat': 1,
},
'HumanoidWalk': {
'num_timesteps': 609_000_000,
'num_envs': 2048,
'learning_rate': 1e-3,
'batch_size': 1024,
'discounting': 0.9995,
'entropy_cost': 1e-2,
'episode_length': 1000,
'num_minibatches': 32,
'num_updates_per_batch': 16,
'unroll_length': 30,
'reward_scaling': 10.0,
'action_repeat': 1,
'policy_hidden_sizes': '128,128,128,128',
'value_hidden_sizes': '128,128,128,128',
'activation': 'swish',
},
}
# Default values from argparse - used to check if user explicitly set something
param_defaults = {
'num_timesteps': 10_000_000,
'num_envs': 2048,
'episode_length': 1000,
'unroll_length': 10,
'num_minibatches': 32,
'num_updates_per_batch': 8,
'learning_rate': 3e-4,
'entropy_cost': 1e-2,
'discounting': 0.97,
'reward_scaling': 0.1,
'batch_size': 256,
'action_repeat': 1,
'policy_hidden_sizes': '32,32,32,32',
'value_hidden_sizes': '32,32,32,32',
'activation': 'tanh',
}
# Apply environment-specific config (only for params not explicitly set by user)
if env_name in env_specific_configs:
env_config = env_specific_configs[env_name]
print(f"\nApplying environment-specific hyperparameters for {env_name}:")
for key, val in env_config.items():
if hasattr(args, key):
current_val = getattr(args, key)
is_default = key in param_defaults and current_val == param_defaults[key]
if is_default:
setattr(args, key, val)
print(f" {key}: {val}")
else:
print(f" {key}: {current_val} (CLI override, env-specific: {val})")
# Parse hidden layer sizes
policy_hidden_sizes = tuple(int(x) for x in args.policy_hidden_sizes.split(','))
value_hidden_sizes = tuple(int(x) for x in args.value_hidden_sizes.split(','))
# Get device info
devices = jax.devices()
num_devices = len(devices)
print("=" * 60)
print(f"PPO Training on {env_name} (MuJoCo Playground/MJX)")
print("=" * 60)
print(f"\nUsing {num_devices} GPU(s): {[str(d) for d in devices]}")
print(f"Available environments: {registry.ALL_ENVS[:10]}...")
# Initialize wandb
wandb.init(
project="ppo-mjx",
name=f"ppo-{env_name}",
config={
"algorithm": "PPO",
"env_name": env_name,
"num_timesteps": args.num_timesteps,
"num_envs": args.num_envs,
"num_eval_envs": args.num_eval_envs,
"episode_length": args.episode_length,
"unroll_length": args.unroll_length,
"num_minibatches": args.num_minibatches,
"num_updates_per_batch": args.num_updates_per_batch,
"learning_rate": args.learning_rate,
"entropy_cost": args.entropy_cost,
"discounting": args.discounting,
"reward_scaling": args.reward_scaling,
"clipping_epsilon": args.clipping_epsilon,
"gae_lambda": args.gae_lambda,
"normalize_observations": args.normalize_observations,
"policy_hidden_sizes": policy_hidden_sizes,
"value_hidden_sizes": value_hidden_sizes,
"batch_size": args.batch_size,
"max_grad_norm": args.max_grad_norm,
"action_repeat": args.action_repeat,
"seed": args.seed,
"num_gpus": num_devices,
}
)
# Load environment
print(f"\nInitializing {env_name} environment...")
env = registry.load(env_name)
eval_env = registry.load(env_name)
# Get observation and action dimensions
key = jax.random.key(args.seed)
key, reset_key = jax.random.split(key)
state = env.reset(reset_key)
obs_dim = state.obs.shape[-1]
action_dim = env.action_size
print(f" Observation dim: {obs_dim}")
print(f" Action dim: {action_dim}")
print(f" Episode length: {args.episode_length}")
# Print PPO hyperparameters
print("\nPPO Hyperparameters:")
print(f" Num timesteps: {args.num_timesteps:,}")
print(f" Num envs: {args.num_envs}")
print(f" Unroll length: {args.unroll_length}")
print(f" Num minibatches: {args.num_minibatches}")
print(f" Num updates per batch: {args.num_updates_per_batch}")
print(f" Learning rate: {args.learning_rate}")
print(f" Entropy cost: {args.entropy_cost}")
print(f" Discounting: {args.discounting}")
print(f" Reward scaling: {args.reward_scaling}")
print(f" Clipping epsilon: {args.clipping_epsilon}")
print(f" GAE lambda: {args.gae_lambda}")
print(f" Policy hidden sizes: {policy_hidden_sizes}")
print(f" Value hidden sizes: {value_hidden_sizes}")
print(f" Activation: {args.activation}")
# Get activation function
activation_fn = jax.nn.swish if args.activation == 'swish' else jax.nn.tanh
# Network factory
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=policy_hidden_sizes,
value_hidden_layer_sizes=value_hidden_sizes,
activation=activation_fn,
)
# Progress callback for logging
times = [time.time()]
best_reward = -float('inf')
def progress_callback(num_steps, metrics):
times.append(time.time())
# Log to wandb
wandb.log(metrics, step=num_steps)
nonlocal best_reward
if 'eval/episode_reward' in metrics:
reward = metrics['eval/episode_reward']
if reward > best_reward:
best_reward = reward
elapsed = times[-1] - times[0]
print(f"Step {num_steps:>10,} | "
f"Reward: {reward:8.2f} | "
f"Best: {best_reward:8.2f} | "
f"Time: {elapsed:6.1f}s")
# Training function
train_fn = functools.partial(
ppo.train,
num_timesteps=args.num_timesteps,
num_evals=args.num_evals,
reward_scaling=args.reward_scaling,
episode_length=args.episode_length,
normalize_observations=args.normalize_observations,
action_repeat=args.action_repeat,
unroll_length=args.unroll_length,
num_minibatches=args.num_minibatches,
num_updates_per_batch=args.num_updates_per_batch,
discounting=args.discounting,
learning_rate=args.learning_rate,
entropy_cost=args.entropy_cost,
num_envs=args.num_envs,
batch_size=args.batch_size,
max_grad_norm=args.max_grad_norm,
clipping_epsilon=args.clipping_epsilon,
gae_lambda=args.gae_lambda,
network_factory=network_factory,
seed=args.seed,
num_eval_envs=args.num_eval_envs,
wrap_env_fn=wrapper.wrap_for_brax_training,
)
# Train
print("\n" + "=" * 60)
print("Starting PPO training...")
print("=" * 60 + "\n")
start_time = time.time()
make_inference_fn, params, _ = train_fn(
environment=env,
progress_fn=progress_callback,
eval_env=eval_env,
)
total_time = time.time() - start_time
print("\n" + "=" * 60)
print("Training Complete!")
print("=" * 60)
print(f" Environment: {env_name}")
print(f" Total timesteps: {args.num_timesteps:,}")
print(f" Total time: {total_time:.1f}s")
print(f" Best reward: {best_reward:.2f}")
# Save trained parameters
save_path = f"best_{env_name.lower()}_ppo_policy.pkl"
with open(save_path, 'wb') as f:
pickle.dump({
'params': params,
'config': {
'env_name': env_name,
'policy_hidden_sizes': policy_hidden_sizes,
'value_hidden_sizes': value_hidden_sizes,
'obs_dim': obs_dim,
'action_dim': action_dim,
}
}, f)
print(f" Best policy saved to: {save_path}")
# Create inference function for evaluation
inference_fn = make_inference_fn(params, deterministic=True)
# Run final evaluation
key, eval_key = jax.random.split(key)
eval_save_dir = f"eval_results_{env_name.lower()}_ppo"
final_evaluation(
eval_env, inference_fn, eval_key, env_name,
num_eval_trials=100, max_steps=args.episode_length, save_dir=eval_save_dir
)
# Finish wandb run
wandb.finish()
return best_reward, params
if name == "main":
main()
`