Skip to content

.Net: Update Bedrock Agent to support conversation state and improve tests. #11737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions dotnet/src/Agents/AzureAI/AzureAIAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ public async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> InvokeAsync
azureAIAgentThread.StateParts.RegisterPlugins(kernel);
#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext);
var extensionsContextOptions = options is null ?
new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } :
new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext };
new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } :
new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions };

var invokeResults = ActivityExtensions.RunWithActivityAsync(
() => ModelDiagnostics.StartAgentInvocationActivity(this.Id, this.GetDisplayName(), this.Description),
Expand Down Expand Up @@ -323,9 +324,10 @@ public async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> In
azureAIAgentThread.StateParts.RegisterPlugins(kernel);
#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext);
var extensionsContextOptions = options is null ?
new AzureAIAgentInvokeOptions() { AdditionalInstructions = extensionsContext } :
new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = extensionsContext };
new AzureAIAgentInvokeOptions() { AdditionalInstructions = mergedAdditionalInstructions } :
new AzureAIAgentInvokeOptions(options) { AdditionalInstructions = mergedAdditionalInstructions };

#pragma warning disable CS0618 // Type or member is obsolete
// Invoke the Agent with the thread that we already added our message to.
Expand Down Expand Up @@ -461,4 +463,19 @@ protected override async Task<AgentChannel> RestoreChannelAsync(string channelSt

return new AzureAIChannel(this.Client, thread.Id);
}

private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) =>
(optionsAdditionalInstructions, extensionsContext) switch
{
(string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat(
ai,
Environment.NewLine,
Environment.NewLine,
ec),
(string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec,
(string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai,
(null, string ec) => ec,
(string ai, null) => ai,
_ => string.Empty
};
}
1 change: 1 addition & 0 deletions dotnet/src/Agents/Bedrock/Agents.Bedrock.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

<ItemGroup>
<ProjectReference Include="..\Abstractions\Agents.Abstractions.csproj" />
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
</ItemGroup>

<ItemGroup>
Expand Down
53 changes: 49 additions & 4 deletions dotnet/src/Agents/Bedrock/BedrockAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Amazon.BedrockAgent;
Expand Down Expand Up @@ -117,11 +118,18 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
() => new BedrockAgentThread(this.RuntimeClient),
cancellationToken).ConfigureAwait(false);

// Get the conversation state extensions context contributions and register plugins from the extensions.
#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false);
#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

// Ensure that the last message provided is a user message
string? message = this.ExtractUserMessage(messages.Last());

// Build session state with conversation history if needed
// Build session state with conversation history and override instructions if needed
SessionState sessionState = this.ExtractSessionState(messages);
var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext);
sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions };

// Configure the agent request with the provided options
var invokeAgentRequest = this.ConfigureAgentRequest(options, () =>
Expand Down Expand Up @@ -346,11 +354,18 @@ public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageCon
() => new BedrockAgentThread(this.RuntimeClient),
cancellationToken).ConfigureAwait(false);

// Get the conversation state extensions context contributions and register plugins from the extensions.
#pragma warning disable SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
var extensionsContext = await bedrockThread.StateParts.OnModelInvokeAsync(messages, cancellationToken).ConfigureAwait(false);
#pragma warning restore SKEXP0110 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.

// Ensure that the last message provided is a user message
string? message = this.ExtractUserMessage(messages.Last());

// Build session state with conversation history if needed
// Build session state with conversation history and override instructions if needed
SessionState sessionState = this.ExtractSessionState(messages);
var mergedAdditionalInstructions = MergeAdditionalInstructions(options?.AdditionalInstructions, extensionsContext);
sessionState.PromptSessionAttributes = new() { ["AdditionalInstructions"] = mergedAdditionalInstructions };

// Configure the agent request with the provided options
var invokeAgentRequest = this.ConfigureAgentRequest(options, () =>
Expand Down Expand Up @@ -639,20 +654,35 @@ private IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingInternalAsy

async IAsyncEnumerable<StreamingChatMessageContent> InvokeInternal()
{
var combinedResponseMessageBuilder = new StringBuilder();
StreamingChatMessageContent? lastMessage = null;

// The Bedrock agent service has the same API for both streaming and non-streaming responses.
// We are invoking the same method as the non-streaming response with the streaming configuration set,
// and converting the chat message content to streaming chat message content.
await foreach (var chatMessageContent in this.InternalInvokeAsync(invokeAgentRequest, arguments, cancellationToken).ConfigureAwait(false))
{
await this.NotifyThreadOfNewMessage(thread, chatMessageContent, cancellationToken).ConfigureAwait(false);
yield return new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content)
lastMessage = new StreamingChatMessageContent(chatMessageContent.Role, chatMessageContent.Content)
{
AuthorName = chatMessageContent.AuthorName,
ModelId = chatMessageContent.ModelId,
InnerContent = chatMessageContent.InnerContent,
Metadata = chatMessageContent.Metadata,
};
yield return lastMessage;

combinedResponseMessageBuilder.Append(chatMessageContent.Content);
}

// Build a combined message containing the text from all response parts
// to send to the thread.
var combinedMessage = new ChatMessageContent(AuthorRole.Assistant, combinedResponseMessageBuilder.ToString())
{
AuthorName = lastMessage?.AuthorName,
ModelId = lastMessage?.ModelId,
Metadata = lastMessage?.Metadata,
};
await this.NotifyThreadOfNewMessage(thread, combinedMessage, cancellationToken).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -726,4 +756,19 @@ private Amazon.BedrockAgentRuntime.ConversationRole MapBedrockAgentUser(AuthorRo
}

#endregion

private static string MergeAdditionalInstructions(string? optionsAdditionalInstructions, string extensionsContext) =>
(optionsAdditionalInstructions, extensionsContext) switch
{
(string ai, string ec) when !string.IsNullOrWhiteSpace(ai) && !string.IsNullOrWhiteSpace(ec) => string.Concat(
ai,
Environment.NewLine,
Environment.NewLine,
ec),
(string ai, string ec) when string.IsNullOrWhiteSpace(ai) => ec,
(string ai, string ec) when string.IsNullOrWhiteSpace(ec) => ai,
(null, string ec) => ec,
(string ai, null) => ai,
_ => string.Empty
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,40 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Agen

public class BedrockAgentThreadTests() : AgentThreadTests(() => new BedrockAgentFixture())
{
[Fact(Skip = "Manual verification only")]
private const string ManualVerificationSkipReason = "This test is for manual verification.";

[Fact(Skip = ManualVerificationSkipReason)]
public override Task OnNewMessageWithServiceFailureThrowsAgentOperationExceptionAsync()
{
// The Bedrock agent does not support writing to a thread with OnNewMessage.
return Task.CompletedTask;
}

[Fact(Skip = "Manual verification only")]
[Fact(Skip = ManualVerificationSkipReason)]
public override Task DeletingThreadTwiceDoesNotThrowAsync()
{
return base.DeletingThreadTwiceDoesNotThrowAsync();
}

[Fact(Skip = "Manual verification only")]
[Fact(Skip = ManualVerificationSkipReason)]
public override Task UsingThreadAfterDeleteThrowsAsync()
{
return base.UsingThreadAfterDeleteThrowsAsync();
}

[Fact(Skip = "Manual verification only")]
[Fact(Skip = ManualVerificationSkipReason)]
public override Task DeleteThreadBeforeCreateThrowsAsync()
{
return base.DeleteThreadBeforeCreateThrowsAsync();
}

[Fact(Skip = "Manual verification only")]
[Fact(Skip = ManualVerificationSkipReason)]
public override Task UsingThreadBeforeCreateCreatesAsync()
{
return base.UsingThreadBeforeCreateCreatesAsync();
}

[Fact(Skip = "Manual verification only")]
[Fact(Skip = ManualVerificationSkipReason)]
public override Task DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync()
{
return base.DeleteThreadWithServiceFailureThrowsAgentOperationExceptionAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,39 @@ public virtual async Task StatePartReceivesMessagesFromAgentAsync()
}
}

[Fact]
public virtual async Task StatePartReceivesMessagesFromAgentWhenStreamingAsync()
{
// Arrange
var mockStatePart = new Mock<ConversationStatePart>() { CallBase = true };
mockStatePart.Setup(x => x.OnNewMessageAsync(It.IsAny<string>(), It.IsAny<ChatMessage>(), It.IsAny<CancellationToken>()));

var agent = this.Fixture.Agent;

var agentThread = this.Fixture.GetNewThread();

try
{
agentThread.StateParts.Add(mockStatePart.Object);

// Act
var inputMessage = "What is the capital of France?";
var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread);
var results = await asyncResults1.ToListAsync();

// Assert
var responseMessage = string.Concat(results.Select(x => x.Message.Content));
Assert.Contains("Paris", responseMessage);
mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny<string>(), It.Is<ChatMessage>(cm => cm.Text == inputMessage), It.IsAny<CancellationToken>()), Times.Once);
mockStatePart.Verify(x => x.OnNewMessageAsync(It.IsAny<string>(), It.Is<ChatMessage>(cm => cm.Text == responseMessage), It.IsAny<CancellationToken>()), Times.Once);
}
finally
{
// Cleanup
await this.Fixture.DeleteThread(agentThread);
}
}

[Fact]
public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync()
{
Expand Down Expand Up @@ -83,6 +116,37 @@ public virtual async Task StatePartPreInvokeStateIsUsedByAgentAsync()
}
}

[Fact]
public virtual async Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync()
{
// Arrange
var mockStatePart = new Mock<ConversationStatePart>() { CallBase = true };
mockStatePart.Setup(x => x.OnModelInvokeAsync(It.IsAny<ICollection<ChatMessage>>(), It.IsAny<CancellationToken>())).ReturnsAsync("User name is Caoimhe");

var agent = this.Fixture.Agent;

var agentThread = this.Fixture.GetNewThread();

try
{
agentThread.StateParts.Add(mockStatePart.Object);

// Act
var inputMessage = "What is my name?.";
var asyncResults1 = agent.InvokeStreamingAsync(inputMessage, agentThread);
var results = await asyncResults1.ToListAsync();

// Assert
var responseMessage = string.Concat(results.Select(x => x.Message.Content));
Assert.Contains("Caoimhe", responseMessage);
}
finally
{
// Cleanup
await this.Fixture.DeleteThread(agentThread);
}
}

public Task InitializeAsync()
{
this._agentFixture = createAgentFixture();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Threading.Tasks;
using Xunit;

namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.AgentWithStatePartConformance;

public class BedrockAgentWithStatePartTests() : AgentWithStatePartTests<BedrockAgentFixture>(() => new BedrockAgentFixture())
{
private const string ManualVerificationSkipReason = "This test is for manual verification.";

[Fact(Skip = ManualVerificationSkipReason)]
public override Task StatePartReceivesMessagesFromAgentAsync()
{
return base.StatePartReceivesMessagesFromAgentAsync();
}

[Fact(Skip = ManualVerificationSkipReason)]
public override Task StatePartReceivesMessagesFromAgentWhenStreamingAsync()
{
return base.StatePartReceivesMessagesFromAgentWhenStreamingAsync();
}

[Fact(Skip = ManualVerificationSkipReason)]
public override Task StatePartPreInvokeStateIsUsedByAgentAsync()
{
return base.StatePartPreInvokeStateIsUsedByAgentAsync();
}

[Fact(Skip = ManualVerificationSkipReason)]
public override Task StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync()
{
return base.StatePartPreInvokeStateIsUsedByAgentWhenStreamingAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance;

internal sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable
public sealed class BedrockAgentFixture : AgentFixture, IAsyncDisposable
{
private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
.AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
Expand All @@ -11,7 +10,9 @@ namespace SemanticKernel.IntegrationTests.Agents.CommonInterfaceConformance.Invo

public class BedrockAgentInvokeTests() : InvokeTests(() => new BedrockAgentFixture())
{
[Fact(Skip = "This test is for manual verification.")]
private const string ManualVerificationSkipReason = "This test is for manual verification.";

[Fact(Skip = ManualVerificationSkipReason)]
public override async Task ConversationMaintainsHistoryAsync()
{
var q1 = "What is the capital of France.";
Expand All @@ -32,7 +33,7 @@ public override async Task ConversationMaintainsHistoryAsync()
//Assert.Contains("Eiffel", result2.Message.Content);
}

[Fact(Skip = "This test is for manual verification.")]
[Fact(Skip = ManualVerificationSkipReason)]
public override async Task InvokeReturnsResultAsync()
{
var agent = this.Fixture.Agent;
Expand All @@ -48,7 +49,7 @@ public override async Task InvokeReturnsResultAsync()
//Assert.Contains("Paris", firstResult.Message.Content);
}

[Fact(Skip = "This test is for manual verification.")]
[Fact(Skip = ManualVerificationSkipReason)]
public override async Task InvokeWithoutThreadCreatesThreadAsync()
{
var agent = this.Fixture.Agent;
Expand All @@ -66,20 +67,19 @@ public override async Task InvokeWithoutThreadCreatesThreadAsync()
await this.Fixture.DeleteThread(firstResult.Thread);
}

[Fact(Skip = "This test is for manual verification.")]
[Fact(Skip = "The BedrockAgent does not support invoking without a message.")]
public override Task InvokeWithoutMessageCreatesThreadAsync()
{
// The Bedrock agent does not support invoking without a message.
return Assert.ThrowsAsync<InvalidOperationException>(async () => await base.InvokeWithoutThreadCreatesThreadAsync());
return base.InvokeWithoutMessageCreatesThreadAsync();
}

[Fact(Skip = "This test is for manual verification.")]
[Fact(Skip = "The BedrockAgent does not yet support plugins")]
public override Task MultiStepInvokeWithPluginAndArgOverridesAsync()
{
return base.MultiStepInvokeWithPluginAndArgOverridesAsync();
}

[Fact(Skip = "This test is for manual verification.")]
[Fact(Skip = "The BedrockAgent does not yet support plugins")]
public override Task InvokeWithPluginNotifiesForAllMessagesAsync()
{
return base.InvokeWithPluginNotifiesForAllMessagesAsync();
Expand Down
Loading