Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 7 additions & 29 deletions tests/rl/agentic/rewards/reward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from absl.testing import parameterized

from tunix.rl.agentic.rewards import reward
from tunix.rl.agentic.rewards import reward_types


class RewardTest(parameterized.TestCase):
Expand All @@ -34,7 +33,7 @@ def test_registry(self):
"""Tests the reward function registry mechanism."""
# A simple reward function for testing registration
def test_fn(task, action):
return reward_types.RewardOutput(0.5, {})
return 0.5

reward.register("test_fn")(test_fn)
self.assertIs(reward.get_reward_fn("test_fn"), test_fn)
Expand All @@ -58,8 +57,7 @@ def test_exact_match(self, ground_truth, action, expected_score):
"""Tests the exact_match reward function."""
task = {"ground_truth": ground_truth}
result = reward.exact_match(task, action)
self.assertEqual(result.reward, expected_score)
self.assertEqual(result.metadata["exact_match"], expected_score)
self.assertEqual(result, expected_score)

@parameterized.named_parameters(
("integer_string", "2", 1.0),
Expand All @@ -71,14 +69,12 @@ def test_exact_match(self, ground_truth, action, expected_score):
def test_is_two_reward(self, action, expected_score):
"""Tests the is_two_reward function."""
result = reward.is_two_reward({}, action)
self.assertEqual(result.reward, expected_score)
self.assertEqual(result.metadata["is_two"], expected_score)
self.assertEqual(result, expected_score)

def test_dummy_reward(self):
"""Tests the dummy_reward function."""
result = reward.dummy_reward({}, "any action")
self.assertEqual(result.reward, 0.0)
self.assertEqual(result.metadata, {})
self.assertEqual(result, 0.0)

@parameterized.named_parameters(
("correct", "2 + 2 = ?", "The answer is 4", 1.0),
Expand All @@ -91,24 +87,6 @@ def test_calculate_reward(self, question, action, expected_score):
"""Tests the calculate_reward function."""
task = {"question": question}
result = reward.calculate_reward(task, action)
self.assertEqual(result.reward, expected_score)
self.assertEqual(result.metadata["calculate_correct"], expected_score)

def test_combine_rewards(self):
"""Tests the reward combination logic."""
weights = {"exact_match": 0.7, "dummy": 0.3}
combined_fn = reward.combine_rewards(weights)

# Case 1: Matches exact_match
task = {"ground_truth": "hello"}
action = "hello"
result = combined_fn(task, action)
self.assertAlmostEqual(result.reward, 0.7)
self.assertEqual(result.metadata, {"exact_match": 1.0})

# Case 2: Does not match exact_match
task = {"ground_truth": "world"}
action = "hello"
result = combined_fn(task, action)
self.assertAlmostEqual(result.reward, 0.0)
self.assertEqual(result.metadata, {"exact_match": 0.0})
self.assertEqual(result, expected_score)


Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from tunix.rl.agentic.agents import agent_types
from tunix.rl.agentic.agents import base_agent
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.rewards import reward_types
from tunix.rl.agentic.trajectory import trajectory_collect_engine
from tunix.rl.agentic import utils

Expand All @@ -40,7 +39,7 @@ def setUp(self):

self.mock_model_call = mock.Mock()
self.mock_final_reward_fn = mock.Mock(
return_value=reward_types.RewardOutput(reward=0.5)
return_value=0.5
)
self.mock_tokenizer = mock.Mock()
self.mock_tokenizer.encode.return_value = [1, 2, 3]
Expand Down Expand Up @@ -177,11 +176,11 @@ def test_collect_with_tokenization(self, mock_convert):
'prompt_tokens': [101],
'conversation_tokens': [201, 202, 301, 302, 203, 204, 303, 304],
'conversation_masks': [1, 1, 1, 1, 1, 1, 1, 1],
'status': 'SUCCEEDED',
'trajectory_reward': 3.0, # 1.0 + 2.0
'policy_version': None,
'original_input': {'some': 'task'},
'group_id': None,
'status': 'SUCCEEDED',
}
self.assertEqual(token_data, expected_tokens)

Expand Down
69 changes: 20 additions & 49 deletions tests/rl/experimental/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,8 @@
TrainingInputT = agentic_grpo_learner.TrainingInputT


def reward_fn_1(prompts, completions, **kwargs):
del prompts, kwargs
return [float(i) for i in range(len(completions))]


def reward_fn_2(answer, **kwargs):
del kwargs
return [float(i) for i in range(len(answer))]
def reward_fn_1(task, action):
return float(len(str(action)))


_MOCK_RESPONSES = [
Expand Down Expand Up @@ -342,7 +336,7 @@ def test_num_iterations_greater_than_1(self):
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
metric_fns=[lambda **kwargs: {"test_metric": (1.0, np.mean)}],
chat_parser=MockChatParser(),
Expand Down Expand Up @@ -491,7 +485,7 @@ def create_learner(
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand Down Expand Up @@ -535,9 +529,8 @@ def test_micro_batch_training(
mini_batch_size,
train_micro_batch_size,
):
def reward_fn(prompts, **kwargs):
del kwargs
return [1.0] * len(prompts)
def reward_fn(task, action):
return float(len(str(action)))

def create_learner(
mini_batch_size,
Expand Down Expand Up @@ -591,7 +584,7 @@ def create_learner(
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn,
reward_fn=reward_fn,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand Down Expand Up @@ -696,7 +689,7 @@ def create_learner(
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn,
reward_fn=reward_fn,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand Down Expand Up @@ -769,7 +762,7 @@ def test_exception_handling(self):
grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=10)
learner = _LearnerWithException(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand All @@ -779,25 +772,17 @@ def test_exception_handling(self):

@parameterized.named_parameters(
dict(
testcase_name="single_reward_fn",
reward_fns=reward_fn_1,
loss_algo="grpo",
),
dict(
testcase_name="multiple_reward_fns",
reward_fns=[
reward_fn_1,
reward_fn_2,
],
testcase_name="grpo_algo",
reward_fn=reward_fn_1,
loss_algo="grpo",
),
dict(
testcase_name="single_reward_fn_gspo",
reward_fns=reward_fn_1,
testcase_name="gspo_algo",
reward_fn=reward_fn_1,
loss_algo="gspo-token",
),
)
def test_grpo_learner(self, reward_fns, loss_algo):
def test_grpo_learner(self, reward_fn, loss_algo):
vocab = _mock_vocab()
tokenizer = tokenizer_adapter.TokenizerAdapter(vocab)
model = test_common.ToyTransformer(
Expand Down Expand Up @@ -846,7 +831,7 @@ def test_grpo_learner(self, reward_fns, loss_algo):
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fns,
reward_fn=reward_fn,
algo_config=grpo_config,
metric_fns=[lambda **kwargs: {"test_metric": (1.0, np.mean)}],
chat_parser=MockChatParser(),
Expand All @@ -870,28 +855,14 @@ def test_grpo_learner(self, reward_fns, loss_algo):

rl_metric_logger = grpo_learner.rl_cluster._rl_metrics_logger

rewards_metrics = (
("rewards/" + f.__name__ for f in reward_fns)
if isinstance(reward_fns, list)
else ("rewards/" + reward_fns.__name__,)
)
for metric_name in [
"rewards/sum",
*rewards_metrics,
"generation/prompts/mean_length",
"generation/prompts/max_length",
"generation/prompts/min_length",
"generation/completions/mean_length",
"generation/completions/max_length",
"generation/completions/min_length",
"generation/completions/clip_ratio",
"perf/global_step_time",
"global/test_metric",
]:
if metric_name == "rewards/reward_fn_2" and not isinstance(
reward_fns, list
):
continue
# We log metrics per step, and sometimes one extra step is logged due to
# buffer flushing. So we check if length is close to global_steps.
prefix, metric_name = metric_name.split("/", maxsplit=1)
Expand Down Expand Up @@ -994,7 +965,7 @@ def test_on_off_policy_training(self, offpolicy_steps):
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
metric_fns=[lambda **kwargs: {"test_metric": (1.0, np.mean)}],
chat_parser=MockChatParser(),
Expand Down Expand Up @@ -1051,7 +1022,7 @@ def test_put_prompts_to_queue(self):
grpo_config = agentic_grpo_learner.GRPOConfig()
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand Down Expand Up @@ -1120,7 +1091,7 @@ def test_trajectory_logging(self):
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
metric_fns=[lambda **kwargs: {"test_metric": (1.0, np.mean)}],
chat_parser=MockChatParser(),
Expand Down Expand Up @@ -1209,7 +1180,7 @@ def test_grpo_with_lora_model(self):

grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)
Expand Down Expand Up @@ -1332,7 +1303,7 @@ def update_from_model(self, response, **kwargs):
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
reward_fn=reward_fn_1,
algo_config=grpo_config,
metric_fns=[lambda **kwargs: {"test_metric": (1.0, np.mean)}],
chat_parser=MockChatParser(),
Expand Down
16 changes: 7 additions & 9 deletions tunix/rl/agentic/environments/task_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl import logging
from tunix.rl.agentic.agents import agent_types
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.rewards import reward


class TaskEnvironment(base_environment.BaseTaskEnv):
Expand All @@ -46,18 +45,17 @@ def __init__(
Args:
single_example: A single prompt.
reward_fn: Reward function that takes (task, action) and returns
RewardOutput with `.reward` and `.metadata` fields. If None, defaults to
`dummy_reward`.
a scalar reward. If None, defaults to a function that returns 0.0.
**kwargs: Extra arguments ignored by this environment but accepted for
compatibility with a common environment config interface.
"""
if reward_fn is None:
logging.log_first_n(
logging.WARNING,
"No reward_fn provided, defaulting to dummy_reward().",
"No reward_fn provided, defaulting to returning 0.0.",
1,
)
reward_fn = reward.dummy_reward
reward_fn = lambda *_: 0.0

super().__init__(
task=single_example, reward_fn=reward_fn, max_steps=1, **kwargs
Expand All @@ -78,16 +76,16 @@ def _step_impl(self, action: Any) -> base_environment.EnvStepResult:

Returns:
An `EnvStepResult` containing an empty observation, the calculated reward,
done=True, and info including the agent's response and reward metadata.
done=True, and info including the agent's response.
"""
if isinstance(action, agent_types.Action):
action = action.action
r_out = self.reward_fn(task=self.task, action=action)
reward_val = self.reward_fn(task=self.task, action=action)
return base_environment.EnvStepResult(
observation={},
reward=r_out.reward,
reward=reward_val,
done=True,
info={"response": action, "metadata": r_out.metadata},
info={"response": action},
)

@classmethod
Expand Down
16 changes: 7 additions & 9 deletions tunix/rl/agentic/environments/tool_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from absl import logging
from tunix.rl.agentic.agents import agent_types
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.rewards import reward
from tunix.rl.agentic.tools import base_tool
from tunix.rl.agentic.tools import tool_manager

Expand Down Expand Up @@ -64,19 +63,18 @@ def __init__(
tool_map (Dict[str, type[BaseTool]]): Mapping of tool names to their
implementation classes for tool discovery and execution.
reward_fn: Reward function that takes (task, action) and returns
RewardOutput with `.reward` and `.metadata` fields. If None, defaults to
`dummy_reward` with a warning.
a scalar reward. If None, defaults to returning 0.0 with a warning.
max_steps (int): Maximum number of interaction steps before forced
termination. Prevents infinite loops and controls episode length.
**kwargs: Additional arguments reserved for future extensions.
"""
if reward_fn is None:
logging.log_first_n(
logging.WARNING,
"No reward_fn provided, defaulting to dummy_reward().",
"No reward_fn provided, defaulting to returning 0.0.",
1,
)
reward_fn = reward.dummy_reward
reward_fn = lambda *_: 0.0

# Let BaseTaskEnv handle task, reward_fn, step_count, and max_steps.
super().__init__(
Expand Down Expand Up @@ -140,12 +138,12 @@ def _step_impl(self, action: Any) -> base_environment.EnvStepResult:
# Handle episode termination: compute final reward.
if done:
llm_answer = self._extract_llm_answer(action)
r_out = self.reward_fn(task=self.task, action=llm_answer)
reward_val = self.reward_fn(task=self.task, action=llm_answer)
return base_environment.EnvStepResult(
observation={},
reward=r_out.reward,
reward=reward_val,
done=True,
info={"response": action, "metadata": r_out.metadata},
info={"response": action},
)

# Handle continuing episode: execute tools and return intermediate results.
Expand All @@ -155,7 +153,7 @@ def _step_impl(self, action: Any) -> base_environment.EnvStepResult:
observation=obs,
reward=0.0,
done=False,
info={"response": action, "metadata": {}},
info={"response": action},
)

@staticmethod
Expand Down
Loading