Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/rl/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ def test_invalid_norm(self, norm_val, loss_agg_mode):
norm=norm_val,
)

def test_compute_kl_divergence_bf16(self):
per_token_logps = jnp.array([-10.0, -1.0, 0.0], dtype=jnp.bfloat16)
ref_per_token_logps = jnp.array([-1.0, -10.0, 0.0], dtype=jnp.bfloat16)

kl = common.compute_kl_divergence(
per_token_logps, ref_per_token_logps, method="low_var_kl"
)
self.assertEqual(kl.dtype, jnp.float32)
expected_kl = common.compute_kl_divergence(
per_token_logps.astype(jnp.float32),
ref_per_token_logps.astype(jnp.float32),
method="low_var_kl",
)
np.testing.assert_allclose(kl, expected_kl, rtol=1e-3)

def test_aggregate_loss_bf16(self):
per_token_loss = jnp.array([1.0, 2.0, 3.0], dtype=jnp.bfloat16)
completion_mask = jnp.array([1, 1, 0], dtype=jnp.int32)

loss = common.aggregate_loss(
per_token_loss, completion_mask, loss_agg_mode="token-mean"
)
self.assertEqual(loss.dtype, jnp.float32)
self.assertAlmostEqual(loss, 1.5, places=5)


if __name__ == "__main__":
absltest.main()
9 changes: 8 additions & 1 deletion tunix/rl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def compute_kl_divergence(
Returns:
KL divergence.
"""
per_token_logps = per_token_logps.astype(jnp.float32)
if ref_per_token_logps is not None:
ref_per_token_logps = ref_per_token_logps.astype(jnp.float32)

if method == "kl":
return per_token_logps - ref_per_token_logps
elif method == "mse_kl":
Expand Down Expand Up @@ -373,8 +377,11 @@ def aggregate_loss(
Aggregated loss.
"""

per_token_loss = per_token_loss.astype(jnp.float32)

if loss_agg_mode == "token-mean":
# sum all the token loss, and average by total number of completion token in the batch
# sum all the token loss, and average by total number of completion tokens
# in the batch
loss = (per_token_loss * completion_mask).sum() / (
jnp.clip(completion_mask.sum(), min=1)
)
Expand Down
15 changes: 10 additions & 5 deletions tunix/rl/experimental/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,15 @@ def grpo_loss_fn(
stop_gradient=False,
return_logits=False,
)
advantages = train_example.advantages
per_token_logps = jnp.astype(per_token_logps, jnp.float32)
advantages = jnp.astype(train_example.advantages, jnp.float32)

if train_example.old_per_token_logps is None:
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
else:
old_per_token_logps = train_example.old_per_token_logps
old_per_token_logps = jnp.astype(
train_example.old_per_token_logps, jnp.float32
)

seq_importance_ratio = per_token_logps - old_per_token_logps
# TODO(sizhi): Refactor this to a separate function.
Expand All @@ -544,7 +547,7 @@ def grpo_loss_fn(
per_token_loss = -jnp.minimum(
coef_1 * jnp.expand_dims(advantages, 1),
coef_2 * jnp.expand_dims(advantages, 1),
)
).astype(jnp.float32)

aux = {"kl": 0.0}
if beta is not None and beta != 0.0:
Expand All @@ -554,8 +557,9 @@ def grpo_loss_fn(
per_token_loss = per_token_loss + beta * kl

# Log mean KL.
aux["kl"] = (kl * completion_mask).sum() / jnp.clip(
completion_mask.sum(), min=1
aux["kl"] = jnp.astype(
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
jnp.float32,
)

loss = common.aggregate_loss(
Expand All @@ -576,6 +580,7 @@ def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array:
Returns:
Group relative advantages.
"""
rewards = jnp.astype(rewards, jnp.float32)
mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1)
std_grouped_rewards = rewards.reshape(-1, num_generations).std(
axis=-1, ddof=1
Expand Down