Skip to content

Vision Colab broken #268

@wittenator

Description

@wittenator

Hi,

I tried running the Colab that includes the Madrona batch renderer (https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1_t4.ipynb)
After following the instructions, it seems that some internal API is missing to create the env. Is this a known error? Here is the error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[/tmp/ipython-input-1084582484.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in <cell line: 0>()
     15 
     16 env_name = "CartpoleBalance"
---> 17 env = dm_control_suite.load(
     18     env_name, config_overrides=config_overrides
     19 )

4 frames[/usr/local/lib/python3.12/dist-packages/mujoco_playground/_src/dm_control_suite/__init__.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in load(env_name, config, config_overrides)
    155     )
    156   config = config or get_default_config(env_name)
--> 157   return _envs[env_name](config=config, config_overrides=config_overrides)

[/usr/local/lib/python3.12/dist-packages/mujoco_playground/_src/dm_control_suite/cartpole.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in __init__(self, swing_up, sparse, config, config_overrides)
    109         warnings.warn("Madrona MJX not installed. Cannot use vision with.")
    110         return
--> 111       self.renderer = BatchRenderer(
    112           m=self._mjx_model,
    113           gpu_id=self._config.vision_config.gpu_id,

[/content/modules/madrona_mjx/src/madrona_mjx/renderer.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in __init__(self, m, gpu_id, num_worlds, batch_render_view_width, batch_render_view_height, enabled_geom_groups, enabled_cameras, add_cam_debug_geo, use_rasterizer, viz_gpu_hdls)
    209     )
    210 
--> 211     init_fn, render_fn = _setup_jax_primitives(
    212         self.madrona,
    213         num_worlds,

[/content/modules/madrona_mjx/src/madrona_mjx/renderer.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in _setup_jax_primitives(renderer, num_worlds, num_geoms, num_cams, render_width, render_height)
    366   render_custom_call_name = f"{custom_call_prefix}_render"
    367 
--> 368   xla_client.register_custom_call_target(
    369       init_custom_call_name,
    370       init_custom_call_capsule,

[/usr/local/lib/python3.12/dist-packages/jax/_src/deprecations.py](https://github.com/google-deepmind/mujoco_playground/issues/new#) in getattr(name)
     55       warnings.warn(message, DeprecationWarning, stacklevel=2)
     56       return fn
---> 57     raise AttributeError(f"module {module!r} has no attribute {name!r}")
     58 
     59   return getattr

AttributeError: module 'jax.lib.xla_client' has no attribute 'register_custom_call_target'

I've tried updating Jax to to match the installation described in the madrona-mjx repo, but that resulted in the same error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions