diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs new file mode 100644 index 00000000000..5fdfae1be2b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationIterationContext.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides context for an iteration within the function invocation loop. +/// +/// This context is provided to the +/// callback after each iteration of the function invocation loop completes. +/// +[Experimental(DiagnosticIds.Experiments.AIIterationCompleted, UrlFormat = DiagnosticIds.UrlFormat)] +public class FunctionInvocationIterationContext +{ + /// Gets or sets the current iteration number (0-based). + /// + /// The initial request to the client that passes along the chat contents provided to the + /// is iteration 0. If the client responds with + /// a function call request that is processed, the next iteration is 1, and so on. + /// + public int Iteration { get; set; } + + /// Gets or sets the aggregated usage details across all iterations so far. + /// + /// This includes usage from all inner client calls and is updated after each iteration. + /// May be if the inner client doesn't provide usage information. + /// + public UsageDetails? TotalUsage { get; set; } + + /// Gets or sets the messages accumulated during the function-calling loop. + /// + /// This includes all messages from all iterations, including function call and result contents. + /// + public IReadOnlyList Messages + { + get; + set => field = Throw.IfNull(value); + } = []; + + /// Gets or sets the response from the most recent inner client call. + /// + /// This is the response that triggered the current iteration's function invocations. + /// + public ChatResponse Response + { + get; + set => field = Throw.IfNull(value); + } = new([]); + + /// Gets or sets a value indicating whether to terminate the loop after this iteration. + /// + /// + /// Setting this to will cause the function invocation loop to exit + /// after the current iteration completes. The function calls for this iteration will have + /// already been processed before this callback is invoked. + /// + /// + /// This is similar to setting from within + /// a function, but can be triggered based on external criteria like usage thresholds. + /// + /// + public bool Terminate { get; set; } + + /// Gets or sets a value indicating whether the iteration is part of a streaming operation. + public bool IsStreaming { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index cd449248867..2b4b84e461f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System; @@ -12,6 +12,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Shared.DiagnosticIds; using Microsoft.Shared.Diagnostics; #pragma warning disable CA2213 // Disposable fields should be disposed @@ -254,6 +255,39 @@ public int MaximumConsecutiveErrorsPerRequest /// public bool TerminateOnUnknownCalls { get; set; } + /// Gets or sets a delegate called after each iteration of the function invocation loop completes. + /// + /// + /// This delegate is invoked after each iteration completes - specifically, after function invocations + /// are processed and before the next request is made to the inner client. This timing ensures that: + /// + /// + /// Function calls are not lost mid-execution + /// Function results are available for inspection + /// Aggregated usage information is up-to-date + /// + /// + /// Common use cases include: + /// + /// + /// Token usage monitoring and context compaction triggers + /// Cost limit enforcement + /// Content guardrails and moderation + /// Time limit enforcement + /// Custom iteration-level logging or metrics + /// + /// + /// Set to + /// to stop the loop. The function calls for the current iteration will have already been processed. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to whether the callback is invoked for an in-flight request. + /// + /// + [Experimental(DiagnosticIds.Experiments.AIIterationCompleted, UrlFormat = DiagnosticIds.UrlFormat)] + public Func? IterationCompleted { get; set; } + /// Gets or sets a delegate used to invoke instances. /// /// By default, the protected method is called for each to be invoked, @@ -399,6 +433,27 @@ public override async Task GetResponseAsync( break; } + // Call the iteration completed hook if configured + if (IterationCompleted is { } iterationCallback) + { + var iterationContext = new FunctionInvocationIterationContext + { + Iteration = iteration, + TotalUsage = CloneUsageDetails(totalUsage), + Messages = responseMessages, + Response = response, + IsStreaming = false + }; + + await iterationCallback(iterationContext, cancellationToken); + + if (iterationContext.Terminate) + { + LogIterationCallbackRequestedTermination(iteration); + break; + } + } + UpdateOptionsForNextIteration(ref options, response.ConversationId); } @@ -421,7 +476,10 @@ public override async IAsyncEnumerable GetStreamingResponseA // Create an activity to group them together for better observability. If there's already a genai "invoke_agent" // span that's current, however, we just consider that the group and don't add a new one. using Activity? activity = CurrentActivityIsInvokeAgent ? null : _activitySource?.StartActivity(OpenTelemetryConsts.GenAI.OrchestrateToolsName); - UsageDetails? totalUsage = activity is { IsAllDataRequested: true } ? new() : null; // tracked usage across all turns, to be used for activity purposes + + // Track usage if needed for telemetry or for the iteration callback + bool needsUsageTracking = activity is { IsAllDataRequested: true } || IterationCompleted is not null; + UsageDetails? totalUsage = needsUsageTracking ? new() : null; // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -640,6 +698,27 @@ public override async IAsyncEnumerable GetStreamingResponseA break; } + // Call the iteration completed hook if configured + if (IterationCompleted is { } iterationCallback) + { + var iterationContext = new FunctionInvocationIterationContext + { + Iteration = iteration, + TotalUsage = CloneUsageDetails(totalUsage), + Messages = responseMessages, + Response = response, + IsStreaming = true + }; + + await iterationCallback(iterationContext, cancellationToken); + + if (iterationContext.Terminate) + { + LogIterationCallbackRequestedTermination(iteration); + break; + } + } + UpdateOptionsForNextIteration(ref options, response.ConversationId); } @@ -677,6 +756,31 @@ private static void AddUsageTags(Activity? activity, UsageDetails? usage) } } + /// Creates a defensive copy of usage details to prevent mutation after callback returns. + private static UsageDetails? CloneUsageDetails(UsageDetails? usage) + { + if (usage is null) + { + return null; + } + + var clone = new UsageDetails + { + InputTokenCount = usage.InputTokenCount, + OutputTokenCount = usage.OutputTokenCount, + TotalTokenCount = usage.TotalTokenCount, + CachedInputTokenCount = usage.CachedInputTokenCount, + ReasoningTokenCount = usage.ReasoningTokenCount, + }; + + if (usage.AdditionalCounts is { } additionalCounts) + { + clone.AdditionalCounts = new(additionalCounts); + } + + return clone; + } + /// Prepares the various chat message lists after a response from the inner client and before invoking functions. /// The original messages provided by the caller. /// The messages reference passed to the inner client. @@ -1851,6 +1955,9 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => [LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' requested termination of the processing loop.")] private partial void LogFunctionRequestedTermination(string functionName); + [LoggerMessage(LogLevel.Debug, "Iteration {Iteration}: IterationCompleted callback requested termination.")] + private partial void LogIterationCallbackRequestedTermination(int iteration); + /// Provides information about the invocation of a function call. public sealed class FunctionInvocationResult { diff --git a/src/Shared/DiagnosticIds/DiagnosticIds.cs b/src/Shared/DiagnosticIds/DiagnosticIds.cs index 4cc736a7252..0a8b3e24a11 100644 --- a/src/Shared/DiagnosticIds/DiagnosticIds.cs +++ b/src/Shared/DiagnosticIds/DiagnosticIds.cs @@ -57,6 +57,7 @@ internal static class Experiments internal const string AIResponseContinuations = AIExperiments; internal const string AICodeInterpreter = AIExperiments; internal const string AIRealTime = AIExperiments; + internal const string AIIterationCompleted = AIExperiments; private const string AIExperiments = "MEAI001"; } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 6cd6e857996..539a999212c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System; @@ -3382,4 +3382,379 @@ public async Task LogsFunctionRejected() // ContinuesWithFailingCallsUntilMaximumConsecutiveErrors test which triggers // the threshold condition. The logging call is at line 1078 and will execute // when MaximumConsecutiveErrorsPerRequest is exceeded. + + [Fact] + public void IterationCompleted_NullByDefault() + { + using TestChatClient innerClient = new(); + using FunctionInvokingChatClient client = new(innerClient); + + Assert.Null(client.IterationCompleted); + } + + [Fact] + public void IterationCompleted_Roundtrip() + { + using TestChatClient innerClient = new(); + using FunctionInvokingChatClient client = new(innerClient); + + Assert.Null(client.IterationCompleted); + Func callback = (ctx, ct) => default; + client.IterationCompleted = callback; + Assert.Same(callback, client.IterationCompleted); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task IterationCompleted_InvokedAfterEachIteration(bool streaming) + { + var iterationContexts = new List(); + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + int callCount = 0; + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (contents, chatOptions, ct) => + { + await Task.Yield(); + callCount++; + + if (callCount <= 2) + { + // Return function calls for first two iterations + return new ChatResponse([new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent($"callId{callCount}", "Func1")])]) + { + Usage = new UsageDetails { InputTokenCount = 100 * callCount, OutputTokenCount = 50 * callCount, TotalTokenCount = 150 * callCount } + }; + } + else + { + return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")]) + { + Usage = new UsageDetails { InputTokenCount = 300, OutputTokenCount = 150, TotalTokenCount = 450 } + }; + } + }, + GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) => + { + callCount++; + + ChatMessage message; + UsageDetails? usage; + if (callCount <= 2) + { + message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{callCount}", "Func1")]); + usage = new UsageDetails { InputTokenCount = 100 * callCount, OutputTokenCount = 50 * callCount, TotalTokenCount = 150 * callCount }; + } + else + { + message = new ChatMessage(ChatRole.Assistant, "Done"); + usage = new UsageDetails { InputTokenCount = 300, OutputTokenCount = 150, TotalTokenCount = 450 }; + } + + var updates = new ChatResponse(message) { Usage = usage }.ToChatResponseUpdates().ToList(); + return YieldAsync(updates); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + IterationCompleted = (ctx, ct) => + { + iterationContexts.Add(new FunctionInvocationIterationContext + { + Iteration = ctx.Iteration, + TotalUsage = ctx.TotalUsage is not null + ? new UsageDetails + { + InputTokenCount = ctx.TotalUsage.InputTokenCount, + OutputTokenCount = ctx.TotalUsage.OutputTokenCount, + TotalTokenCount = ctx.TotalUsage.TotalTokenCount + } + : null, + Messages = ctx.Messages.ToList(), + Response = ctx.Response, + IsStreaming = ctx.IsStreaming + }); + return default; + } + }; + + if (streaming) + { + await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync(); + } + else + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + + // Should be called twice (once per iteration with function calls) + Assert.Equal(2, iterationContexts.Count); + + // First iteration + Assert.Equal(0, iterationContexts[0].Iteration); + Assert.Equal(streaming, iterationContexts[0].IsStreaming); + Assert.NotNull(iterationContexts[0].Response); + + // Second iteration + Assert.Equal(1, iterationContexts[1].Iteration); + Assert.Equal(streaming, iterationContexts[1].IsStreaming); + + // Messages should accumulate + Assert.True(iterationContexts[1].Messages.Count > iterationContexts[0].Messages.Count); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task IterationCompleted_CanTerminateLoop(bool streaming) + { + int functionInvocations = 0; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => { functionInvocations++; return "Result"; }, "Func1")] + }; + + int callCount = 0; + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (contents, chatOptions, ct) => + { + await Task.Yield(); + callCount++; + + // Always return a function call - loop should be terminated by callback + return new ChatResponse([new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent($"callId{callCount}", "Func1")])]); + }, + GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) => + { + callCount++; + var message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{callCount}", "Func1")]); + return YieldAsync(new ChatResponse(message).ToChatResponseUpdates()); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + IterationCompleted = (ctx, ct) => + { + // Terminate after first iteration + if (ctx.Iteration == 0) + { + ctx.Terminate = true; + } + + return default; + } + }; + + if (streaming) + { + await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync(); + } + else + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + + // Only one function invocation should occur (first iteration completes, then terminates) + Assert.Equal(1, functionInvocations); + Assert.Equal(1, callCount); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task IterationCompleted_ReceivesCorrectUsageDetails(bool streaming) + { + UsageDetails? capturedUsage = null; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result", "Func1")] + }; + + int callCount = 0; + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (contents, chatOptions, ct) => + { + await Task.Yield(); + callCount++; + + if (callCount == 1) + { + return new ChatResponse([new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])]) + { + Usage = new UsageDetails { InputTokenCount = 100, OutputTokenCount = 50, TotalTokenCount = 150 } + }; + } + else + { + return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")]) + { + Usage = new UsageDetails { InputTokenCount = 200, OutputTokenCount = 100, TotalTokenCount = 300 } + }; + } + }, + GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) => + { + callCount++; + + ChatMessage message; + UsageDetails usage; + if (callCount == 1) + { + message = new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]); + usage = new UsageDetails { InputTokenCount = 100, OutputTokenCount = 50, TotalTokenCount = 150 }; + } + else + { + message = new ChatMessage(ChatRole.Assistant, "Done"); + usage = new UsageDetails { InputTokenCount = 200, OutputTokenCount = 100, TotalTokenCount = 300 }; + } + + var updates = new ChatResponse(message) { Usage = usage }.ToChatResponseUpdates().ToList(); + return YieldAsync(updates); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + IterationCompleted = (ctx, ct) => + { + capturedUsage = ctx.TotalUsage; + return default; + } + }; + + if (streaming) + { + await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync(); + } + else + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + + // Usage should reflect accumulated values from first iteration + Assert.NotNull(capturedUsage); + Assert.Equal(100, capturedUsage.InputTokenCount); + Assert.Equal(50, capturedUsage.OutputTokenCount); + Assert.Equal(150, capturedUsage.TotalTokenCount); + } + + [Fact] + public async Task IterationCompleted_NotCalledWhenNoFunctionCalls() + { + int callbackInvocations = 0; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (contents, chatOptions, ct) => + { + await Task.Yield(); + return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Just a regular response")]); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + IterationCompleted = (ctx, ct) => + { + callbackInvocations++; + return default; + } + }; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "Result", "Func1")] + }; + + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + + // Callback should not be called since there were no function calls + Assert.Equal(0, callbackInvocations); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task IterationCompleted_ReceivesAllMessages(bool streaming) + { + IReadOnlyList? capturedMessages = null; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "FunctionResult", "Func1")] + }; + + int callCount = 0; + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (contents, chatOptions, ct) => + { + await Task.Yield(); + callCount++; + + if (callCount == 1) + { + return new ChatResponse([new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")])]); + } + else + { + return new ChatResponse([new ChatMessage(ChatRole.Assistant, "Done")]); + } + }, + GetStreamingResponseAsyncCallback = (contents, chatOptions, ct) => + { + callCount++; + + ChatMessage message = callCount == 1 + ? new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]) + : new ChatMessage(ChatRole.Assistant, "Done"); + + return YieldAsync(new ChatResponse(message).ToChatResponseUpdates()); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient) + { + IterationCompleted = (ctx, ct) => + { + capturedMessages = ctx.Messages.ToList(); + return default; + } + }; + + if (streaming) + { + await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "test")], options).ToChatResponseAsync(); + } + else + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + + Assert.NotNull(capturedMessages); + + // Should contain: FunctionCallContent message, FunctionResultContent message + Assert.Equal(2, capturedMessages.Count); + + // First message should have FunctionCallContent + Assert.Contains(capturedMessages[0].Contents, c => c is FunctionCallContent); + + // Second message should have FunctionResultContent + Assert.Contains(capturedMessages[1].Contents, c => c is FunctionResultContent frc && frc.Result?.ToString() == "FunctionResult"); + } }