diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index 4704e8b2023b..f042979ea290 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -206,6 +206,10 @@ + + + + diff --git a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs index fee03fc84fa8..d3463eec9d27 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs @@ -1159,6 +1159,10 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden public virtual Azure.Response CreateThreadAndRun(string assistantId, Azure.AI.Agents.Persistent.ThreadAndRunOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> CreateThreadAndRunAsync(string assistantId, Azure.AI.Agents.Persistent.ThreadAndRunOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } + public static partial class PersistentAgentsClientExtensions + { + public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; } + } public static partial class PersistentAgentsExtensions { public static Azure.AI.Agents.Persistent.PersistentAgentsClient GetPersistentAgentsClient(this System.ClientModel.Primitives.ClientConnectionProvider provider) { throw null; } diff --git a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs index d34c42b9eab7..f0887f272c48 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs @@ -1159,6 +1159,10 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden public virtual Azure.Response CreateThreadAndRun(string assistantId, Azure.AI.Agents.Persistent.ThreadAndRunOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> CreateThreadAndRunAsync(string assistantId, Azure.AI.Agents.Persistent.ThreadAndRunOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } + public static partial class PersistentAgentsClientExtensions + { + public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; } + } public static partial class PersistentAgentsExtensions { public static Azure.AI.Agents.Persistent.PersistentAgentsClient GetPersistentAgentsClient(this System.ClientModel.Primitives.ClientConnectionProvider provider) { throw null; } diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Azure.AI.Agents.Persistent.csproj b/sdk/ai/Azure.AI.Agents.Persistent/src/Azure.AI.Agents.Persistent.csproj index f4ca45b7112b..b59d74f0e552 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/src/Azure.AI.Agents.Persistent.csproj +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Azure.AI.Agents.Persistent.csproj @@ -16,6 +16,7 @@ + diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsAdministrationClient.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsAdministrationClient.cs index 3feb39566db3..9efb13472b23 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsAdministrationClient.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsAdministrationClient.cs @@ -20,8 +20,8 @@ namespace Azure.AI.Agents.Persistent public partial class PersistentAgentsAdministrationClient { private static readonly bool s_is_test_run = AppContextSwitchHelper.GetConfigValue( - PersistantAgensConstants.UseOldConnectionString, - PersistantAgensConstants.UseOldConnectionStringEnvVar); + PersistentAgentsConstants.UseOldConnectionString, + PersistentAgentsConstants.UseOldConnectionStringEnvVar); /// The ClientDiagnostics is used to provide tracing support for the client library. internal virtual ClientDiagnostics ClientDiagnostics { get; } // TODO: Replace project connections string by PROJECT_ENDPOINT when 1DP will be available. diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs new file mode 100644 index 000000000000..16a406842241 --- /dev/null +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Azure.AI.Agents.Persistent +{ + /// Represents an for an Azure.AI.Agents.Persistent . + internal partial class PersistentAgentsChatClient : IChatClient + { + /// The name of the chat client provider. + private const string ProviderName = "azure"; + + /// The underlying . + private readonly PersistentAgentsClient? _client; + + /// Metadata for the client. + private readonly ChatClientMetadata? _metadata; + + /// The ID of the agent to use. + private readonly string? _agentId; + + /// The thread ID to use if none is supplied in . + private readonly string? _defaultThreadId; + + /// List of tools associated with the agent. + private IReadOnlyList? _agentTools; + + /// Initializes a new instance of the class for the specified . + public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null) + { + Argument.AssertNotNull(client, nameof(client)); + Argument.AssertNotNullOrWhiteSpace(agentId, nameof(agentId)); + + _client = client; + _agentId = agentId; + _defaultThreadId = defaultThreadId; + + _metadata = new(ProviderName); + } + + protected PersistentAgentsChatClient() { } + + /// + public virtual object? GetService(Type serviceType, object? serviceKey = null) => + serviceType is null ? throw new ArgumentNullException(nameof(serviceType)) : + serviceKey is not null ? null : + serviceType == typeof(ChatClientMetadata) ? _metadata : + serviceType == typeof(PersistentAgentsClient) ? _client : + serviceType.IsInstanceOfType(this) ? this : + null; + + /// + public virtual Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + GetStreamingResponseAsync(messages, options, cancellationToken).ToChatResponseAsync(cancellationToken); + + /// + public virtual async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(messages, nameof(messages)); + + // Extract necessary state from messages and options. + (ThreadAndRunOptions runOptions, List? toolResults) = + await CreateRunOptionsAsync(messages, options, cancellationToken).ConfigureAwait(false); + + // Get the thread ID. + string? threadId = options?.ConversationId ?? _defaultThreadId; + if (threadId is null && toolResults is not null) + { + throw new ArgumentException("No thread ID was provided, but chat messages includes tool results.", nameof(messages)); + } + + // Get any active run ID for this thread. + ThreadRun? threadRun = null; + if (threadId is not null) + { + await foreach (ThreadRun? run in _client!.Runs.GetRunsAsync(threadId, limit: 1, ListSortOrder.Descending, cancellationToken: cancellationToken).ConfigureAwait(false)) + { + if (run.Status != RunStatus.Completed && run.Status != RunStatus.Cancelled && run.Status != RunStatus.Failed && run.Status != RunStatus.Expired) + { + threadRun = run; + break; + } + } + } + + // Submit the request. + IAsyncEnumerable updates; + if (threadRun is not null && + ConvertFunctionResultsToToolOutput(toolResults, out List? toolOutputs) is { } toolRunId && + toolRunId == threadRun.Id) + { + // There's an active run and we have tool results to submit, so submit the results and continue streaming. + // This is going to ignore any additional messages in the run options, as we are only submitting tool outputs, + // but there doesn't appear to be a way to submit additional messages, and having such additional messages is rare. + updates = _client!.Runs.SubmitToolOutputsToStreamAsync(threadRun, toolOutputs, cancellationToken); + } + else + { + if (threadId is null) + { + // No thread ID was provided, so create a new thread. + PersistentAgentThread thread = await _client!.Threads.CreateThreadAsync(runOptions.ThreadOptions.Messages, runOptions.ToolResources, runOptions.Metadata, cancellationToken).ConfigureAwait(false); + runOptions.ThreadOptions.Messages.Clear(); + threadId = thread.Id; + } + else if (threadRun is not null) + { + // There was an active run; we need to cancel it before starting a new run. + await _client!.Runs.CancelRunAsync(threadId, threadRun.Id, cancellationToken).ConfigureAwait(false); + threadRun = null; + } + + // Now create a new run and stream the results. + updates = _client!.Runs.CreateRunStreamingAsync( + threadId: threadId, + agentId: _agentId, + overrideModelName: runOptions.OverrideModelName, + overrideInstructions: runOptions.OverrideInstructions, + additionalInstructions: null, + additionalMessages: runOptions.ThreadOptions.Messages, + overrideTools: runOptions.OverrideTools, + temperature: runOptions.Temperature, + topP: runOptions.TopP, + maxPromptTokens: runOptions.MaxPromptTokens, + maxCompletionTokens: runOptions.MaxCompletionTokens, + truncationStrategy: runOptions.TruncationStrategy, + toolChoice: runOptions.ToolChoice, + responseFormat: runOptions.ResponseFormat, + parallelToolCalls: runOptions.ParallelToolCalls, + metadata: runOptions.Metadata, + cancellationToken); + } + + // Process each update. + string? responseId = null; + await foreach (StreamingUpdate? update in updates.ConfigureAwait(false)) + { + switch (update) + { + case ThreadUpdate tu: + threadId ??= tu.Value.Id; + goto default; + + case RunUpdate ru: + threadId ??= ru.Value.ThreadId; + responseId ??= ru.Value.Id; + + ChatResponseUpdate ruUpdate = new() + { + AuthorName = ru.Value.AssistantId, + ConversationId = threadId, + CreatedAt = ru.Value.CreatedAt, + MessageId = responseId, + ModelId = ru.Value.Model, + RawRepresentation = ru, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + + if (ru.Value.Usage is { } usage) + { + ruUpdate.Contents.Add(new UsageContent(new() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens, + })); + } + + if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + { + ruUpdate.Contents.Add( + new FunctionCallContent( + JsonSerializer.Serialize([ru.Value.Id, toolCallId], AgentsChatClientJsonContext.Default.StringArray), + functionName, + JsonSerializer.Deserialize(rau.FunctionArguments, AgentsChatClientJsonContext.Default.IDictionaryStringObject)!)); + } + + yield return ruUpdate; + break; + + case MessageContentUpdate mcu: + yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) + { + ConversationId = threadId, + MessageId = responseId, + RawRepresentation = mcu, + ResponseId = responseId, + }; + break; + + default: + yield return new ChatResponseUpdate + { + ConversationId = threadId, + MessageId = responseId, + RawRepresentation = update, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + break; + } + } + } + + /// + public void Dispose() { } + + /// + /// Creates the to use for the request and extracts any function result contents + /// that need to be submitted as tool results. + /// + private async ValueTask<(ThreadAndRunOptions RunOptions, List? ToolResults)> CreateRunOptionsAsync( + IEnumerable messages, ChatOptions? options, CancellationToken cancellationToken) + { + // Create the options instance to populate, either a fresh or using one the caller provides. + ThreadAndRunOptions runOptions = + options?.RawRepresentationFactory?.Invoke(this) as ThreadAndRunOptions ?? + new(); + + // Populate the run options from the ChatOptions, if provided. + if (options is not null) + { + runOptions.MaxCompletionTokens ??= options.MaxOutputTokens; + runOptions.OverrideModelName ??= options.ModelId; + runOptions.TopP ??= options.TopP; + runOptions.Temperature ??= options.Temperature; + runOptions.ParallelToolCalls ??= options.AllowMultipleToolCalls; + // Ignored: options.TopK, options.FrequencyPenalty, options.Seed, options.StopSequences + + if (options.Tools is { Count: > 0 } tools) + { + List toolDefinitions = []; + + // If the caller has provided any tool overrides, we'll assume they don't want to use the agent's tools. + // But if they haven't, the only way we can provide our tools is via an override, whereas we'd really like to + // just add them. To handle that, we'll get all of the agent's tools and add them to the override list + // along with our tools. + if (runOptions.OverrideTools is null || !runOptions.OverrideTools.Any()) + { + if (_agentTools is null) + { + PersistentAgent agent = await _client!.Administration.GetAgentAsync(_agentId, cancellationToken).ConfigureAwait(false); + _agentTools = agent.Tools; + } + + toolDefinitions.AddRange(_agentTools); + } + + // The caller can provide tools in the supplied ThreadAndRunOptions. + if (runOptions.OverrideTools is not null) + { + toolDefinitions.AddRange(runOptions.OverrideTools); + } + + // Now add the tools from ChatOptions.Tools. + foreach (AITool tool in tools) + { + switch (tool) + { + case AIFunction aiFunction: + toolDefinitions.Add(new FunctionToolDefinition( + aiFunction.Name, + aiFunction.Description, + BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(aiFunction.JsonSchema, AgentsChatClientJsonContext.Default.JsonElement)))); + break; + + case HostedCodeInterpreterTool: + toolDefinitions.Add(new CodeInterpreterToolDefinition()); + break; + + case HostedWebSearchTool webSearch when webSearch.AdditionalProperties?.TryGetValue("connectionId", out object? connectionId) is true: + toolDefinitions.Add(new BingGroundingToolDefinition(new BingGroundingSearchToolParameters([new BingGroundingSearchConfiguration(connectionId!.ToString())]))); + break; + } + } + + if (toolDefinitions.Count > 0) + { + runOptions.OverrideTools = toolDefinitions; + } + } + + // Store the tool mode, if relevant. + if (runOptions.ToolChoice is null) + { + switch (options.ToolMode) + { + case NoneChatToolMode: + runOptions.ToolChoice = BinaryData.FromString("none"); + break; + + case RequiredChatToolMode required: + runOptions.ToolChoice = required.RequiredFunctionName is string functionName ? + BinaryData.FromString($$"""{"type": "function", "function": {"name": "{{functionName}}"} }""") : + BinaryData.FromString("required"); + break; + } + } + + // Store the response format, if relevant. + if (runOptions.ResponseFormat is null) + { + if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + runOptions.ResponseFormat = jsonFormat.Schema is { } schema ? + BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(new() + { + ["type"] = "json_schema", + ["json_schema"] = JsonSerializer.SerializeToNode(schema, AgentsChatClientJsonContext.Default.JsonNode), + }, AgentsChatClientJsonContext.Default.JsonObject)) : + BinaryData.FromString("""{ "type": "json_object" }"""); + } + } + } + + // Process ChatMessages. System messages are turned into additional instructions. + // All other messages are added 1:1, treating assistant messages as agent messages + // and everything else as user messages. + StringBuilder? instructions = null; + List? functionResults = null; + + runOptions.ThreadOptions ??= new(); + + foreach (ChatMessage chatMessage in messages) + { + List messageContents = []; + + if (chatMessage.Role == ChatRole.System || + chatMessage.Role == new ChatRole("developer")) + { + instructions ??= new(); + foreach (TextContent textContent in chatMessage.Contents.OfType()) + { + _ = instructions.Append(textContent); + } + + continue; + } + + foreach (AIContent content in chatMessage.Contents) + { + switch (content) + { + case TextContent text: + messageContents.Add(new MessageInputTextBlock(text.Text)); + break; + + case DataContent image when image.HasTopLevelMediaType("image"): + messageContents.Add(new MessageInputImageUriBlock(new MessageImageUriParam(image.Uri))); + break; + + case UriContent image when image.HasTopLevelMediaType("image"): + messageContents.Add(new MessageInputImageUriBlock(new MessageImageUriParam(image.Uri.AbsoluteUri))); + break; + + case FunctionResultContent result: + (functionResults ??= []).Add(result); + break; + + default: + if (content.RawRepresentation is MessageInputContentBlock rawContent) + { + messageContents.Add(rawContent); + } + break; + } + } + + if (messageContents.Count > 0) + { + runOptions.ThreadOptions.Messages.Add(new ThreadMessageOptions( + chatMessage.Role == ChatRole.Assistant ? MessageRole.Agent : MessageRole.User, + messageContents)); + } + } + + if (instructions is not null) + { + runOptions.OverrideInstructions = instructions.ToString(); + } + + return (runOptions, functionResults); + } + + /// Convert instances to instances. + /// The tool results to process. + /// The generated list of tool outputs, if any could be created. + /// The run ID associated with the corresponding function call requests. + private static string? ConvertFunctionResultsToToolOutput(List? toolResults, out List? toolOutputs) + { + string? runId = null; + toolOutputs = null; + if (toolResults?.Count > 0) + { + foreach (FunctionResultContent frc in toolResults) + { + // When creating the FunctionCallContext, we created it with a CallId == [runId, callId]. + // We need to extract the run ID and ensure that the ToolOutput we send back to Azure + // is only the call ID. + string[]? runAndCallIDs; + try + { + runAndCallIDs = JsonSerializer.Deserialize(frc.CallId, AgentsChatClientJsonContext.Default.StringArray); + } + catch + { + continue; + } + + if (runAndCallIDs is null || + runAndCallIDs.Length != 2 || + string.IsNullOrWhiteSpace(runAndCallIDs[0]) || // run ID + string.IsNullOrWhiteSpace(runAndCallIDs[1]) || // call ID + (runId is not null && runId != runAndCallIDs[0])) + { + continue; + } + + runId = runAndCallIDs[0]; + (toolOutputs ??= []).Add(new(runAndCallIDs[1], frc.Result?.ToString() ?? string.Empty)); + } + } + + return runId; + } + + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(JsonNode))] + [JsonSerializable(typeof(JsonObject))] + [JsonSerializable(typeof(string[]))] + [JsonSerializable(typeof(IDictionary))] + private sealed partial class AgentsChatClientJsonContext : JsonSerializerContext; + } +} diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs new file mode 100644 index 000000000000..f4320e3b9f47 --- /dev/null +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using Microsoft.Extensions.AI; + +namespace Azure.AI.Agents.Persistent +{ + /// + /// Provides extension methods for . + /// + public static class PersistentAgentsClientExtensions + { + /// + /// Creates an for a client for interacting with a specific agent. + /// + /// The instance to be accessed as an . + /// The unique identifier of the agent with which to interact. + /// + /// An optional existing thread identifier for the chat session. This serves as a default, and may be overridden per call to + /// or via the + /// property. If no thread ID is provided via either mechanism, a new thread will be created for the request. + /// + /// An instance configured to interact with the specified agent and thread. + public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null) => + new PersistentAgentsChatClient(client, agentId, defaultThreadId); + } +} diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistantAgensConstants.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsConstants.cs similarity index 88% rename from sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistantAgensConstants.cs rename to sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsConstants.cs index d0bb692acb80..f93503fc41de 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistantAgensConstants.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsConstants.cs @@ -3,7 +3,7 @@ namespace Azure.AI.Agents.Persistent { - internal class PersistantAgensConstants + internal class PersistentAgentsConstants { public const string UseOldConnectionString = "Azure.AI.Agents.Persistent.Internal.UseConnectionString"; public const string UseOldConnectionStringEnvVar = "_IS_TEST_RUN"; diff --git a/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs new file mode 100644 index 000000000000..68b0343b5927 --- /dev/null +++ b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Azure.Identity; +using Microsoft.Extensions.AI; +using NUnit.Framework; + +namespace Azure.AI.Agents.Persistent.Tests +{ + public class PersistentAgentsChatClientTests : RecordedTestBase + { + private const string AGENT_NAME = "cs_e2e_tests_chat_client"; + private const string STREAMING_CONSTRAINT = "The test framework does not support iteration of stream in Sync mode."; + + private string _agentId; + private string _threadId; + + public PersistentAgentsChatClientTests(bool isAsync) : base(isAsync) + { + TestDiagnostics = false; + } + + #region Enumerations + public enum ChatOptionsTestType + { + Default, + WithTools, + WithResponseFormat + } + #endregion + + [SetUp] + public async Task Setup() + { + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + PersistentAgent agent = await client.Administration.CreateAgentAsync( + model: "gpt-4.1", + name: AGENT_NAME, + instructions: "You are a helpful chat agent." + ); + + _agentId = agent.Id; + + PersistentAgentThread thread = await client.Threads.CreateThreadAsync(); + + _threadId = thread.Id; + } + + [RecordedTest] + public async Task TestGetResponseAsync() + { + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + PersistentAgentsChatClient chatClient = new(client, _agentId, _threadId); + + List messages = []; + messages.Add(new ChatMessage(ChatRole.User, [new TextContent("Hello, tell me a joke")])); + + ChatResponse response = await chatClient.GetResponseAsync(messages); + + Assert.IsNotNull(response); + Assert.IsNotNull(response.Messages); + Assert.GreaterOrEqual(response.Messages.Count, 1); + Assert.AreEqual(ChatRole.Assistant, response.Messages[0].Role); + Assert.IsNotNull(response.ConversationId); + } + + [RecordedTest] + [TestCase(ChatOptionsTestType.Default)] + [TestCase(ChatOptionsTestType.WithTools)] + [TestCase(ChatOptionsTestType.WithResponseFormat)] + public async Task TestGetStreamingResponseAsync(ChatOptionsTestType optionsType) + { + if (!IsAsync) + { + Assert.Inconclusive(STREAMING_CONSTRAINT); + } + + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + PersistentAgentsChatClient chatClient = new(client, _agentId, _threadId); + + ChatOptions options = null; + if (optionsType == ChatOptionsTestType.WithTools) + { + options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "It's 80 degrees and sunny.", "GetWeather")], + ToolMode = ChatToolMode.Auto + }; + } + else if (optionsType == ChatOptionsTestType.WithResponseFormat) + { + options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json + }; + } + + List messages = [new ChatMessage(ChatRole.User, [new TextContent("What's the weather like? Respond in JSON.")])]; + bool receivedUpdate = false; + + await foreach (ChatResponseUpdate update in chatClient.GetStreamingResponseAsync(messages, options)) + { + Assert.IsNotNull(update); + Assert.IsNotNull(update.ConversationId); + if (update.Contents.Any(c => (optionsType == ChatOptionsTestType.WithTools && c is FunctionCallContent) || c is TextContent)) + { + receivedUpdate = true; + } + } + + Assert.IsTrue(receivedUpdate, "No valid streaming update received."); + } + + [RecordedTest] + public async Task TestSubmitToolOutputs() + { + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + FunctionToolDefinition tool = new( + name: "GetFavouriteWord", + description: "Gets the favourite word of a person.", + parameters: BinaryData.FromObjectAsJson(new + { + Type = "object", + Properties = new { Name = new { Type = "string", Description = "Person's name" } }, + Required = new[] { "name" } + }, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }) + ); + + PersistentAgent agent = await client.Administration.CreateAgentAsync( + model: "gpt-4.1", + name: AGENT_NAME, + instructions: "Use the provided function to answer questions.", + tools: [tool] + ); + + PersistentAgentThread thread = await client.Threads.CreateThreadAsync(); + + PersistentAgentsChatClient chatClient = new(client, agent.Id, thread.Id); + + await client.Messages.CreateMessageAsync(thread.Id, MessageRole.User, "What's Mike's favourite word?"); + + ThreadRun run = await client.Runs.CreateRunAsync(thread.Id, agent.Id); + do + { + await Task.Delay(500); + run = await client.Runs.GetRunAsync(thread.Id, run.Id); + } while (run.Status == RunStatus.Queued || run.Status == RunStatus.InProgress); + + if (run.Status == RunStatus.RequiresAction && run.RequiredAction is SubmitToolOutputsAction action) + { + List messages = []; + foreach (RequiredToolCall toolCall in action.ToolCalls) + { + if (toolCall is RequiredFunctionToolCall functionCall) + { + string[] callIds = [run.Id, functionCall.Id]; + messages.Add(new ChatMessage(ChatRole.Tool, [new FunctionResultContent(JsonSerializer.Serialize(callIds), "bar")])); + } + } + + ChatResponse response = await chatClient.GetResponseAsync(messages, new ChatOptions { ConversationId = thread.Id }); + Assert.IsNotNull(response); + Assert.GreaterOrEqual(response.Messages.Count, 1); + Assert.IsTrue(response.Messages[0].Contents.Any(c => c is TextContent tc && tc.Text.Contains("bar"))); + } + else + { + Assert.Fail("Run did not require tool action."); + } + } + + [RecordedTest] + public async Task TestChatOptionsTools() + { + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + + FunctionToolDefinition wordTool = new( + name: "GetFavouriteWord", + description: "Gets the favourite word of a person.", + parameters: BinaryData.FromObjectAsJson(new + { + Type = "object", + Properties = new { Name = new { Type = "string", Description = "Person's name" } }, + Required = new[] { "name" } + }, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }) + ); + + // First tool is registered on agent level. + PersistentAgent agent = await client.Administration.CreateAgentAsync( + model: "gpt-4.1", + name: AGENT_NAME, + instructions: "Use the provided function to answer questions.", + tools: [wordTool] + ); + + // Second tool is registered per request. + ChatOptions chatOptions = new() + { + Tools = [AIFunctionFactory.Create(() => "It's 80 degrees and sunny.", "GetWeather")], + ToolMode = ChatToolMode.Auto + }; + + PersistentAgentThread thread = await client.Threads.CreateThreadAsync(); + + PersistentAgentsChatClient chatClient = new(client, agent.Id, thread.Id); + + List messages = []; + messages.Add(new ChatMessage(ChatRole.User, [new TextContent("What's Mike's favourite word and current weather in Seattle?")])); + + ChatResponse response = await chatClient.GetResponseAsync(messages, chatOptions); + + Assert.IsNotNull(response); + Assert.IsNotNull(response.Messages); + Assert.GreaterOrEqual(response.Messages.Count, 1); + Assert.AreEqual(ChatRole.Assistant, response.Messages[0].Role); + + List functionNames = [.. response.Messages[0].Contents + .OfType() + .Select(c => c.Name)]; + + Assert.Contains("GetFavouriteWord", functionNames); + Assert.Contains("GetWeather", functionNames); + } + + [RecordedTest] + public void TestGetService() + { + using IDisposable _ = SetTestSwitch(); + PersistentAgentsClient client = GetClient(); + PersistentAgentsChatClient chatClient = new(client, _agentId, _threadId); + + Assert.IsNotNull(chatClient.GetService(typeof(ChatClientMetadata))); + Assert.IsNotNull(chatClient.GetService(typeof(PersistentAgentsClient))); + Assert.IsNotNull(chatClient.GetService(typeof(PersistentAgentsChatClient))); + Assert.IsNull(chatClient.GetService(typeof(string))); + Assert.Throws(() => chatClient.GetService(null)); + } + + #region Helpers + private class CompositeDisposable : IDisposable + { + private readonly List _disposables = []; + + public CompositeDisposable(params IDisposable[] disposables) + { + for (int i = 0; i < disposables.Length; i++) + { + _disposables.Add(disposables[i]); + } + } + + public void Dispose() + { + foreach (IDisposable d in _disposables) + { + d?.Dispose(); + } + } + } + + private static CompositeDisposable SetTestSwitch() + { + return new CompositeDisposable( + new TestAppContextSwitch(new() + { + { PersistentAgentsConstants.UseOldConnectionString, true.ToString() } + })); + } + + private PersistentAgentsClient GetClient() + { + var connectionString = TestEnvironment.PROJECT_CONNECTION_STRING; + PersistentAgentsAdministrationClientOptions opts = InstrumentClientOptions(new PersistentAgentsAdministrationClientOptions()); + PersistentAgentsAdministrationClient admClient; + + if (Mode == RecordedTestMode.Playback) + { + admClient = InstrumentClient(new PersistentAgentsAdministrationClient(connectionString, new MockCredential(), opts)); + return new PersistentAgentsClient(admClient); + } + + var cli = Environment.GetEnvironmentVariable("USE_CLI_CREDENTIAL"); + if (!string.IsNullOrEmpty(cli) && string.Compare(cli, "true", StringComparison.OrdinalIgnoreCase) == 0) + { + admClient = InstrumentClient(new PersistentAgentsAdministrationClient(connectionString, new AzureCliCredential(), opts)); + } + else + { + admClient = InstrumentClient(new PersistentAgentsAdministrationClient(connectionString, new DefaultAzureCredential(), opts)); + } + + return new PersistentAgentsClient(admClient); + } + + #endregion + + #region Cleanup + [TearDown] + public void Cleanup() + { + DirectoryInfo tempDir = new(Path.Combine(Path.GetTempPath(), "cs_e2e_temp_dir")); + if (tempDir.Exists) + { + tempDir.Delete(true); + } + + if (Mode == RecordedTestMode.Playback) + return; + + PersistentAgentsClient client; + var cli = Environment.GetEnvironmentVariable("USE_CLI_CREDENTIAL"); + if (!string.IsNullOrEmpty(cli) && string.Compare(cli, "true", StringComparison.OrdinalIgnoreCase) == 0) + { + client = new PersistentAgentsClient(TestEnvironment.PROJECT_ENDPOINT, new AzureCliCredential()); + } + else + { + client = new PersistentAgentsClient(TestEnvironment.PROJECT_ENDPOINT, new DefaultAzureCredential()); + } + + // Remove agent + Pageable agents = client.Administration.GetAgents(); + foreach (PersistentAgent agent in agents) + { + if (agent.Name.StartsWith(AGENT_NAME)) + client.Administration.DeleteAgent(agent.Id); + } + + // Remove thread + Pageable threads = client.Threads.GetThreads(); + foreach (PersistentAgentThread thread in threads) + { + if (thread.Id == _threadId) + client.Threads.DeleteThread(thread.Id); + } + } + #endregion + } +} diff --git a/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsTests.cs b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsTests.cs index cead4a3bb908..0a5152c3080c 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsTests.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsTests.cs @@ -56,7 +56,7 @@ public IDisposable SetTestSwitch() { return new CompositeDisposable( new TestAppContextSwitch(new() { - { PersistantAgensConstants.UseOldConnectionString, true.ToString() }, + { PersistentAgentsConstants.UseOldConnectionString, true.ToString() }, })); } diff --git a/sdk/ai/Azure.AI.Agents.Persistent/tests/Samples/Sample_PersistentAgents_As_IChatClient.cs b/sdk/ai/Azure.AI.Agents.Persistent/tests/Samples/Sample_PersistentAgents_As_IChatClient.cs new file mode 100644 index 000000000000..09f2bce0ccc1 --- /dev/null +++ b/sdk/ai/Azure.AI.Agents.Persistent/tests/Samples/Sample_PersistentAgents_As_IChatClient.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable disable + +using System; +using System.Linq; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using Microsoft.Extensions.AI; +using NUnit.Framework; + +namespace Azure.AI.Agents.Persistent.Tests; + +public partial class Sample_PersistentAgents_As_IChatClient : SamplesBase +{ + [Test] + [AsyncOnly] + public async Task PersistentAgentsAsIChatClient() + { + #region Snippet:PersistentAgentsAsIChatClient_CreateClient +#if SNIPPET + var projectEndpoint = System.Environment.GetEnvironmentVariable("PROJECT_ENDPOINT"); + var modelDeploymentName = System.Environment.GetEnvironmentVariable("MODEL_DEPLOYMENT_NAME"); +#else + var projectEndpoint = TestEnvironment.PROJECT_ENDPOINT; + var modelDeploymentName = TestEnvironment.MODELDEPLOYMENTNAME; +#endif + PersistentAgentsClient client = new(projectEndpoint, new DefaultAzureCredential()); + #endregion + #region Snippet:PersistentAgentsAsIChatClient_CreateAgentAsIChatClient + PersistentAgent agent = await client.Administration.CreateAgentAsync( + model: modelDeploymentName, + name: "my-agent", + instructions: "You are a helpful agent."); + + PersistentAgentThread thread = await client.Threads.CreateThreadAsync(); + + IChatClient chatClient = client.AsIChatClient(agent.Id, thread.Id); + #endregion + #region Snippet:PersistentAgentsAsIChatClient_GetResponseAsync + ChatResponse response = await chatClient.GetResponseAsync([new ChatMessage(ChatRole.User, [new TextContent("Hello, tell me a joke")])]); + + Console.WriteLine(string.Join(Environment.NewLine, response.Messages.Select(c => c.Text))); + #endregion + #region Snippet:PersistentAgentsAsIChatClient_Cleanup + await client.Threads.DeleteThreadAsync(thread.Id); + await client.Administration.DeleteAgentAsync(agent.Id); + #endregion + } +}