From ddd1bb8e39bd34442c25076d6b0c7f43b00dfde2 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 15 Jul 2025 15:18:57 +0200 Subject: [PATCH 1/7] Add tests --- .../tests/test_chat_generator.py | 51 +++++++++++++ .../tests/test_chat_generator_utils.py | 71 +++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ee651d72b..97e35f9b4 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -364,6 +364,57 @@ def test_live_run_with_multi_tool_calls(self, model_name, tools): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() + @pytest.mark.skip(reason="This fails because we are missing the reasoning content in the second round of messages") + def test_live_run_with_tool_call_and_thinking(self, tools): + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator( + model="arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + tools=tools, + generation_kwargs={ + "maxTokens": 8192, + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0].id, "Tool call does not contain value for 'id' key" + assert tool_calls[0].tool_name == "weather" + assert tool_calls[0].arguments["city"] == "Paris" + assert tool_call_message.meta["finish_reason"] == "tool_use" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [ + ChatMessage.from_tool(tool_result="22° C", origin=tool_call) for tool_call in tool_calls + ] + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) def test_live_run_with_multi_tool_calls_streaming(self, model_name, tools): """ diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index 83ef6f873..fc982e6d6 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -334,6 +334,77 @@ def test_extract_replies_from_multi_tool_response(self, mock_boto3_session): ) assert replies[0] == expected_message + def test_extract_replies_from_multi_tool_response_with_thinking(self, mock_boto3_session): + model = "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" + response_body = { + "ResponseMetadata": { + "RequestId": "d7be81a1-5d37-40fe-936a-7c96e850cdda", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Tue, 15 Jul 2025 12:49:56 GMT", + "content-type": "application/json", + "content-length": "1107", + "connection": "keep-alive", + "x-amzn-requestid": "d7be81a1-5d37-40fe-936a-7c96e850cdda", + }, + "RetryAttempts": 0, + }, + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "The user wants to know the weather in Paris. I have a `weather` function " + "available that can provide this information. \n\nRequired parameters for " + "the weather function:\n- city: The city to get the weather for\n\nIn this " + 'case, the user has clearly specified "Paris" as the city, so I have all ' + "the required information to make the function call.", + "signature": "...", + } + } + }, + {"text": "I'll check the current weather in Paris for you."}, + { + "toolUse": { + "toolUseId": "tooluse_iUqy8-ypSByLK5zFkka8uA", + "name": "weather", + "input": {"city": "Paris"}, + } + }, + ], + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 412, + "outputTokens": 146, + "totalTokens": 558, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, + "metrics": {"latencyMs": 4811}, + } + replies = _parse_completion_response(response_body, model) + + # TODO We are missing the reasoning content in the ChatMessage + expected_message = ChatMessage.from_assistant( + text="I'll check the current weather in Paris for you.", + tool_calls=[ + ToolCall( + tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA" + ), + ], + meta={ + "model": "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "index": 0, + "finish_reason": "tool_use", + "usage": {"prompt_tokens": 412, "completion_tokens": 146, "total_tokens": 558}, + }, + ) + assert replies[0] == expected_message + def test_process_streaming_response_one_tool_call(self, mock_boto3_session): """ Test that process_streaming_response correctly handles streaming events and accumulates responses From c1418a922a0758734097e6dc75966be0979dcbcb Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 15 Jul 2025 15:21:49 +0200 Subject: [PATCH 2/7] Formatting --- .../amazon_bedrock/tests/test_chat_generator_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index fc982e6d6..80334c6d5 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -392,9 +392,7 @@ def test_extract_replies_from_multi_tool_response_with_thinking(self, mock_boto3 expected_message = ChatMessage.from_assistant( text="I'll check the current weather in Paris for you.", tool_calls=[ - ToolCall( - tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA" - ), + ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA"), ], meta={ "model": "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", From 21b8860772dfbfebd890d4ba1a31c9260b537d79 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 4 Aug 2025 13:42:02 +0200 Subject: [PATCH 3/7] Update reasoning_content to be saved when non-streaming --- .../generators/amazon_bedrock/chat/utils.py | 10 ++- .../tests/test_chat_generator.py | 53 ++++++++++++- .../tests/test_chat_generator_utils.py | 79 ++++++++++++++++--- 3 files changed, 130 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index 238ef1ae7..97aa351d2 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -221,7 +221,6 @@ def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], return system_prompts, repaired_bedrock_formatted_messages -# Bedrock to Haystack util method def _parse_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]: """ Parse a Bedrock API response into Haystack ChatMessage objects. @@ -267,6 +266,12 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis arguments=tool_use.get("input", {}), ) tool_calls.append(tool_call) + elif "reasoningContent" in content_block: + reasoning_content = content_block["reasoningContent"] + # If reasoningText is present, replace it with reasoning_text + if "reasoningText" in reasoning_content: + reasoning_content["reasoning_text"] = reasoning_content.pop("reasoningText") + base_meta.update({"reasoning_content": reasoning_content}) # Create a single ChatMessage with combined text and tool calls replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta)) @@ -494,11 +499,14 @@ def _parse_streaming_response( :param component_info: ComponentInfo object :return: List of ChatMessage objects """ + aws_chunks = [] chunks: List[StreamingChunk] = [] for event in response_stream: + aws_chunks.append(event) streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model, component_info=component_info) streaming_callback(streaming_chunk) chunks.append(streaming_chunk) + print(aws_chunks) replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)] return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index a694efdc7..b754c1e96 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -283,7 +283,6 @@ def test_prepare_request_params_tool_config(self, top_song_tool_config, mock_bot assert request_params["toolConfig"] == top_song_tool_config -# In the CI, those tests are skipped if AWS Authentication fails @pytest.mark.integration class TestAmazonBedrockChatGeneratorInference: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @@ -400,7 +399,6 @@ def test_live_run_with_multi_tool_calls(self, model_name, tools): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() - @pytest.mark.skip(reason="This fails because we are missing the reasoning content in the second round of messages") def test_live_run_with_tool_call_and_thinking(self, tools): initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] component = AmazonBedrockChatGenerator( @@ -451,6 +449,57 @@ def test_live_run_with_tool_call_and_thinking(self, tools): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() + def test_live_run_with_tool_call_and_thinking_streaming(self, tools): + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator( + model="arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + tools=tools, + generation_kwargs={ + "maxTokens": 8192, + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, + streaming_callback=print_streaming_chunk, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0].id, "Tool call does not contain value for 'id' key" + assert tool_calls[0].tool_name == "weather" + assert tool_calls[0].arguments["city"] == "Paris" + assert tool_call_message.meta["finish_reason"] == "tool_use" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [ + ChatMessage.from_tool(tool_result="22° C", origin=tool_call) for tool_call in tool_calls + ] + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) def test_live_run_with_multi_tool_calls_streaming(self, model_name, tools): """ diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index c125ea729..ef298e985 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -1,6 +1,7 @@ import base64 import pytest +from unittest.mock import ANY from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, ImageContent, StreamingChunk, ToolCall from haystack.tools import Tool @@ -380,7 +381,7 @@ def test_extract_replies_from_multi_tool_response(self, mock_boto3_session): ) assert replies[0] == expected_message - def test_extract_replies_from_multi_tool_response_with_thinking(self, mock_boto3_session): + def test_extract_replies_from_one_tool_response_with_thinking(self, mock_boto3_session): model = "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" response_body = { "ResponseMetadata": { @@ -434,7 +435,6 @@ def test_extract_replies_from_multi_tool_response_with_thinking(self, mock_boto3 } replies = _parse_completion_response(response_body, model) - # TODO We are missing the reasoning content in the ChatMessage expected_message = ChatMessage.from_assistant( text="I'll check the current weather in Paris for you.", tool_calls=[ @@ -445,6 +445,16 @@ def test_extract_replies_from_multi_tool_response_with_thinking(self, mock_boto3 "index": 0, "finish_reason": "tool_use", "usage": {"prompt_tokens": 412, "completion_tokens": 146, "total_tokens": 558}, + "reasoning_content": { + "reasoning_text": { + "text": "The user wants to know the weather in Paris. I have a `weather` function " + "available that can provide this information. \n\nRequired parameters for " + "the weather function:\n- city: The city to get the weather for\n\nIn this " + 'case, the user has clearly specified "Paris" as the city, so I have all ' + "the required information to make the function call.", + "signature": "...", + } + }, }, ) assert replies[0] == expected_message @@ -543,6 +553,62 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 1 assert replies == expected_messages + # TODO + # def test_process_streaming_response_one_tool_call_with_thinking(self, mock_boto3_session): + # model = "anthropic.claude-3-5-sonnet-20240620-v1:0" + # type_ = ( + # "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + # ) + # streaming_chunks = [] + # + # def test_callback(chunk: StreamingChunk): + # streaming_chunks.append(chunk) + # + # events = [ + # {"messageStart": {"role": "assistant"}}, + # {"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " answer your question about the"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " weather in Berlin and Paris, I'll"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " need to use the weather_tool"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " for each city. Let"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " me fetch that information for"}, "contentBlockIndex": 0}}, + # {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, + # {"contentBlockStop": {"contentBlockIndex": 0}}, + # { + # "contentBlockStart": { + # "start": {"toolUse": {"toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA", "name": "weather_tool"}}, + # "contentBlockIndex": 1, + # } + # }, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"location":'}}, "contentBlockIndex": 1}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Be'}}, "contentBlockIndex": 1}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": 'rlin"}'}}, "contentBlockIndex": 1}}, + # {"contentBlockStop": {"contentBlockIndex": 1}}, + # { + # "contentBlockStart": { + # "start": {"toolUse": {"toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw", "name": "weather_tool"}}, + # "contentBlockIndex": 2, + # } + # }, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"l'}}, "contentBlockIndex": 2}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": "ocati"}}, "contentBlockIndex": 2}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": 'on": "P'}}, "contentBlockIndex": 2}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": "ari"}}, "contentBlockIndex": 2}}, + # {"contentBlockDelta": {"delta": {"toolUse": {"input": 's"}'}}, "contentBlockIndex": 2}}, + # {"contentBlockStop": {"contentBlockIndex": 2}}, + # {"messageStop": {"stopReason": "tool_use"}}, + # { + # "metadata": { + # "usage": {"inputTokens": 366, "outputTokens": 83, "totalTokens": 449}, + # "metrics": {"latencyMs": 3194}, + # } + # }, + # ] + # + # replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) + def test_parse_streaming_response_with_two_tool_calls(self, mock_boto3_session): model = "anthropic.claude-3-5-sonnet-20240620-v1:0" type_ = ( @@ -596,13 +662,7 @@ def test_callback(chunk: StreamingChunk): }, ] - component_info = ComponentInfo( - type=type_, - ) - - replies = _parse_streaming_response(events, test_callback, model, component_info) - # Pop completion_start_time since it will always change - replies[0].meta.pop("completion_start_time") + replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) expected_messages = [ ChatMessage.from_assistant( text="To answer your question about the weather in Berlin and Paris, I'll need to use the " @@ -621,6 +681,7 @@ def test_callback(chunk: StreamingChunk): "index": 0, "finish_reason": "tool_use", "usage": {"prompt_tokens": 366, "completion_tokens": 83, "total_tokens": 449}, + "completion_start_time": ANY, }, ), ] From 7dbe374912b5ef580eda229b2c4e0726d137da45 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 4 Aug 2025 13:57:34 +0200 Subject: [PATCH 4/7] Update _format_messages to add back the reasoning content --- integrations/amazon_bedrock/pyproject.toml | 2 +- .../generators/amazon_bedrock/chat/utils.py | 17 +++++- .../tests/test_chat_generator_utils.py | 61 ++++++++++++++++++- 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 31e82b12e..4bd200efc 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ [tool.hatch.envs.test.scripts] unit = 'pytest -m "not integration" {args:tests}' -integration = 'pytest -m "integration" {args:tests}' +integration = 'pytest -m "integration" {args:tests} -s -k test_live_run_with_tool_call_and_thinking_streaming' all = 'pytest {args:tests}' cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x' diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index 97aa351d2..63d9cbf7d 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -55,6 +55,15 @@ def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]: Dictionary representing the tool call message in Bedrock's expected format """ content: List[Dict[str, Any]] = [] + + # tool call messages can contain reasoning content + if tool_call_message.meta.get("reasoning_content"): + reasoning_content = tool_call_message.meta["reasoning_content"] + # If reasoningText is present, replace it with reasoning_text + if "reasoning_text" in reasoning_content: + reasoning_content["reasoningText"] = reasoning_content.pop("reasoning_text") + content.append({"reasoningContent": reasoning_content}) + # Tool call message can contain text if tool_call_message.text: content.append({"text": tool_call_message.text}) @@ -168,6 +177,13 @@ def _format_text_image_message(message: ChatMessage) -> Dict[str, Any]: content_parts = message._content bedrock_content_blocks: List[Dict[str, Any]] = [] + # Add reasoning content if available as the first content block + if message.meta.get("reasoning_content"): + reasoning_content = message.meta["reasoning_content"] + if "reasoning_text" in reasoning_content: + reasoning_content["reasoningText"] = reasoning_content.pop("reasoning_text") + bedrock_content_blocks.append({"reasoningContent": reasoning_content}) + for part in content_parts: if isinstance(part, TextContent): bedrock_content_blocks.append({"text": part.text}) @@ -279,7 +295,6 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis return replies -# Bedrock streaming to Haystack util methods def _convert_event_to_streaming_chunk( event: Dict[str, Any], model: str, component_info: ComponentInfo ) -> StreamingChunk: diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index ef298e985..22c265687 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -110,6 +110,65 @@ def test_format_messages(self): {"role": "assistant", "content": [{"text": "The weather in Paris is sunny and 25°C."}]}, ] + def test_format_message_thinking(self): + assistant_message = ChatMessage.from_assistant( + "This is a test message.", + meta={"reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the message.", + "signature": "reasoning_signature" + } + }} + ) + formatted_message = _format_messages([assistant_message])[1][0] + assert formatted_message == { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "This is the reasoning behind the message.", + "signature": "reasoning_signature" + } + } + }, + {"text": "This is a test message."}, + ], + } + + tool_call_message = ChatMessage.from_assistant( + "This is a test message with a tool call.", + tool_calls=[ToolCall(id="123", tool_name="test_tool", arguments={"key": "value"})], + meta={"reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the tool call.", + "signature": "reasoning_signature" + } + }} + ) + formatted_message = _format_messages([tool_call_message])[1][0] + assert formatted_message == { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "This is the reasoning behind the tool call.", + "signature": "reasoning_signature" + } + } + }, + {"text": "This is a test message with a tool call."}, + { + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"}, + } + }, + ], + } + def test_format_text_image_message(self): plain_assistant_message = ChatMessage.from_assistant("This is a test message.") formatted_message = _format_text_image_message(plain_assistant_message) @@ -152,7 +211,7 @@ def test_format_text_image_message_errors(self): with pytest.raises(ValueError): _format_text_image_message(image_message) - def test_formate_messages_multi_tool(self): + def test_format_messages_multi_tool(self): messages = [ ChatMessage.from_user("What is the weather in Berlin and Paris?"), ChatMessage.from_assistant( From 1711a14782b590cc44ef25b3ec0feea0464d7c40 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 4 Aug 2025 14:26:01 +0200 Subject: [PATCH 5/7] Add thinking support to streaming as well --- integrations/amazon_bedrock/pyproject.toml | 2 +- .../generators/amazon_bedrock/chat/utils.py | 35 ++- .../tests/test_chat_generator.py | 2 - .../tests/test_chat_generator_utils.py | 208 ++++++++++++------ 4 files changed, 166 insertions(+), 81 deletions(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 4bd200efc..31e82b12e 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ [tool.hatch.envs.test.scripts] unit = 'pytest -m "not integration" {args:tests}' -integration = 'pytest -m "integration" {args:tests} -s -k test_live_run_with_tool_call_and_thinking_streaming' +integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x' diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index 63d9cbf7d..317b31117 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -325,7 +325,6 @@ def _convert_event_to_streaming_chunk( content="", meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, # We follow the same format used in the OpenAIChatGenerator "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]] @@ -355,7 +354,6 @@ def _convert_event_to_streaming_chunk( content=delta["text"], meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, "tool_calls": None, "finish_reason": None, @@ -369,7 +367,6 @@ def _convert_event_to_streaming_chunk( content="", meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]] { @@ -387,6 +384,19 @@ def _convert_event_to_streaming_chunk( "received_at": datetime.now(timezone.utc).isoformat(), }, ) + # This is for accumulating reasoning content deltas + elif "reasoningContent" in delta: + streaming_chunk = StreamingChunk( + content="", + meta={ + "model": model, + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": datetime.now(timezone.utc).isoformat(), + "reasoning_content": delta["reasoningContent"], + }, + ) elif "messageStop" in event: finish_reason = event["messageStop"].get("stopReason") @@ -441,8 +451,23 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C A ChatMessage object constructed from the streaming chunks, containing the aggregated text, processed tool calls, and metadata. """ + # Join all text content from the chunks text = "".join([chunk.content for chunk in chunks]) + # If reasoning content is present in any chunk, accumulate it + reasoning_text = "" + reasoning_signature = None + for chunk in chunks: + if chunk.meta.get("reasoning_content"): + reasoning_content = chunk.meta["reasoning_content"] + if "text" in reasoning_content: + reasoning_text += reasoning_content["text"] + elif "signature" in reasoning_content: + reasoning_signature = reasoning_content["signature"] + reasoning_content = None + if reasoning_text: + reasoning_content = {"reasoning_text": {"text": reasoning_text, "signature": reasoning_signature}} + # Process tool calls if present in any chunk tool_calls = [] tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index @@ -494,6 +519,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C "finish_reason": finish_reason, "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received "usage": usage, + "reasoning_content": reasoning_content if reasoning_content else None, } return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) @@ -514,14 +540,11 @@ def _parse_streaming_response( :param component_info: ComponentInfo object :return: List of ChatMessage objects """ - aws_chunks = [] chunks: List[StreamingChunk] = [] for event in response_stream: - aws_chunks.append(event) streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model, component_info=component_info) streaming_callback(streaming_chunk) chunks.append(streaming_chunk) - print(aws_chunks) replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)] return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index b754c1e96..4656d1cc7 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -447,7 +447,6 @@ def test_live_run_with_tool_call_and_thinking(self, tools): assert not final_message.tool_call assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() - assert "berlin" in final_message.text.lower() def test_live_run_with_tool_call_and_thinking_streaming(self, tools): initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] @@ -498,7 +497,6 @@ def test_live_run_with_tool_call_and_thinking_streaming(self, tools): assert not final_message.tool_call assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() - assert "berlin" in final_message.text.lower() @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) def test_live_run_with_multi_tool_calls_streaming(self, model_name, tools): diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index 22c265687..0edddaee6 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -1,7 +1,7 @@ import base64 +from unittest.mock import ANY import pytest -from unittest.mock import ANY from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, ImageContent, StreamingChunk, ToolCall from haystack.tools import Tool @@ -113,12 +113,14 @@ def test_format_messages(self): def test_format_message_thinking(self): assistant_message = ChatMessage.from_assistant( "This is a test message.", - meta={"reasoning_content": { - "reasoning_text": { - "text": "This is the reasoning behind the message.", - "signature": "reasoning_signature" + meta={ + "reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the message.", + "signature": "reasoning_signature", + } } - }} + }, ) formatted_message = _format_messages([assistant_message])[1][0] assert formatted_message == { @@ -128,7 +130,7 @@ def test_format_message_thinking(self): "reasoningContent": { "reasoningText": { "text": "This is the reasoning behind the message.", - "signature": "reasoning_signature" + "signature": "reasoning_signature", } } }, @@ -139,12 +141,14 @@ def test_format_message_thinking(self): tool_call_message = ChatMessage.from_assistant( "This is a test message with a tool call.", tool_calls=[ToolCall(id="123", tool_name="test_tool", arguments={"key": "value"})], - meta={"reasoning_content": { - "reasoning_text": { - "text": "This is the reasoning behind the tool call.", - "signature": "reasoning_signature" + meta={ + "reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the tool call.", + "signature": "reasoning_signature", + } } - }} + }, ) formatted_message = _format_messages([tool_call_message])[1][0] assert formatted_message == { @@ -154,7 +158,7 @@ def test_format_message_thinking(self): "reasoningContent": { "reasoningText": { "text": "This is the reasoning behind the tool call.", - "signature": "reasoning_signature" + "signature": "reasoning_signature", } } }, @@ -507,10 +511,10 @@ def test_extract_replies_from_one_tool_response_with_thinking(self, mock_boto3_s "reasoning_content": { "reasoning_text": { "text": "The user wants to know the weather in Paris. I have a `weather` function " - "available that can provide this information. \n\nRequired parameters for " - "the weather function:\n- city: The city to get the weather for\n\nIn this " - 'case, the user has clearly specified "Paris" as the city, so I have all ' - "the required information to make the function call.", + "available that can provide this information. \n\nRequired parameters for " + "the weather function:\n- city: The city to get the weather for\n\nIn this " + 'case, the user has clearly specified "Paris" as the city, so I have all ' + "the required information to make the function call.", "signature": "...", } }, @@ -612,61 +616,121 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 1 assert replies == expected_messages - # TODO - # def test_process_streaming_response_one_tool_call_with_thinking(self, mock_boto3_session): - # model = "anthropic.claude-3-5-sonnet-20240620-v1:0" - # type_ = ( - # "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" - # ) - # streaming_chunks = [] - # - # def test_callback(chunk: StreamingChunk): - # streaming_chunks.append(chunk) - # - # events = [ - # {"messageStart": {"role": "assistant"}}, - # {"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " answer your question about the"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " weather in Berlin and Paris, I'll"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " need to use the weather_tool"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " for each city. Let"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " me fetch that information for"}, "contentBlockIndex": 0}}, - # {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, - # {"contentBlockStop": {"contentBlockIndex": 0}}, - # { - # "contentBlockStart": { - # "start": {"toolUse": {"toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA", "name": "weather_tool"}}, - # "contentBlockIndex": 1, - # } - # }, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"location":'}}, "contentBlockIndex": 1}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Be'}}, "contentBlockIndex": 1}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": 'rlin"}'}}, "contentBlockIndex": 1}}, - # {"contentBlockStop": {"contentBlockIndex": 1}}, - # { - # "contentBlockStart": { - # "start": {"toolUse": {"toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw", "name": "weather_tool"}}, - # "contentBlockIndex": 2, - # } - # }, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"l'}}, "contentBlockIndex": 2}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": "ocati"}}, "contentBlockIndex": 2}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": 'on": "P'}}, "contentBlockIndex": 2}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": "ari"}}, "contentBlockIndex": 2}}, - # {"contentBlockDelta": {"delta": {"toolUse": {"input": 's"}'}}, "contentBlockIndex": 2}}, - # {"contentBlockStop": {"contentBlockIndex": 2}}, - # {"messageStop": {"stopReason": "tool_use"}}, - # { - # "metadata": { - # "usage": {"inputTokens": 366, "outputTokens": 83, "totalTokens": 449}, - # "metrics": {"latencyMs": 3194}, - # } - # }, - # ] - # - # replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) + def test_process_streaming_response_one_tool_call_with_thinking(self, mock_boto3_session): + model = "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0" + type_ = ( + "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + ) + streaming_chunks = [] + + def test_callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + events = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": "The user is asking about the weather"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " in Paris. I have"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " access to a"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " weather function that takes"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " a city parameter. Paris"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " is clearly specifie"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": "d as the city, so I have all"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " the required parameters to make the"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " function call."}}, + "contentBlockIndex": 0, + } + }, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "..."}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg", "name": "weather"}}, + "contentBlockIndex": 1, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"ci'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "ty"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '": "P'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "aris"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}}, + {"contentBlockStop": {"contentBlockIndex": 1}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 412, "outputTokens": 104, "totalTokens": 516}, + "metrics": {"latencyMs": 2134}, + } + }, + ] + + replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) + + expected_messages = [ + ChatMessage.from_assistant( + tool_calls=[ + ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_1gPhO4A1RNWgzKbt1PXWLg"), + ], + meta={ + "model": "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0", + "index": 0, + "finish_reason": "tool_use", + "usage": {"prompt_tokens": 412, "completion_tokens": 104, "total_tokens": 516}, + "completion_start_time": ANY, + "reasoning_content": { + "reasoning_text": { + "text": "The user is asking about the weather in Paris. I have access to a weather " + "function that takes a city parameter. Paris is clearly specified as the city, " + "so I have all the required parameters to make the function call.", + "signature": "...", + } + }, + }, + ), + ] + assert replies == expected_messages def test_parse_streaming_response_with_two_tool_calls(self, mock_boto3_session): model = "anthropic.claude-3-5-sonnet-20240620-v1:0" From ed2830b9d0377cc9bd900a4fdb5529d3d58e8604 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 4 Aug 2025 14:31:10 +0200 Subject: [PATCH 6/7] Fix unit tests --- integrations/amazon_bedrock/tests/test_chat_generator_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index 0edddaee6..b109b31c9 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -592,6 +592,7 @@ def test_callback(chunk: StreamingChunk): "index": 0, "finish_reason": "tool_use", "usage": {"prompt_tokens": 364, "completion_tokens": 71, "total_tokens": 435}, + "reasoning_content": None, }, ) ] @@ -805,6 +806,7 @@ def test_callback(chunk: StreamingChunk): "finish_reason": "tool_use", "usage": {"prompt_tokens": 366, "completion_tokens": 83, "total_tokens": 449}, "completion_start_time": ANY, + "reasoning_content": None, }, ), ] From cc8c14acdb5958e0e2b0746c9e856224227455df Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 4 Aug 2025 15:27:15 +0200 Subject: [PATCH 7/7] PR comments --- .../components/generators/amazon_bedrock/chat/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index 317b31117..00bbcdfa3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -57,8 +57,7 @@ def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]: content: List[Dict[str, Any]] = [] # tool call messages can contain reasoning content - if tool_call_message.meta.get("reasoning_content"): - reasoning_content = tool_call_message.meta["reasoning_content"] + if reasoning_content := tool_call_message.meta.get("reasoning_content"): # If reasoningText is present, replace it with reasoning_text if "reasoning_text" in reasoning_content: reasoning_content["reasoningText"] = reasoning_content.pop("reasoning_text") @@ -458,8 +457,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C reasoning_text = "" reasoning_signature = None for chunk in chunks: - if chunk.meta.get("reasoning_content"): - reasoning_content = chunk.meta["reasoning_content"] + if reasoning_content := chunk.meta.get("reasoning_content"): if "text" in reasoning_content: reasoning_text += reasoning_content["text"] elif "signature" in reasoning_content: