Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,21 +1346,60 @@ def get_pos(self, envs_idx=None):
return self._solver.get_links_pos(self.base_link_idx, envs_idx)[..., 0, :]

@gs.assert_built
def get_quat(self, envs_idx=None):
def get_quat(self, envs_idx=None, *, relative=False):
"""
Returns quaternion of the entity's base link.

Parameters
----------
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
relative : bool, optional
If True, return the quaternion relative to the initial (not current!) quaternion.
The returned quaternion ``delta`` satisfies
``abs_quat == transform_quat_by_quat(init_quat, delta)``.
Equivalently, ``delta == transform_quat_by_quat(inv_quat(init_quat), abs_quat)``.
Defaults to False.

Returns
-------
quat : torch.Tensor, shape (4,) or (n_envs, 4)
The quaternion of the entity's base link.
"""
return self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :]
The quaternion of the entity's base link (absolute or relative).
"""
abs_quat = self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :]
if not relative:
return abs_quat

has_free_root_qpos = self.base_link.n_joints == 1 and self.base_link.joints[0].type == gs.JOINT_TYPE.FREE
if not has_free_root_qpos:
if self._solver._options.batch_links_info:
init_quat = qd_to_torch(
self._solver.links_info.quat,
envs_idx,
self.base_link_idx,
transpose=True,
copy=True,
)
if self._solver.n_envs == 0:
init_quat = init_quat[0, 0]
else:
init_quat = init_quat[:, 0]
else:
init_quat = torch.as_tensor(self.base_link.quat, dtype=abs_quat.dtype, device=abs_quat.device)
else:
q_start = self.base_link.q_start
init_quat = qd_to_torch(
self._solver.qpos0,
envs_idx,
slice(q_start + 3, q_start + 7),
transpose=True,
Comment on lines +1390 to +1395
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle attached fixed-base entities in relative quaternion getter

When relative=True, this branch assumes any non-fixed base link has a free-root quaternion in qpos0, but attach() can flip base_link.is_fixed to False for originally fixed-base entities attached to a non-fixed parent (see link._is_fixed &= parent_link.is_fixed). In that case the base link does not have a root quaternion slot, so slice(q_start + 3, q_start + 7) reads unrelated qpos0 entries (or from an invalid start), producing incorrect relative quaternions for attached entities. This regression is new because the previous get_quat() only returned absolute state and did not depend on q_start.

Useful? React with 👍 / 👎.

copy=True,
)
Comment on lines +1391 to +1397
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve unbatched quaternion shape in relative getter

When scene.build() is used without parallel envs (n_envs == 0) and the base link is not fixed, this branch pulls init_quat from qpos0 without removing the singleton env dimension. abs_quat is 1D in this mode, so transform_quat_by_quat(...) returns a (1, 4) tensor instead of (4,), which breaks the documented/output shape parity with get_quat() and can cause silent broadcasting bugs in downstream quaternion math.

Useful? React with 👍 / 👎.

if self._solver.n_envs == 0:
init_quat = init_quat[0]

init_quat = init_quat.to(dtype=abs_quat.dtype, device=abs_quat.device)
return gu.transform_quat_by_quat(gu.inv_quat(init_quat), abs_quat)

@gs.assert_built
def get_vel(self, envs_idx=None):
Expand Down
83 changes: 83 additions & 0 deletions tests/test_rigid_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,85 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol):
quat_ref = quat_delta
assert_allclose(quat, quat_ref, tol=tol)

if relative:
quat_rel_ref = quat_delta
else:
quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(quat_zero), quat_delta)
assert_allclose(entity.get_quat(relative=True), quat_rel_ref, tol=tol)
# Verify get_quat(relative=False) matches get_quat() (preserves old behavior)
assert_allclose(entity.get_quat(relative=False), quat, tol=tol)


@pytest.mark.required
def test_get_quat_relative_heterogeneous_initial_quat(show_viewer, tol):
scene = gs.Scene(
rigid_options=gs.options.RigidOptions(batch_links_info=True),
show_viewer=show_viewer,
show_FPS=False,
)
box = scene.add_entity(
morph=(
gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 0.0, 0.0)),
gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 45.0, 0.0)),
),
)
scene.build(n_envs=4)

quat_delta = torch.tensor(
[
[0.9238795, 0.3826834, 0.0, 0.0],
[0.8660254, 0.0, 0.5, 0.0],
[0.7071068, 0.0, 0.0, 0.7071068],
[1.0, 0.0, 0.0, 0.0],
],
dtype=gs.tc_float,
device=gs.device,
)
quat_delta = quat_delta / torch.linalg.norm(quat_delta, dim=-1, keepdim=True)

box.set_quat(quat_delta, relative=True)

assert_allclose(box.get_quat(relative=True), quat_delta, tol=tol)
assert_allclose(box.get_quat(envs_idx=[2, 3], relative=True), quat_delta[2:], tol=tol)


@pytest.mark.required
def test_get_quat_relative_non_parallel(show_viewer, tol):
scene = gs.Scene(show_viewer=show_viewer, show_FPS=False)
box = scene.add_entity(gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 30.0, 0.0)))
scene.build()

quat_delta = torch.tensor([0.9238795, 0.0, 0.3826834, 0.0], dtype=gs.tc_float, device=gs.device)
quat_delta = quat_delta / torch.linalg.norm(quat_delta)

box.set_quat(quat_delta, relative=True)
quat_rel = box.get_quat(relative=True)
assert quat_rel.shape == quat_delta.shape
assert_allclose(quat_rel, quat_delta, tol=tol)


@pytest.mark.required
def test_get_quat_relative_non_parallel_batched_link_info(show_viewer, tol):
scene = gs.Scene(
rigid_options=gs.options.RigidOptions(batch_links_info=True),
show_viewer=show_viewer,
show_FPS=False,
)
box = scene.add_entity(
gs.morphs.Box(
fixed=True,
batch_fixed_verts=True,
size=(0.04, 0.04, 0.04),
pos=(0.0, 0.0, 0.1),
euler=(0.0, 30.0, 0.0),
)
)
scene.build()

quat_rel = box.get_quat(relative=True)
assert quat_rel.shape == (4,)
assert_allclose(quat_rel, gu.identity_quat(), tol=tol)


@pytest.mark.required
def test_normalized_quat(show_viewer, tol):
Expand Down Expand Up @@ -5062,6 +5141,10 @@ def test_merge_entities(is_fixed, merge_fixed_links, show_viewer, tol, monkeypat

attach_link = franka.get_link("attachment")
assert_allclose(attach_link.get_pos(), hand.links[0].get_pos(), tol=gs.EPS)
hand_quat_rel = hand.get_quat(relative=True)
hand_init_quat = torch.as_tensor(hand.base_link.quat, dtype=gs.tc_float, device=gs.device)
hand_quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(hand_init_quat), hand.get_quat())
assert_allclose(hand_quat_rel, hand_quat_rel_ref, tol=tol)
offset_quat = gu.transform_quat_by_quat(hand.links[0].get_quat(), gu.inv_quat(attach_link.get_quat()))
assert_allclose(gu.quat_to_xyz(offset_quat, rpy=False, degrees=True), EULER_OFFSET, tol=tol)
for link in hand.links[slice(0, None) if merge_fixed_links else slice(1, -1)]:
Expand Down