diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index 238ef1ae7..00bbcdfa3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -55,6 +55,14 @@ def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]: Dictionary representing the tool call message in Bedrock's expected format """ content: List[Dict[str, Any]] = [] + + # tool call messages can contain reasoning content + if reasoning_content := tool_call_message.meta.get("reasoning_content"): + # If reasoningText is present, replace it with reasoning_text + if "reasoning_text" in reasoning_content: + reasoning_content["reasoningText"] = reasoning_content.pop("reasoning_text") + content.append({"reasoningContent": reasoning_content}) + # Tool call message can contain text if tool_call_message.text: content.append({"text": tool_call_message.text}) @@ -168,6 +176,13 @@ def _format_text_image_message(message: ChatMessage) -> Dict[str, Any]: content_parts = message._content bedrock_content_blocks: List[Dict[str, Any]] = [] + # Add reasoning content if available as the first content block + if message.meta.get("reasoning_content"): + reasoning_content = message.meta["reasoning_content"] + if "reasoning_text" in reasoning_content: + reasoning_content["reasoningText"] = reasoning_content.pop("reasoning_text") + bedrock_content_blocks.append({"reasoningContent": reasoning_content}) + for part in content_parts: if isinstance(part, TextContent): bedrock_content_blocks.append({"text": part.text}) @@ -221,7 +236,6 @@ def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], return system_prompts, repaired_bedrock_formatted_messages -# Bedrock to Haystack util method def _parse_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]: """ Parse a Bedrock API response into Haystack ChatMessage objects. @@ -267,6 +281,12 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis arguments=tool_use.get("input", {}), ) tool_calls.append(tool_call) + elif "reasoningContent" in content_block: + reasoning_content = content_block["reasoningContent"] + # If reasoningText is present, replace it with reasoning_text + if "reasoningText" in reasoning_content: + reasoning_content["reasoning_text"] = reasoning_content.pop("reasoningText") + base_meta.update({"reasoning_content": reasoning_content}) # Create a single ChatMessage with combined text and tool calls replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta)) @@ -274,7 +294,6 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis return replies -# Bedrock streaming to Haystack util methods def _convert_event_to_streaming_chunk( event: Dict[str, Any], model: str, component_info: ComponentInfo ) -> StreamingChunk: @@ -305,7 +324,6 @@ def _convert_event_to_streaming_chunk( content="", meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, # We follow the same format used in the OpenAIChatGenerator "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]] @@ -335,7 +353,6 @@ def _convert_event_to_streaming_chunk( content=delta["text"], meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, "tool_calls": None, "finish_reason": None, @@ -349,7 +366,6 @@ def _convert_event_to_streaming_chunk( content="", meta={ "model": model, - # This is always 0 b/c it represents the choice index "index": 0, "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]] { @@ -367,6 +383,19 @@ def _convert_event_to_streaming_chunk( "received_at": datetime.now(timezone.utc).isoformat(), }, ) + # This is for accumulating reasoning content deltas + elif "reasoningContent" in delta: + streaming_chunk = StreamingChunk( + content="", + meta={ + "model": model, + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": datetime.now(timezone.utc).isoformat(), + "reasoning_content": delta["reasoningContent"], + }, + ) elif "messageStop" in event: finish_reason = event["messageStop"].get("stopReason") @@ -421,8 +450,22 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C A ChatMessage object constructed from the streaming chunks, containing the aggregated text, processed tool calls, and metadata. """ + # Join all text content from the chunks text = "".join([chunk.content for chunk in chunks]) + # If reasoning content is present in any chunk, accumulate it + reasoning_text = "" + reasoning_signature = None + for chunk in chunks: + if reasoning_content := chunk.meta.get("reasoning_content"): + if "text" in reasoning_content: + reasoning_text += reasoning_content["text"] + elif "signature" in reasoning_content: + reasoning_signature = reasoning_content["signature"] + reasoning_content = None + if reasoning_text: + reasoning_content = {"reasoning_text": {"text": reasoning_text, "signature": reasoning_signature}} + # Process tool calls if present in any chunk tool_calls = [] tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index @@ -474,6 +517,7 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C "finish_reason": finish_reason, "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received "usage": usage, + "reasoning_content": reasoning_content if reasoning_content else None, } return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index e3daf2c3a..4656d1cc7 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -283,7 +283,6 @@ def test_prepare_request_params_tool_config(self, top_song_tool_config, mock_bot assert request_params["toolConfig"] == top_song_tool_config -# In the CI, those tests are skipped if AWS Authentication fails @pytest.mark.integration class TestAmazonBedrockChatGeneratorInference: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @@ -400,6 +399,105 @@ def test_live_run_with_multi_tool_calls(self, model_name, tools): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() + def test_live_run_with_tool_call_and_thinking(self, tools): + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator( + model="arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + tools=tools, + generation_kwargs={ + "maxTokens": 8192, + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0].id, "Tool call does not contain value for 'id' key" + assert tool_calls[0].tool_name == "weather" + assert tool_calls[0].arguments["city"] == "Paris" + assert tool_call_message.meta["finish_reason"] == "tool_use" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [ + ChatMessage.from_tool(tool_result="22° C", origin=tool_call) for tool_call in tool_calls + ] + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + def test_live_run_with_tool_call_and_thinking_streaming(self, tools): + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator( + model="arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + tools=tools, + generation_kwargs={ + "maxTokens": 8192, + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + }, + }, + streaming_callback=print_streaming_chunk, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0].id, "Tool call does not contain value for 'id' key" + assert tool_calls[0].tool_name == "weather" + assert tool_calls[0].arguments["city"] == "Paris" + assert tool_call_message.meta["finish_reason"] == "tool_use" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [ + ChatMessage.from_tool(tool_result="22° C", origin=tool_call) for tool_call in tool_calls + ] + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) def test_live_run_with_multi_tool_calls_streaming(self, model_name, tools): """ diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index 87496bd51..b109b31c9 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -1,4 +1,5 @@ import base64 +from unittest.mock import ANY import pytest from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, ImageContent, StreamingChunk, ToolCall @@ -109,6 +110,69 @@ def test_format_messages(self): {"role": "assistant", "content": [{"text": "The weather in Paris is sunny and 25°C."}]}, ] + def test_format_message_thinking(self): + assistant_message = ChatMessage.from_assistant( + "This is a test message.", + meta={ + "reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the message.", + "signature": "reasoning_signature", + } + } + }, + ) + formatted_message = _format_messages([assistant_message])[1][0] + assert formatted_message == { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "This is the reasoning behind the message.", + "signature": "reasoning_signature", + } + } + }, + {"text": "This is a test message."}, + ], + } + + tool_call_message = ChatMessage.from_assistant( + "This is a test message with a tool call.", + tool_calls=[ToolCall(id="123", tool_name="test_tool", arguments={"key": "value"})], + meta={ + "reasoning_content": { + "reasoning_text": { + "text": "This is the reasoning behind the tool call.", + "signature": "reasoning_signature", + } + } + }, + ) + formatted_message = _format_messages([tool_call_message])[1][0] + assert formatted_message == { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "This is the reasoning behind the tool call.", + "signature": "reasoning_signature", + } + } + }, + {"text": "This is a test message with a tool call."}, + { + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"}, + } + }, + ], + } + def test_format_text_image_message(self): plain_assistant_message = ChatMessage.from_assistant("This is a test message.") formatted_message = _format_text_image_message(plain_assistant_message) @@ -151,7 +215,7 @@ def test_format_text_image_message_errors(self): with pytest.raises(ValueError): _format_text_image_message(image_message) - def test_formate_messages_multi_tool(self): + def test_format_messages_multi_tool(self): messages = [ ChatMessage.from_user("What is the weather in Berlin and Paris?"), ChatMessage.from_assistant( @@ -380,6 +444,84 @@ def test_extract_replies_from_multi_tool_response(self, mock_boto3_session): ) assert replies[0] == expected_message + def test_extract_replies_from_one_tool_response_with_thinking(self, mock_boto3_session): + model = "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" + response_body = { + "ResponseMetadata": { + "RequestId": "d7be81a1-5d37-40fe-936a-7c96e850cdda", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Tue, 15 Jul 2025 12:49:56 GMT", + "content-type": "application/json", + "content-length": "1107", + "connection": "keep-alive", + "x-amzn-requestid": "d7be81a1-5d37-40fe-936a-7c96e850cdda", + }, + "RetryAttempts": 0, + }, + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "The user wants to know the weather in Paris. I have a `weather` function " + "available that can provide this information. \n\nRequired parameters for " + "the weather function:\n- city: The city to get the weather for\n\nIn this " + 'case, the user has clearly specified "Paris" as the city, so I have all ' + "the required information to make the function call.", + "signature": "...", + } + } + }, + {"text": "I'll check the current weather in Paris for you."}, + { + "toolUse": { + "toolUseId": "tooluse_iUqy8-ypSByLK5zFkka8uA", + "name": "weather", + "input": {"city": "Paris"}, + } + }, + ], + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 412, + "outputTokens": 146, + "totalTokens": 558, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 0, + }, + "metrics": {"latencyMs": 4811}, + } + replies = _parse_completion_response(response_body, model) + + expected_message = ChatMessage.from_assistant( + text="I'll check the current weather in Paris for you.", + tool_calls=[ + ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA"), + ], + meta={ + "model": "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "index": 0, + "finish_reason": "tool_use", + "usage": {"prompt_tokens": 412, "completion_tokens": 146, "total_tokens": 558}, + "reasoning_content": { + "reasoning_text": { + "text": "The user wants to know the weather in Paris. I have a `weather` function " + "available that can provide this information. \n\nRequired parameters for " + "the weather function:\n- city: The city to get the weather for\n\nIn this " + 'case, the user has clearly specified "Paris" as the city, so I have all ' + "the required information to make the function call.", + "signature": "...", + } + }, + }, + ) + assert replies[0] == expected_message + def test_process_streaming_response_one_tool_call(self, mock_boto3_session): """ Test that process_streaming_response correctly handles streaming events and accumulates responses @@ -450,6 +592,7 @@ def test_callback(chunk: StreamingChunk): "index": 0, "finish_reason": "tool_use", "usage": {"prompt_tokens": 364, "completion_tokens": 71, "total_tokens": 435}, + "reasoning_content": None, }, ) ] @@ -474,6 +617,122 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 1 assert replies == expected_messages + def test_process_streaming_response_one_tool_call_with_thinking(self, mock_boto3_session): + model = "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0" + type_ = ( + "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + ) + streaming_chunks = [] + + def test_callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + events = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": "The user is asking about the weather"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " in Paris. I have"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " access to a"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " weather function that takes"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " a city parameter. Paris"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " is clearly specifie"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": "d as the city, so I have all"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " the required parameters to make the"}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " function call."}}, + "contentBlockIndex": 0, + } + }, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "..."}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg", "name": "weather"}}, + "contentBlockIndex": 1, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"ci'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "ty"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '": "P'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "aris"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}}, + {"contentBlockStop": {"contentBlockIndex": 1}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 412, "outputTokens": 104, "totalTokens": 516}, + "metrics": {"latencyMs": 2134}, + } + }, + ] + + replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) + + expected_messages = [ + ChatMessage.from_assistant( + tool_calls=[ + ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_1gPhO4A1RNWgzKbt1PXWLg"), + ], + meta={ + "model": "arn:aws:bedrock:us-east-1::inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0", + "index": 0, + "finish_reason": "tool_use", + "usage": {"prompt_tokens": 412, "completion_tokens": 104, "total_tokens": 516}, + "completion_start_time": ANY, + "reasoning_content": { + "reasoning_text": { + "text": "The user is asking about the weather in Paris. I have access to a weather " + "function that takes a city parameter. Paris is clearly specified as the city, " + "so I have all the required parameters to make the function call.", + "signature": "...", + } + }, + }, + ), + ] + assert replies == expected_messages + def test_parse_streaming_response_with_two_tool_calls(self, mock_boto3_session): model = "anthropic.claude-3-5-sonnet-20240620-v1:0" type_ = ( @@ -527,13 +786,7 @@ def test_callback(chunk: StreamingChunk): }, ] - component_info = ComponentInfo( - type=type_, - ) - - replies = _parse_streaming_response(events, test_callback, model, component_info) - # Pop completion_start_time since it will always change - replies[0].meta.pop("completion_start_time") + replies = _parse_streaming_response(events, test_callback, model, ComponentInfo(type=type_)) expected_messages = [ ChatMessage.from_assistant( text="To answer your question about the weather in Berlin and Paris, I'll need to use the " @@ -552,6 +805,8 @@ def test_callback(chunk: StreamingChunk): "index": 0, "finish_reason": "tool_use", "usage": {"prompt_tokens": 366, "completion_tokens": 83, "total_tokens": 449}, + "completion_start_time": ANY, + "reasoning_content": None, }, ), ]