-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathtest_tool_failure_simulation.py
More file actions
241 lines (206 loc) · 9.37 KB
/
test_tool_failure_simulation.py
File metadata and controls
241 lines (206 loc) · 9.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Example test demonstrating tool failure simulation with real LLM tool calling.
This example shows how to test agent resilience by simulating tool failures,
timeouts, and other error conditions while using actual LLM tool calling.
"""
import pytest
import scenario
from unittest.mock import patch
import litellm
import json
def call_external_service(endpoint: str) -> str:
"""Call an external service."""
# This would normally make an external API call
raise NotImplementedError("This should be mocked in tests")
class ResilientAgent(scenario.AgentAdapter):
"""Agent that uses real LLM tool calling and handles external service failures gracefully."""
async def call(self, input: scenario.AgentInput) -> scenario.AgentReturnTypes:
# Define the external service tool schema for the LLM
tool_schemas = [
{
"type": "function",
"function": {
"name": "call_external_service",
"description": "Call an external service API endpoint",
"parameters": {
"type": "object",
"properties": {
"endpoint": {
"type": "string",
"description": "The API endpoint to call",
}
},
"required": ["endpoint"],
},
},
}
]
# Let the LLM decide when and how to call the external service tool
response = litellm.completion(
model="openai/gpt-4o-mini",
messages=input.messages,
tools=tool_schemas,
tool_choice="auto",
)
message = response.choices[0].message # type: ignore[attr-defined] # litellm response has dynamic attributes
# Handle any tool calls the LLM decided to make
if message.tool_calls:
tool_responses = []
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
# LLM provides the arguments (endpoint) extracted from user input
args = json.loads(tool_call.function.arguments)
# Execute the appropriate tool function
if tool_name == "call_external_service":
try:
# Call the actual external service tool with LLM-extracted parameters
# This is where our failure simulation takes effect
tool_result = call_external_service(**args)
tool_responses.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(tool_result),
}
)
except Exception as e:
# Handle service call errors gracefully - this is what we're testing
tool_responses.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"Error: {str(e)}",
}
)
# If tools were called, get the LLM's final response based on service results
if tool_responses:
follow_up_response = litellm.completion(
model="openai/gpt-4o-mini",
messages=input.messages + [message] + tool_responses,
)
return follow_up_response.choices[0].message.content or "" # type: ignore[attr-defined] # litellm response has dynamic attributes
# Return the LLM's direct response if no tools were called
return message.content or ""
def check_error_in_message(state: scenario.ScenarioState) -> None:
"""Check that the agent's message contains error or timeout information."""
last_msg = state.last_message()
if last_msg["role"] == "assistant":
content = last_msg.get("content", "")
# Check for various error indicators the LLM might use
error_indicators = ["error", "timeout", "timed out", "failed", "issue"]
content_str = content if isinstance(content, str) else str(content)
assert any(indicator in content_str.lower() for indicator in error_indicators)
def check_rate_limit_in_message(state: scenario.ScenarioState) -> None:
"""Check that the agent's message contains rate limit error information."""
last_msg = state.last_message()
if last_msg["role"] == "assistant":
content = last_msg.get("content", "")
# Check for various rate limit indicators the LLM might use
rate_limit_indicators = [
"rate limit",
"exceeded",
"limit exceeded",
"too many requests",
]
content_str = content if isinstance(content, str) else str(content)
assert any(
indicator in content_str.lower() for indicator in rate_limit_indicators
)
def check_success_in_message(state: scenario.ScenarioState) -> None:
"""Check that the agent's message contains success information."""
last_msg = state.last_message()
if last_msg["role"] == "assistant":
content = last_msg.get("content", "")
# Check for various success indicators the LLM might use
success_indicators = [
"successful",
"success",
"completed",
"call was successful",
]
content_str = content if isinstance(content, str) else str(content)
assert any(indicator in content_str.lower() for indicator in success_indicators)
@pytest.mark.agent_test
@pytest.mark.flaky(reruns=2)
@pytest.mark.asyncio
async def test_tool_timeout_simulation():
"""Test agent's ability to handle tool timeouts."""
with patch("test_tool_failure_simulation.call_external_service") as mock_service:
# Simulate timeout error
mock_service.side_effect = Exception("Request timeout")
result = await scenario.run(
name="tool timeout test",
description="Test agent's ability to handle tool timeouts",
agents=[
ResilientAgent(),
scenario.UserSimulatorAgent(model="openai/gpt-4o-mini"),
],
script=[
scenario.user("Call the external service at endpoint /api/data"),
scenario.agent(),
# Verify the mock was called with specific endpoint extracted by the LLM
# This proves the LLM correctly extracted "/api/data" from the user message
lambda state: mock_service.assert_called_once_with(
endpoint="/api/data"
),
check_error_in_message,
scenario.succeed(),
],
)
assert result.success
@pytest.mark.agent_test
@pytest.mark.flaky(reruns=2)
@pytest.mark.asyncio
async def test_tool_rate_limit_simulation():
"""Test agent's ability to handle rate limits."""
with patch("test_tool_failure_simulation.call_external_service") as mock_service:
# Simulate rate limit error
mock_service.side_effect = Exception("Rate limit exceeded")
result = await scenario.run(
name="tool rate limit test",
description="Test agent's ability to handle rate limits",
agents=[
ResilientAgent(),
scenario.UserSimulatorAgent(model="openai/gpt-4o-mini"),
],
script=[
scenario.user("Call the external service at endpoint /api/data"),
scenario.agent(),
# Verify the mock was called with specific endpoint extracted by the LLM
# This proves the LLM correctly extracted "/api/data" from the user message
lambda state: mock_service.assert_called_once_with(
endpoint="/api/data"
),
check_rate_limit_in_message,
scenario.succeed(),
],
)
assert result.success
@pytest.mark.agent_test
@pytest.mark.flaky(reruns=2)
@pytest.mark.asyncio
async def test_tool_success_simulation():
"""Test agent's ability to handle successful tool calls."""
with patch("test_tool_failure_simulation.call_external_service") as mock_service:
# Simulate successful service call
mock_service.return_value = "Service call successful"
result = await scenario.run(
name="tool success test",
description="Test agent's ability to handle successful tool calls",
agents=[
ResilientAgent(),
scenario.UserSimulatorAgent(model="openai/gpt-4o-mini"),
],
script=[
scenario.user("Call the external service at endpoint /api/data"),
scenario.agent(),
# Verify the mock was called with specific endpoint extracted by the LLM
# This proves the LLM correctly extracted "/api/data" from the user message
lambda state: mock_service.assert_called_once_with(
endpoint="/api/data"
),
check_success_in_message,
scenario.succeed(),
],
)
assert result.success