Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,25 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> None:
"""Callback executed when an event carries state changes.

This callback is invoked after an event with a non-empty
``state_delta`` is yielded from the runner. It is observational, but
returning a non-`None` value will short-circuit subsequent plugins.

Args:
callback_context: The context for the current invocation.
state_delta: A copy of the state changes carried by the event.
Mutating this dict does not affect the original state.

Returns:
None
"""
pass
7 changes: 0 additions & 7 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,13 +2510,6 @@ async def after_tool_callback(
parent_span_id_override=parent_span_id,
)

if tool_context.actions.state_delta:
await self._log_event(
"STATE_DELTA",
tool_context,
state_delta=tool_context.actions.state_delta,
)

async def on_tool_error_callback(
self,
*,
Expand Down
14 changes: 14 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_state_change_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -257,6 +258,19 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> Optional[Any]:
"""Runs the `on_state_change_callback` for all plugins."""
return await self._run_callbacks(
"on_state_change_callback",
callback_context=callback_context,
state_delta=state_delta,
)

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
23 changes: 20 additions & 3 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,6 @@ async def _exec_with_plugin(
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
Expand Down Expand Up @@ -839,9 +838,19 @@ async def _exec_with_plugin(
_apply_run_config_custom_metadata(
modified_event, invocation_context.run_config
)
yield modified_event
final_event = modified_event
else:
yield event
final_event = event
yield final_event

# Step 3b: Notify plugins of state changes, if any.
if final_event.actions.state_delta:
from .agents.callback_context import CallbackContext

await plugin_manager.run_on_state_change_callback(
callback_context=CallbackContext(invocation_context),
state_delta=dict(final_event.actions.state_delta),
)

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
Expand Down Expand Up @@ -1485,6 +1494,14 @@ async def _handle_new_message(
state_delta=state_delta,
)

if state_delta:
from .agents.callback_context import CallbackContext

await invocation_context.plugin_manager.run_on_state_change_callback(
callback_context=CallbackContext(invocation_context),
state_delta=dict(state_delta),
)

def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/plugins/test_base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str:
async def on_model_error_callback(self, **kwargs) -> str:
return "overridden_on_model_error"

async def on_state_change_callback(self, **kwargs) -> str:
return "overridden_on_state_change"


def test_base_plugin_initialization():
"""Tests that a plugin is initialized with the correct name."""
Expand Down Expand Up @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none():
)
is None
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_context,
state_delta={},
)
is None
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
)
== "overridden_on_model_error"
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_callback_context,
state_delta={"key": "value"},
)
== "overridden_on_state_change"
)
38 changes: 10 additions & 28 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,16 +1560,19 @@ async def test_after_tool_callback_logs_correctly(
assert content_dict["result"] == {"res": "success"}

@pytest.mark.asyncio
async def test_after_tool_callback_state_delta_logging(
async def test_after_tool_callback_no_inline_state_delta(
self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema
):
"""after_tool_callback does not log STATE_DELTA inline.

STATE_DELTA is logged exclusively via on_state_change_callback.
"""
mock_tool = mock.create_autospec(
base_tool_lib.BaseTool, instance=True, spec_set=True
)
type(mock_tool).name = mock.PropertyMock(return_value="StateTool")
type(mock_tool).description = mock.PropertyMock(return_value="Sets state")

# Simulate a tool modifying the state
tool_context.actions.state_delta["new_key"] = "new_value"

bigquery_agent_analytics_plugin.TraceManager.push_span(tool_context)
Expand All @@ -1581,31 +1584,11 @@ async def test_after_tool_callback_state_delta_logging(
)
await asyncio.sleep(0.01)

# We should have two events appended: TOOL_COMPLETED and STATE_DELTA
assert mock_write_client.append_rows.call_count >= 1

# Retrieve all flushed events
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
assert len(rows) == 2

# Sort by event_type to reliably access them
rows.sort(key=lambda x: x["event_type"])

state_delta_event = (
rows[0] if rows[0]["event_type"] == "STATE_DELTA" else rows[1]
)
tool_event = (
rows[1] if rows[1]["event_type"] == "TOOL_COMPLETED" else rows[0]
# Only TOOL_COMPLETED should be logged
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)

assert state_delta_event["event_type"] == "STATE_DELTA"
assert tool_event["event_type"] == "TOOL_COMPLETED"

# Verify STATE_DELTA payload
attributes = json.loads(state_delta_event["attributes"])
assert "state_delta" in attributes
assert attributes["state_delta"] == {"new_key": "new_value"}
assert state_delta_event["content"] is None
assert log_entry["event_type"] == "TOOL_COMPLETED"

@pytest.mark.asyncio
async def test_on_state_change_callback_logs_correctly(
Expand All @@ -1615,6 +1598,7 @@ async def test_on_state_change_callback_logs_correctly(
callback_context,
dummy_arrow_schema,
):
"""STATE_DELTA is logged via on_state_change_callback."""
state_delta = {"key": "value", "new_key": 123}
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
await bq_plugin_inst.on_state_change_callback(
Expand All @@ -1625,10 +1609,8 @@ async def test_on_state_change_callback_logs_correctly(
mock_write_client, dummy_arrow_schema
)
_assert_common_fields(log_entry, "STATE_DELTA")
# content should be None (as raw_content was not passed)
assert log_entry["content"] is None

# state_delta should be in attributes
attributes = json.loads(log_entry["attributes"])
assert attributes["state_delta"] == state_delta

Expand Down
62 changes: 62 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_state_change_callback(self, **kwargs):
return await self._handle_callback("on_state_change_callback")


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_state_change_callback(
callback_context=mock_context,
state_delta={"key": "value"},
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_state_change_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -317,3 +325,57 @@ async def slow_close():
assert "Failed to close plugins: 'plugin1': TimeoutError" in str(
excinfo.value
)


# --- on_state_change_callback tests ---


@pytest.mark.asyncio
async def test_run_on_state_change_callback(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that run_on_state_change_callback invokes the callback and returns None."""
service.register_plugin(plugin1)
result = await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)
assert result is None
assert "on_state_change_callback" in plugin1.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_calls_all_plugins(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_state_change_callback is called on all plugins."""
service.register_plugin(plugin1)
service.register_plugin(plugin2)

await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "on_state_change_callback" in plugin1.call_log
assert "on_state_change_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_wraps_exceptions(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that exceptions in on_state_change_callback are wrapped in RuntimeError."""
original_exception = ValueError("state change error")
plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_state_change_callback" in str(excinfo.value)
assert excinfo.value.__cause__ is original_exception
47 changes: 47 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(self):
self.enable_user_message_callback = False
self.enable_event_callback = False
self.user_content_seen_in_before_run_callback = None
self.state_change_deltas: list[dict] = []

async def on_user_message_callback(
self,
Expand All @@ -169,6 +170,9 @@ async def before_run_callback(
invocation_context.user_content
)

async def on_state_change_callback(self, *, callback_context, state_delta, **kwargs):
self.state_change_deltas.append(state_delta)

async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
Expand Down Expand Up @@ -853,6 +857,49 @@ async def test_runner_passes_plugin_close_timeout(self):
)
assert runner.plugin_manager._close_timeout == 10.0

@pytest.mark.asyncio
async def test_state_delta_in_run_async_triggers_on_state_change_callback(
self,
):
"""Test that caller-supplied state_delta triggers on_state_change_callback."""
await self.session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
state_delta = {"lang": "en", "theme": "dark"}
events = []
async for event in self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(
role="user", parts=[types.Part(text="Hello")]
),
state_delta=state_delta,
):
events.append(event)

assert len(self.plugin.state_change_deltas) >= 1
assert self.plugin.state_change_deltas[0] == state_delta

@pytest.mark.asyncio
async def test_no_state_delta_does_not_trigger_on_state_change_callback(
self,
):
"""Test that on_state_change_callback is not called when no state_delta is provided."""
await self.session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
events = []
async for event in self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(
role="user", parts=[types.Part(text="Hello")]
),
):
events.append(event)

assert len(self.plugin.state_change_deltas) == 0

@pytest.mark.filterwarnings(
"ignore:The `plugins` argument is deprecated:DeprecationWarning"
)
Expand Down
Loading