diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs
new file mode 100644
index 00000000000..5fdfae1be2b
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs
@@ -0,0 +1,70 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Shared.DiagnosticIds;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+/// Provides context for an iteration within the function invocation loop.
+///
+/// This context is provided to the
+/// callback after each iteration of the function invocation loop completes.
+///
+[Experimental(DiagnosticIds.Experiments.AIIterationCompleted, UrlFormat = DiagnosticIds.UrlFormat)]
+public class FunctionInvocationIterationContext
+{
+ /// Gets or sets the current iteration number (0-based).
+ ///
+ /// The initial request to the client that passes along the chat contents provided to the
+ /// is iteration 0. If the client responds with
+ /// a function call request that is processed, the next iteration is 1, and so on.
+ ///
+ public int Iteration { get; set; }
+
+ /// Gets or sets the aggregated usage details across all iterations so far.
+ ///
+ /// This includes usage from all inner client calls and is updated after each iteration.
+ /// May be if the inner client doesn't provide usage information.
+ ///
+ public UsageDetails? TotalUsage { get; set; }
+
+ /// Gets or sets the messages accumulated during the function-calling loop.
+ ///
+ /// This includes all messages from all iterations, including function call and result contents.
+ ///
+ public IReadOnlyList Messages
+ {
+ get;
+ set => field = Throw.IfNull(value);
+ } = [];
+
+ /// Gets or sets the response from the most recent inner client call.
+ ///
+ /// This is the response that triggered the current iteration's function invocations.
+ ///
+ public ChatResponse Response
+ {
+ get;
+ set => field = Throw.IfNull(value);
+ } = new([]);
+
+ /// Gets or sets a value indicating whether to terminate the loop after this iteration.
+ ///
+ ///
+ /// Setting this to will cause the function invocation loop to exit
+ /// after the current iteration completes. The function calls for this iteration will have
+ /// already been processed before this callback is invoked.
+ ///
+ ///
+ /// This is similar to setting from within
+ /// a function, but can be triggered based on external criteria like usage thresholds.
+ ///
+ ///
+ public bool Terminate { get; set; }
+
+ /// Gets or sets a value indicating whether the iteration is part of a streaming operation.
+ public bool IsStreaming { get; set; }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
index cd449248867..2b4b84e461f 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
@@ -12,6 +12,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Shared.DiagnosticIds;
using Microsoft.Shared.Diagnostics;
#pragma warning disable CA2213 // Disposable fields should be disposed
@@ -254,6 +255,39 @@ public int MaximumConsecutiveErrorsPerRequest
///
public bool TerminateOnUnknownCalls { get; set; }
+ /// Gets or sets a delegate called after each iteration of the function invocation loop completes.
+ ///
+ ///
+ /// This delegate is invoked after each iteration completes - specifically, after function invocations
+ /// are processed and before the next request is made to the inner client. This timing ensures that:
+ ///
+ ///
+ /// - Function calls are not lost mid-execution
+ /// - Function results are available for inspection
+ /// - Aggregated usage information is up-to-date
+ ///
+ ///
+ /// Common use cases include:
+ ///
+ ///
+ /// - Token usage monitoring and context compaction triggers
+ /// - Cost limit enforcement
+ /// - Content guardrails and moderation
+ /// - Time limit enforcement
+ /// - Custom iteration-level logging or metrics
+ ///
+ ///
+ /// Set to
+ /// to stop the loop. The function calls for the current iteration will have already been processed.
+ ///
+ ///
+ /// Changing the value of this property while the client is in use might result in inconsistencies
+ /// as to whether the callback is invoked for an in-flight request.
+ ///
+ ///
+ [Experimental(DiagnosticIds.Experiments.AIIterationCompleted, UrlFormat = DiagnosticIds.UrlFormat)]
+ public Func? IterationCompleted { get; set; }
+
/// Gets or sets a delegate used to invoke instances.
///
/// By default, the protected method is called for each to be invoked,
@@ -399,6 +433,27 @@ public override async Task GetResponseAsync(
break;
}
+ // Call the iteration completed hook if configured
+ if (IterationCompleted is { } iterationCallback)
+ {
+ var iterationContext = new FunctionInvocationIterationContext
+ {
+ Iteration = iteration,
+ TotalUsage = CloneUsageDetails(totalUsage),
+ Messages = responseMessages,
+ Response = response,
+ IsStreaming = false
+ };
+
+ await iterationCallback(iterationContext, cancellationToken);
+
+ if (iterationContext.Terminate)
+ {
+ LogIterationCallbackRequestedTermination(iteration);
+ break;
+ }
+ }
+
UpdateOptionsForNextIteration(ref options, response.ConversationId);
}
@@ -421,7 +476,10 @@ public override async IAsyncEnumerable GetStreamingResponseA
// Create an activity to group them together for better observability. If there's already a genai "invoke_agent"
// span that's current, however, we just consider that the group and don't add a new one.
using Activity? activity = CurrentActivityIsInvokeAgent ? null : _activitySource?.StartActivity(OpenTelemetryConsts.GenAI.OrchestrateToolsName);
- UsageDetails? totalUsage = activity is { IsAllDataRequested: true } ? new() : null; // tracked usage across all turns, to be used for activity purposes
+
+ // Track usage if needed for telemetry or for the iteration callback
+ bool needsUsageTracking = activity is { IsAllDataRequested: true } || IterationCompleted is not null;
+ UsageDetails? totalUsage = needsUsageTracking ? new() : null;
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
@@ -640,6 +698,27 @@ public override async IAsyncEnumerable GetStreamingResponseA
break;
}
+ // Call the iteration completed hook if configured
+ if (IterationCompleted is { } iterationCallback)
+ {
+ var iterationContext = new FunctionInvocationIterationContext
+ {
+ Iteration = iteration,
+ TotalUsage = CloneUsageDetails(totalUsage),
+ Messages = responseMessages,
+ Response = response,
+ IsStreaming = true
+ };
+
+ await iterationCallback(iterationContext, cancellationToken);
+
+ if (iterationContext.Terminate)
+ {
+ LogIterationCallbackRequestedTermination(iteration);
+ break;
+ }
+ }
+
UpdateOptionsForNextIteration(ref options, response.ConversationId);
}
@@ -677,6 +756,31 @@ private static void AddUsageTags(Activity? activity, UsageDetails? usage)
}
}
+ /// Creates a defensive copy of usage details to prevent mutation after callback returns.
+ private static UsageDetails? CloneUsageDetails(UsageDetails? usage)
+ {
+ if (usage is null)
+ {
+ return null;
+ }
+
+ var clone = new UsageDetails
+ {
+ InputTokenCount = usage.InputTokenCount,
+ OutputTokenCount = usage.OutputTokenCount,
+ TotalTokenCount = usage.TotalTokenCount,
+ CachedInputTokenCount = usage.CachedInputTokenCount,
+ ReasoningTokenCount = usage.ReasoningTokenCount,
+ };
+
+ if (usage.AdditionalCounts is { } additionalCounts)
+ {
+ clone.AdditionalCounts = new(additionalCounts);
+ }
+
+ return clone;
+ }
+
/// Prepares the various chat message lists after a response from the inner client and before invoking functions.
/// The original messages provided by the caller.
/// The messages reference passed to the inner client.
@@ -1851,6 +1955,9 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
[LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' requested termination of the processing loop.")]
private partial void LogFunctionRequestedTermination(string functionName);
+ [LoggerMessage(LogLevel.Debug, "Iteration {Iteration}: IterationCompleted callback requested termination.")]
+ private partial void LogIterationCallbackRequestedTermination(int iteration);
+
/// Provides information about the invocation of a function call.
public sealed class FunctionInvocationResult
{
diff --git a/src/Shared/DiagnosticIds/DiagnosticIds.cs b/src/Shared/DiagnosticIds/DiagnosticIds.cs
index 4cc736a7252..0a8b3e24a11 100644
--- a/src/Shared/DiagnosticIds/DiagnosticIds.cs
+++ b/src/Shared/DiagnosticIds/DiagnosticIds.cs
@@ -57,6 +57,7 @@ internal static class Experiments
internal const string AIResponseContinuations = AIExperiments;
internal const string AICodeInterpreter = AIExperiments;
internal const string AIRealTime = AIExperiments;
+ internal const string AIIterationCompleted = AIExperiments;
private const string AIExperiments = "MEAI001";
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
index 6cd6e857996..539a999212c 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
@@ -3382,4 +3382,379 @@ public async Task LogsFunctionRejected()
// ContinuesWithFailingCallsUntilMaximumConsecutiveErrors test which triggers
// the threshold condition. The logging call is at line 1078 and will execute
// when MaximumConsecutiveErrorsPerRequest is exceeded.
+
+ [Fact]
+ public void IterationCompleted_NullByDefault()
+ {
+ using TestChatClient innerClient = new();
+ using FunctionInvokingChatClient client = new(innerClient);
+
+ Assert.Null(client.IterationCompleted);
+ }
+
+ [Fact]
+ public void IterationCompleted_Roundtrip()
+ {
+ using TestChatClient innerClient = new();
+ using FunctionInvokingChatClient client = new(innerClient);
+
+ Assert.Null(client.IterationCompleted);
+ Func callback = (ctx, ct) => default;
+ client.IterationCompleted = callback;
+ Assert.Same(callback, client.IterationCompleted);
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task IterationCompleted_InvokedAfterEachIteration(bool streaming)
+ {
+ var iterationContexts = new List();
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
+ };
+
+ int callCount = 0;
+ using var innerClient = new TestChatClient
+ {
+ GetResponseAsyncCallback = async (contents, chatOptions, ct) =>
+ {
+ await Task.Yield();
+ callCount++;
+
+ if (callCount <= 2)
+ {
+ // Return function calls for first two iterations
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent($"callId{callCount}", "Func1")])])
+ {
+ Usage = new UsageDetails { InputTokenCount = 100 * callCount, OutputTokenCount = 50 * callCount, TotalTokenCount = 150 * callCount }
+ };
+ }
+ else
+ {
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")])
+ {
+ Usage = new UsageDetails { InputTokenCount = 300, OutputTokenCount = 150, TotalTokenCount = 450 }
+ };
+ }
+ },
+ GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) =>
+ {
+ callCount++;
+
+ ChatMessage message;
+ UsageDetails? usage;
+ if (callCount <= 2)
+ {
+ message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{callCount}", "Func1")]);
+ usage = new UsageDetails { InputTokenCount = 100 * callCount, OutputTokenCount = 50 * callCount, TotalTokenCount = 150 * callCount };
+ }
+ else
+ {
+ message = new ChatMessage(ChatRole.Assistant, "Done");
+ usage = new UsageDetails { InputTokenCount = 300, OutputTokenCount = 150, TotalTokenCount = 450 };
+ }
+
+ var updates = new ChatResponse(message) { Usage = usage }.ToChatResponseUpdates().ToList();
+ return YieldAsync(updates);
+ }
+ };
+
+ using var client = new FunctionInvokingChatClient(innerClient)
+ {
+ IterationCompleted = (ctx, ct) =>
+ {
+ iterationContexts.Add(new FunctionInvocationIterationContext
+ {
+ Iteration = ctx.Iteration,
+ TotalUsage = ctx.TotalUsage is not null
+ ? new UsageDetails
+ {
+ InputTokenCount = ctx.TotalUsage.InputTokenCount,
+ OutputTokenCount = ctx.TotalUsage.OutputTokenCount,
+ TotalTokenCount = ctx.TotalUsage.TotalTokenCount
+ }
+ : null,
+ Messages = ctx.Messages.ToList(),
+ Response = ctx.Response,
+ IsStreaming = ctx.IsStreaming
+ });
+ return default;
+ }
+ };
+
+ if (streaming)
+ {
+ await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync();
+ }
+ else
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+
+ // Should be called twice (once per iteration with function calls)
+ Assert.Equal(2, iterationContexts.Count);
+
+ // First iteration
+ Assert.Equal(0, iterationContexts[0].Iteration);
+ Assert.Equal(streaming, iterationContexts[0].IsStreaming);
+ Assert.NotNull(iterationContexts[0].Response);
+
+ // Second iteration
+ Assert.Equal(1, iterationContexts[1].Iteration);
+ Assert.Equal(streaming, iterationContexts[1].IsStreaming);
+
+ // Messages should accumulate
+ Assert.True(iterationContexts[1].Messages.Count > iterationContexts[0].Messages.Count);
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task IterationCompleted_CanTerminateLoop(bool streaming)
+ {
+ int functionInvocations = 0;
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => { functionInvocations++; return "Result"; }, "Func1")]
+ };
+
+ int callCount = 0;
+ using var innerClient = new TestChatClient
+ {
+ GetResponseAsyncCallback = async (contents, chatOptions, ct) =>
+ {
+ await Task.Yield();
+ callCount++;
+
+ // Always return a function call - loop should be terminated by callback
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent($"callId{callCount}", "Func1")])]);
+ },
+ GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) =>
+ {
+ callCount++;
+ var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{callCount}", "Func1")]);
+ return YieldAsync(new ChatResponse(message).ToChatResponseUpdates());
+ }
+ };
+
+ using var client = new FunctionInvokingChatClient(innerClient)
+ {
+ IterationCompleted = (ctx, ct) =>
+ {
+ // Terminate after first iteration
+ if (ctx.Iteration == 0)
+ {
+ ctx.Terminate = true;
+ }
+
+ return default;
+ }
+ };
+
+ if (streaming)
+ {
+ await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync();
+ }
+ else
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+
+ // Only one function invocation should occur (first iteration completes, then terminates)
+ Assert.Equal(1, functionInvocations);
+ Assert.Equal(1, callCount);
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task IterationCompleted_ReceivesCorrectUsageDetails(bool streaming)
+ {
+ UsageDetails? capturedUsage = null;
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => "Result", "Func1")]
+ };
+
+ int callCount = 0;
+ using var innerClient = new TestChatClient
+ {
+ GetResponseAsyncCallback = async (contents, chatOptions, ct) =>
+ {
+ await Task.Yield();
+ callCount++;
+
+ if (callCount == 1)
+ {
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])])
+ {
+ Usage = new UsageDetails { InputTokenCount = 100, OutputTokenCount = 50, TotalTokenCount = 150 }
+ };
+ }
+ else
+ {
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")])
+ {
+ Usage = new UsageDetails { InputTokenCount = 200, OutputTokenCount = 100, TotalTokenCount = 300 }
+ };
+ }
+ },
+ GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) =>
+ {
+ callCount++;
+
+ ChatMessage message;
+ UsageDetails usage;
+ if (callCount == 1)
+ {
+ message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]);
+ usage = new UsageDetails { InputTokenCount = 100, OutputTokenCount = 50, TotalTokenCount = 150 };
+ }
+ else
+ {
+ message = new ChatMessage(ChatRole.Assistant, "Done");
+ usage = new UsageDetails { InputTokenCount = 200, OutputTokenCount = 100, TotalTokenCount = 300 };
+ }
+
+ var updates = new ChatResponse(message) { Usage = usage }.ToChatResponseUpdates().ToList();
+ return YieldAsync(updates);
+ }
+ };
+
+ using var client = new FunctionInvokingChatClient(innerClient)
+ {
+ IterationCompleted = (ctx, ct) =>
+ {
+ capturedUsage = ctx.TotalUsage;
+ return default;
+ }
+ };
+
+ if (streaming)
+ {
+ await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync();
+ }
+ else
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+
+ // Usage should reflect accumulated values from first iteration
+ Assert.NotNull(capturedUsage);
+ Assert.Equal(100, capturedUsage.InputTokenCount);
+ Assert.Equal(50, capturedUsage.OutputTokenCount);
+ Assert.Equal(150, capturedUsage.TotalTokenCount);
+ }
+
+ [Fact]
+ public async Task IterationCompleted_NotCalledWhenNoFunctionCalls()
+ {
+ int callbackInvocations = 0;
+
+ using var innerClient = new TestChatClient
+ {
+ GetResponseAsyncCallback = async (contents, chatOptions, ct) =>
+ {
+ await Task.Yield();
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Just a regular response")]);
+ }
+ };
+
+ using var client = new FunctionInvokingChatClient(innerClient)
+ {
+ IterationCompleted = (ctx, ct) =>
+ {
+ callbackInvocations++;
+ return default;
+ }
+ };
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => "Result", "Func1")]
+ };
+
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+
+ // Callback should not be called since there were no function calls
+ Assert.Equal(0, callbackInvocations);
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task IterationCompleted_ReceivesAllMessages(bool streaming)
+ {
+ IReadOnlyList? capturedMessages = null;
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create(() => "FunctionResult", "Func1")]
+ };
+
+ int callCount = 0;
+ using var innerClient = new TestChatClient
+ {
+ GetResponseAsyncCallback = async (contents, chatOptions, ct) =>
+ {
+ await Task.Yield();
+ callCount++;
+
+ if (callCount == 1)
+ {
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])]);
+ }
+ else
+ {
+ return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")]);
+ }
+ },
+ GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) =>
+ {
+ callCount++;
+
+ ChatMessage message = callCount == 1
+ ? new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])
+ : new ChatMessage(ChatRole.Assistant, "Done");
+
+ return YieldAsync(new ChatResponse(message).ToChatResponseUpdates());
+ }
+ };
+
+ using var client = new FunctionInvokingChatClient(innerClient)
+ {
+ IterationCompleted = (ctx, ct) =>
+ {
+ capturedMessages = ctx.Messages.ToList();
+ return default;
+ }
+ };
+
+ if (streaming)
+ {
+ await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync();
+ }
+ else
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+
+ Assert.NotNull(capturedMessages);
+
+ // Should contain: FunctionCallContent message, FunctionResultContent message
+ Assert.Equal(2, capturedMessages.Count);
+
+ // First message should have FunctionCallContent
+ Assert.Contains(capturedMessages[0].Contents, c => c is FunctionCallContent);
+
+ // Second message should have FunctionResultContent
+ Assert.Contains(capturedMessages[1].Contents, c => c is FunctionResultContent frc && frc.Result?.ToString() == "FunctionResult");
+ }
}