From e353cf68237b788730f39cab53e2e1d0eb19bff8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 19 May 2025 15:18:45 +0200 Subject: [PATCH 1/8] Add Toolset support to OllamaChatGenerator --- .../generators/ollama/chat/chat_generator.py | 21 ++++-- .../ollama/tests/test_chat_generator.py | 70 +++++++++++++++++++ 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index f86c75d83..fb4e60ffd 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -2,7 +2,8 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall -from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.tools.toolset import Toolset from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from pydantic.json_schema import JsonSchemaValue @@ -151,7 +152,7 @@ def __init__( timeout: int = 120, keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None, ): """ @@ -177,7 +178,8 @@ def __init__( A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param tools: - A list of tools for which the model can prepare calls. + A list of tools or a Toolset for which the model can prepare calls. + This parameter can accept either a list of `Tool` objects or a `Toolset` instance. Not all models support tools. For a list of models compatible with tools, see the [models page](https://ollama.com/search?c=tools). :param response_format: @@ -207,7 +209,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None - serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None + return default_to_dict( self, model=self.model, @@ -216,7 +218,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, timeout=self.timeout, streaming_callback=callback_name, - tools=serialized_tools, + tools=serialize_tools_or_toolset(self.tools), response_format=self.response_format, ) @@ -280,7 +282,7 @@ def run( self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, *, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -294,7 +296,8 @@ def run( top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a + list of `Tool` objects or a `Toolset` instance. If set, it will override the `tools` parameter set during component initialization. :param streaming_callback: A callback function that is called when a new token is received from the stream. @@ -320,6 +323,10 @@ def run( msg = "Ollama does not support streaming and response_format at the same time. Please choose one." raise ValueError(msg) + # Convert toolset to list of tools if needed + if isinstance(tools, Toolset): + tools = list(tools) + ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages] diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 48d3fc140..24f6d14e2 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -11,6 +11,7 @@ ToolCall, ) from haystack.tools import Tool +from haystack.tools.toolset import Toolset from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama.chat.chat_generator import ( @@ -212,6 +213,34 @@ def test_init_fail_with_duplicate_tool_names(self, tools): with pytest.raises(ValueError): OllamaChatGenerator(tools=duplicate_tools) + def test_init_with_toolset(self, tools): + """Test that the OllamaChatGenerator can be initialized with a Toolset.""" + toolset = Toolset(tools) + generator = OllamaChatGenerator(model="llama3", tools=toolset) + assert generator.tools == toolset + + def test_to_dict_with_toolset(self, tools): + """Test that the OllamaChatGenerator can be serialized to a dictionary with a Toolset.""" + toolset = Toolset(tools) + generator = OllamaChatGenerator(model="llama3", tools=toolset) + data = generator.to_dict() + + assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" + assert "tools" in data["init_parameters"]["tools"]["data"] + assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) + + def test_from_dict_with_toolset(self, tools): + """Test that the OllamaChatGenerator can be deserialized from a dictionary with a Toolset.""" + toolset = Toolset(tools) + component = OllamaChatGenerator(model="llama3", tools=toolset) + data = component.to_dict() + + deserialized_component = OllamaChatGenerator.from_dict(data) + + assert isinstance(deserialized_component.tools, Toolset) + assert len(deserialized_component.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) + def test_to_dict(self): tool = Tool( name="name", @@ -620,3 +649,44 @@ def test_run_with_tools_and_format(self, tools): message = ChatMessage.from_user("What's the weather in Paris?") with pytest.raises(ValueError): chat_generator.run([message]) + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run_with_toolset(self, mock_client, tools): + """Test that the OllamaChatGenerator can run with a Toolset.""" + toolset = Toolset(tools) + generator = OllamaChatGenerator(model="llama3", tools=toolset) + + mock_response = ChatResponse( + model="llama3", + created_at="2023-12-12T14:13:43.416799Z", + message={ + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "weather", + "arguments": {"city": "Paris"}, + } + } + ], + }, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("What's the weather in Paris?")]) + + mock_client_instance.chat.assert_called_once() + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].tool_call.tool_name == "weather" + assert result["replies"][0].tool_call.arguments == {"city": "Paris"} From d4fb6eb4e9123e52c25d5286e82a70b11a30019d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 19 May 2025 16:11:10 +0200 Subject: [PATCH 2/8] Lint --- .../generators/ollama/chat/chat_generator.py | 7 ++++++- integrations/ollama/tests/test_chat_generator.py | 14 +++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index fb4e60ffd..55e313ede 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -2,7 +2,12 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall -from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.tools import ( + Tool, + _check_duplicate_tool_names, + deserialize_tools_or_toolset_inplace, + serialize_tools_or_toolset, +) from haystack.tools.toolset import Toolset from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from pydantic.json_schema import JsonSchemaValue diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 24f6d14e2..7a3fff9f0 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -218,17 +218,17 @@ def test_init_with_toolset(self, tools): toolset = Toolset(tools) generator = OllamaChatGenerator(model="llama3", tools=toolset) assert generator.tools == toolset - + def test_to_dict_with_toolset(self, tools): """Test that the OllamaChatGenerator can be serialized to a dictionary with a Toolset.""" toolset = Toolset(tools) generator = OllamaChatGenerator(model="llama3", tools=toolset) data = generator.to_dict() - + assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" assert "tools" in data["init_parameters"]["tools"]["data"] assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) - + def test_from_dict_with_toolset(self, tools): """Test that the OllamaChatGenerator can be deserialized from a dictionary with a Toolset.""" toolset = Toolset(tools) @@ -655,7 +655,7 @@ def test_run_with_toolset(self, mock_client, tools): """Test that the OllamaChatGenerator can run with a Toolset.""" toolset = Toolset(tools) generator = OllamaChatGenerator(model="llama3", tools=toolset) - + mock_response = ChatResponse( model="llama3", created_at="2023-12-12T14:13:43.416799Z", @@ -679,12 +679,12 @@ def test_run_with_toolset(self, mock_client, tools): eval_count=298, eval_duration=4799921000, ) - + mock_client_instance = mock_client.return_value mock_client_instance.chat.return_value = mock_response - + result = generator.run(messages=[ChatMessage.from_user("What's the weather in Paris?")]) - + mock_client_instance.chat.assert_called_once() assert "replies" in result assert len(result["replies"]) == 1 From 9839811467f4bc272fac785adacedb87ae3bdfdd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 19 May 2025 17:58:05 +0200 Subject: [PATCH 3/8] Lambdas are not serializable --- integrations/ollama/tests/test_chat_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 7a3fff9f0..7567e59a0 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -20,6 +20,8 @@ _convert_ollama_response_to_chatmessage, ) +def get_weather(city: str) -> str: + return f"The weather in {city} is sunny" @pytest.fixture def tools(): @@ -32,7 +34,7 @@ def tools(): name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=get_weather, ) return [tool] From 33b4b63af52cc5fab02931c5f8e9613f771b7119 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 19 May 2025 17:58:34 +0200 Subject: [PATCH 4/8] Lint --- integrations/ollama/tests/test_chat_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 7567e59a0..58842e32b 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -20,9 +20,11 @@ _convert_ollama_response_to_chatmessage, ) + def get_weather(city: str) -> str: return f"The weather in {city} is sunny" + @pytest.fixture def tools(): tool_parameters = { From 8e22f95941716ab76d328e37d540d95eb6644295 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 21 May 2025 09:43:39 +0200 Subject: [PATCH 5/8] Generate tool call id if not available --- .../components/generators/ollama/chat/chat_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 55e313ede..b98f0b863 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union +import uuid from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall @@ -114,8 +115,9 @@ def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> tool_calls = [] if ollama_tool_calls := ollama_message.get("tool_calls"): for ollama_tc in ollama_tool_calls: + call_id = ollama_tc["id"] if "id" in ollama_tc else str(uuid.uuid4()) tool_calls.append( - ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) + ToolCall(id=call_id, tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) ) message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) From c935608beb59312c4885db3bd2fbbb0b500f71d8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 21 May 2025 09:50:13 +0200 Subject: [PATCH 6/8] Lint --- .../components/generators/ollama/chat/chat_generator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index b98f0b863..4d8c9a045 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, Union import uuid +from typing import Any, Callable, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall @@ -117,7 +117,9 @@ def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> for ollama_tc in ollama_tool_calls: call_id = ollama_tc["id"] if "id" in ollama_tc else str(uuid.uuid4()) tool_calls.append( - ToolCall(id=call_id, tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) + ToolCall( + id=call_id, tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"] + ) ) message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) From 7272bfe03cef2655bd74de72920a3ceee835e3c2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 21 May 2025 10:05:57 +0200 Subject: [PATCH 7/8] Revert back to not using ToolCall id --- .../components/generators/ollama/chat/chat_generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 4d8c9a045..fb2ad40d6 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -115,10 +115,9 @@ def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> tool_calls = [] if ollama_tool_calls := ollama_message.get("tool_calls"): for ollama_tc in ollama_tool_calls: - call_id = ollama_tc["id"] if "id" in ollama_tc else str(uuid.uuid4()) tool_calls.append( ToolCall( - id=call_id, tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"] + tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"] ) ) From 5b5a36ab71c2264704e38e1f6dd063d2acbb3f48 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 21 May 2025 10:09:04 +0200 Subject: [PATCH 8/8] Lint --- .../components/generators/ollama/chat/chat_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index fb2ad40d6..55e313ede 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,4 +1,3 @@ -import uuid from typing import Any, Callable, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict @@ -116,9 +115,7 @@ def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> if ollama_tool_calls := ollama_message.get("tool_calls"): for ollama_tc in ollama_tool_calls: tool_calls.append( - ToolCall( - tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"] - ) + ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) ) message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)