Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,6 @@ redis-data/*

sotopia/cli/install/redis-data/*
redis-stack-server-*/

*.rdb
examples/experimental/negotiation_arena/redis-data/*
80 changes: 80 additions & 0 deletions docs/pages/examples/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,86 @@ Currently this script would run over 100 simulations on the Sotopia Hard tasks.

An example script is provided in `scripts/display_benchmark_results.sh`

## Using Customized Agents with Benchmark

The default `sotopia benchmark` command uses `LLMAgent` for all agents. If you want to use a customized agent class (e.g., a subclass of `LLMAgent` with custom behavior), you can create your own benchmark script that calls `_benchmark_impl` directly.

Here's an example of how to create a custom benchmark command that uses a customized agent:

```python
import typer
from sotopia.cli.benchmark.benchmark import _benchmark_impl
from sotopia.agents import LLMAgent
from typing import Any, Type
from typing_extensions import Annotated

app = typer.Typer(pretty_exceptions_enable=False)


# Define your custom agent class
class CustomSocialWorldModelAgent(LLMAgent):
"""Example custom agent that extends LLMAgent with additional functionality."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Set default social_world_model_name if not provided
if "social_world_model_name" not in kwargs:
kwargs["social_world_model_name"] = "gpt-4.1-2025-04-14"
super().__init__(*args, **kwargs)


@app.command(name="run-custom-benchmark")
def run_custom_benchmark(
models: Annotated[
str, typer.Option(help="Comma-separated list of models to benchmark")
] = "gpt-4.1-2025-04-14",
partner_model: Annotated[
str, typer.Option(help="Partner model to use")
] = "gpt-4o-2024-08-06",
experiment_tag: Annotated[
str, typer.Option(help="Tag for the benchmark run")
] = "custom_agent_trial",
batch_size: Annotated[int, typer.Option(help="Batch size for processing")] = 100,
push_to_db: Annotated[
bool, typer.Option(help="Whether to push results to database")
] = True,
evaluator_model: Annotated[
str, typer.Option(help="Model to use for evaluation")
] = "gpt-4o",
task: Annotated[str, typer.Option(help="Task difficulty level")] = "hard",
) -> None:
"""Run benchmark with custom agent class."""
# Call _benchmark_impl with your custom agent class
_benchmark_impl(
models=models.split(","),
agent_class=CustomSocialWorldModelAgent, # Use your custom agent
partner_model=partner_model,
evaluator_model=evaluator_model,
batch_size=batch_size,
task=task,
push_to_db=push_to_db,
tag=experiment_tag,
)


if __name__ == "__main__":
app()
```

### Key Points:

1. **Custom Agent Class**: Your custom agent must be a subclass of `LLMAgent` (or another agent class that implements the same interface).

2. **Using `_benchmark_impl`**: The `_benchmark_impl` function accepts an `agent_class` parameter that allows you to specify which agent class to use for the benchmark.

3. **Agent Initialization**: When creating your custom agent, make sure it accepts the same initialization parameters as `LLMAgent` (e.g., `agent_profile`, `model_name`, etc.) and passes them to the parent class.

4. **Running the Custom Benchmark**: Save your script and run it like any other Python script:
```bash
python your_custom_benchmark.py run-custom-benchmark --models gpt-4.1-2025-04-14
```

For more information on creating custom agents, see the [Creating your own agents](/concepts/agents#creating-your-own-agents) section.

# Benchmark your model as a evaluator

```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sotopia.database.persistent_profile import AgentProfile
from typing import Any

from sotopia.generation_utils import agenerate, StrOutputParser, custom_temperature
from sotopia.generation_utils import agenerate, StrOutputParser

# Check Python version
if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -131,7 +131,7 @@ async def aact(self, obs: Observation) -> AgentAction:
"agent_name": self.name,
},
output_parser=StrOutputParser(),
temperature=custom_temperature(0.7),
temperature=0.7,
)

return AgentAction(
Expand Down
3 changes: 1 addition & 2 deletions examples/generation_api/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sotopia.generation_utils import (
ListOfIntOutputParser,
agenerate,
default_temperature,
)
import logging

Expand All @@ -27,7 +26,7 @@ async def generate_n_random_numbers(n: int) -> list[int]:
template="Generate {n} random integer numbers. {format_instructions}",
input_values={"n": str(n)},
output_parser=ListOfIntOutputParser(n),
temperature=default_temperature(0.0),
temperature=0.0,
)


Expand Down
2 changes: 1 addition & 1 deletion examples/minimalist_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
asyncio.run(
run_async_server(
model_dict={
"env": "gpt-4",
"env": "gpt-4o-mini",
"agent1": "gpt-4o-mini",
"agent2": "gpt-4o-mini",
},
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ realtime = [
"pyaudio>=0.2.14,<0.3.0",
]

[tool.uv]
dev-dependencies = [
[dependency-groups]
dev = [
"pre-commit",
"nbmake",
"types-setuptools",
Expand Down Expand Up @@ -85,6 +85,7 @@ aact = { git = "https://github.com/ProKil/aact" , branch = "main" }
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
asyncio_default_fixture_loop_scope = "function"
markers = [
"real_llm: marks tests as requiring real LLM API calls (deselect with '-m \"not real_llm\"')",
"slow: marks tests as slow running (deselect with '-m \"not slow\"')",
Expand Down
7 changes: 4 additions & 3 deletions sotopia/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def aact(self, obs: Observation) -> AgentAction:
)

if len(obs.available_actions) == 1 and "none" in obs.available_actions:
return AgentAction(action_type="none", argument="")
return AgentAction(action_type="none", argument="", to=[])
else:
# Use agent names from script_background if available
agent_names = (
Expand All @@ -84,6 +84,7 @@ async def aact(self, obs: Observation) -> AgentAction:
agent=self.agent_name,
goal=self.goal,
script_like=self.script_like,
structured_output=True,
agent_names=agent_names,
sender=self.agent_name,
)
Expand Down Expand Up @@ -167,7 +168,7 @@ def act(self, obs: Observation) -> AgentAction:
action_type = obs.available_actions[int(input("Action type: "))]
argument = input("Argument: ")

return AgentAction(action_type=action_type, argument=argument)
return AgentAction(action_type=action_type, argument=argument, to=[])

async def aact(self, obs: Observation) -> AgentAction:
self.recv_message("Environment", obs)
Expand Down Expand Up @@ -197,7 +198,7 @@ async def aact(self, obs: Observation) -> AgentAction:
else:
argument = ""

return AgentAction(action_type=action_type, argument=argument)
return AgentAction(action_type=action_type, argument=argument, to=[])


class Agents(dict[str, BaseAgent[Observation, AgentAction]]):
Expand Down
6 changes: 3 additions & 3 deletions sotopia/agents/redis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def aact(
)
)
last_timestamp = sorted_message_list[-1][0]
return AgentAction(action_type="none", argument="")
return AgentAction(action_type="none", argument="", to=[])
else:
async with aiohttp.ClientSession() as session:
# 1. post observation to the message list
Expand Down Expand Up @@ -137,7 +137,7 @@ async def aact(
f"{self._URL}/lock/{self.session_id}/{self.sender_id}/no%20action",
)
self.reset("Someone has left or the conversation is too long.")
return AgentAction(action_type="leave", argument="")
return AgentAction(action_type="leave", argument="", to=[])
action_string = sorted_message_list[-1][2]
try:
action = AgentAction.model_validate_json(action_string)
Expand All @@ -149,7 +149,7 @@ async def aact(
)
)
return AgentAction(
action_type="speak", argument=sorted_message_list[-1][2]
action_type="speak", argument=sorted_message_list[-1][2], to=[]
)

def reset(
Expand Down
Loading