From 6386ad7394b5f1fe2f165f6901b3b892390f5f47 Mon Sep 17 00:00:00 2001 From: Travor <3488616445@qq.com> Date: Thu, 7 May 2026 01:14:17 +0800 Subject: [PATCH] Add relative quaternion getter --- .../entities/rigid_entity/rigid_entity.py | 47 ++++++++++- tests/test_rigid_physics.py | 83 +++++++++++++++++++ 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index cc21a4a92..a453d053b 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1346,7 +1346,7 @@ def get_pos(self, envs_idx=None): return self._solver.get_links_pos(self.base_link_idx, envs_idx)[..., 0, :] @gs.assert_built - def get_quat(self, envs_idx=None): + def get_quat(self, envs_idx=None, *, relative=False): """ Returns quaternion of the entity's base link. @@ -1354,13 +1354,52 @@ def get_quat(self, envs_idx=None): ---------- envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. + relative : bool, optional + If True, return the quaternion relative to the initial (not current!) quaternion. + The returned quaternion ``delta`` satisfies + ``abs_quat == transform_quat_by_quat(init_quat, delta)``. + Equivalently, ``delta == transform_quat_by_quat(inv_quat(init_quat), abs_quat)``. + Defaults to False. Returns ------- quat : torch.Tensor, shape (4,) or (n_envs, 4) - The quaternion of the entity's base link. - """ - return self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :] + The quaternion of the entity's base link (absolute or relative). + """ + abs_quat = self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :] + if not relative: + return abs_quat + + has_free_root_qpos = self.base_link.n_joints == 1 and self.base_link.joints[0].type == gs.JOINT_TYPE.FREE + if not has_free_root_qpos: + if self._solver._options.batch_links_info: + init_quat = qd_to_torch( + self._solver.links_info.quat, + envs_idx, + self.base_link_idx, + transpose=True, + copy=True, + ) + if self._solver.n_envs == 0: + init_quat = init_quat[0, 0] + else: + init_quat = init_quat[:, 0] + else: + init_quat = torch.as_tensor(self.base_link.quat, dtype=abs_quat.dtype, device=abs_quat.device) + else: + q_start = self.base_link.q_start + init_quat = qd_to_torch( + self._solver.qpos0, + envs_idx, + slice(q_start + 3, q_start + 7), + transpose=True, + copy=True, + ) + if self._solver.n_envs == 0: + init_quat = init_quat[0] + + init_quat = init_quat.to(dtype=abs_quat.dtype, device=abs_quat.device) + return gu.transform_quat_by_quat(gu.inv_quat(init_quat), abs_quat) @gs.assert_built def get_vel(self, envs_idx=None): diff --git a/tests/test_rigid_physics.py b/tests/test_rigid_physics.py index 79c91bb60..1419cd562 100644 --- a/tests/test_rigid_physics.py +++ b/tests/test_rigid_physics.py @@ -1604,6 +1604,85 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol): quat_ref = quat_delta assert_allclose(quat, quat_ref, tol=tol) + if relative: + quat_rel_ref = quat_delta + else: + quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(quat_zero), quat_delta) + assert_allclose(entity.get_quat(relative=True), quat_rel_ref, tol=tol) + # Verify get_quat(relative=False) matches get_quat() (preserves old behavior) + assert_allclose(entity.get_quat(relative=False), quat, tol=tol) + + +@pytest.mark.required +def test_get_quat_relative_heterogeneous_initial_quat(show_viewer, tol): + scene = gs.Scene( + rigid_options=gs.options.RigidOptions(batch_links_info=True), + show_viewer=show_viewer, + show_FPS=False, + ) + box = scene.add_entity( + morph=( + gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 0.0, 0.0)), + gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 45.0, 0.0)), + ), + ) + scene.build(n_envs=4) + + quat_delta = torch.tensor( + [ + [0.9238795, 0.3826834, 0.0, 0.0], + [0.8660254, 0.0, 0.5, 0.0], + [0.7071068, 0.0, 0.0, 0.7071068], + [1.0, 0.0, 0.0, 0.0], + ], + dtype=gs.tc_float, + device=gs.device, + ) + quat_delta = quat_delta / torch.linalg.norm(quat_delta, dim=-1, keepdim=True) + + box.set_quat(quat_delta, relative=True) + + assert_allclose(box.get_quat(relative=True), quat_delta, tol=tol) + assert_allclose(box.get_quat(envs_idx=[2, 3], relative=True), quat_delta[2:], tol=tol) + + +@pytest.mark.required +def test_get_quat_relative_non_parallel(show_viewer, tol): + scene = gs.Scene(show_viewer=show_viewer, show_FPS=False) + box = scene.add_entity(gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 30.0, 0.0))) + scene.build() + + quat_delta = torch.tensor([0.9238795, 0.0, 0.3826834, 0.0], dtype=gs.tc_float, device=gs.device) + quat_delta = quat_delta / torch.linalg.norm(quat_delta) + + box.set_quat(quat_delta, relative=True) + quat_rel = box.get_quat(relative=True) + assert quat_rel.shape == quat_delta.shape + assert_allclose(quat_rel, quat_delta, tol=tol) + + +@pytest.mark.required +def test_get_quat_relative_non_parallel_batched_link_info(show_viewer, tol): + scene = gs.Scene( + rigid_options=gs.options.RigidOptions(batch_links_info=True), + show_viewer=show_viewer, + show_FPS=False, + ) + box = scene.add_entity( + gs.morphs.Box( + fixed=True, + batch_fixed_verts=True, + size=(0.04, 0.04, 0.04), + pos=(0.0, 0.0, 0.1), + euler=(0.0, 30.0, 0.0), + ) + ) + scene.build() + + quat_rel = box.get_quat(relative=True) + assert quat_rel.shape == (4,) + assert_allclose(quat_rel, gu.identity_quat(), tol=tol) + @pytest.mark.required def test_normalized_quat(show_viewer, tol): @@ -5062,6 +5141,10 @@ def test_merge_entities(is_fixed, merge_fixed_links, show_viewer, tol, monkeypat attach_link = franka.get_link("attachment") assert_allclose(attach_link.get_pos(), hand.links[0].get_pos(), tol=gs.EPS) + hand_quat_rel = hand.get_quat(relative=True) + hand_init_quat = torch.as_tensor(hand.base_link.quat, dtype=gs.tc_float, device=gs.device) + hand_quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(hand_init_quat), hand.get_quat()) + assert_allclose(hand_quat_rel, hand_quat_rel_ref, tol=tol) offset_quat = gu.transform_quat_by_quat(hand.links[0].get_quat(), gu.inv_quat(attach_link.get_quat())) assert_allclose(gu.quat_to_xyz(offset_quat, rpy=False, degrees=True), EULER_OFFSET, tol=tol) for link in hand.links[slice(0, None) if merge_fixed_links else slice(1, -1)]: