Skip to content

Commit 9c3b13c

Browse files
authored
feat: OllamaChatGenerator - add Toolset support (#1765)
* Add Toolset support to OllamaChatGenerator * Lint * Lambdas are not serializable * Lint * Generate tool call id if not available * Lint * Revert back to not using ToolCall id * Lint
1 parent 2441048 commit 9c3b13c

File tree

2 files changed

+94
-8
lines changed

2 files changed

+94
-8
lines changed

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22

33
from haystack import component, default_from_dict, default_to_dict
44
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
5-
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace
5+
from haystack.tools import (
6+
Tool,
7+
_check_duplicate_tool_names,
8+
deserialize_tools_or_toolset_inplace,
9+
serialize_tools_or_toolset,
10+
)
11+
from haystack.tools.toolset import Toolset
612
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
713
from pydantic.json_schema import JsonSchemaValue
814

@@ -151,7 +157,7 @@ def __init__(
151157
timeout: int = 120,
152158
keep_alive: Optional[Union[float, str]] = None,
153159
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
154-
tools: Optional[List[Tool]] = None,
160+
tools: Optional[Union[List[Tool], Toolset]] = None,
155161
response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None,
156162
):
157163
"""
@@ -177,7 +183,8 @@ def __init__(
177183
A callback function that is called when a new token is received from the stream.
178184
The callback function accepts StreamingChunk as an argument.
179185
:param tools:
180-
A list of tools for which the model can prepare calls.
186+
A list of tools or a Toolset for which the model can prepare calls.
187+
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
181188
Not all models support tools. For a list of models compatible with tools, see the
182189
[models page](https://ollama.com/search?c=tools).
183190
:param response_format:
@@ -207,7 +214,7 @@ def to_dict(self) -> Dict[str, Any]:
207214
Dictionary with serialized data.
208215
"""
209216
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
210-
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
217+
211218
return default_to_dict(
212219
self,
213220
model=self.model,
@@ -216,7 +223,7 @@ def to_dict(self) -> Dict[str, Any]:
216223
generation_kwargs=self.generation_kwargs,
217224
timeout=self.timeout,
218225
streaming_callback=callback_name,
219-
tools=serialized_tools,
226+
tools=serialize_tools_or_toolset(self.tools),
220227
response_format=self.response_format,
221228
)
222229

@@ -280,7 +287,7 @@ def run(
280287
self,
281288
messages: List[ChatMessage],
282289
generation_kwargs: Optional[Dict[str, Any]] = None,
283-
tools: Optional[List[Tool]] = None,
290+
tools: Optional[Union[List[Tool], Toolset]] = None,
284291
*,
285292
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
286293
):
@@ -294,7 +301,8 @@ def run(
294301
top_p, etc. See the
295302
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
296303
:param tools:
297-
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
304+
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
305+
list of `Tool` objects or a `Toolset` instance. If set, it will override the `tools` parameter set
298306
during component initialization.
299307
:param streaming_callback:
300308
A callback function that is called when a new token is received from the stream.
@@ -320,6 +328,10 @@ def run(
320328
msg = "Ollama does not support streaming and response_format at the same time. Please choose one."
321329
raise ValueError(msg)
322330

331+
# Convert toolset to list of tools if needed
332+
if isinstance(tools, Toolset):
333+
tools = list(tools)
334+
323335
ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None
324336

325337
ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages]

integrations/ollama/tests/test_chat_generator.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ToolCall,
1212
)
1313
from haystack.tools import Tool
14+
from haystack.tools.toolset import Toolset
1415
from ollama._types import ChatResponse, ResponseError
1516

1617
from haystack_integrations.components.generators.ollama.chat.chat_generator import (
@@ -20,6 +21,10 @@
2021
)
2122

2223

24+
def get_weather(city: str) -> str:
25+
return f"The weather in {city} is sunny"
26+
27+
2328
@pytest.fixture
2429
def tools():
2530
tool_parameters = {
@@ -31,7 +36,7 @@ def tools():
3136
name="weather",
3237
description="useful to determine the weather in a given location",
3338
parameters=tool_parameters,
34-
function=lambda x: x,
39+
function=get_weather,
3540
)
3641

3742
return [tool]
@@ -212,6 +217,34 @@ def test_init_fail_with_duplicate_tool_names(self, tools):
212217
with pytest.raises(ValueError):
213218
OllamaChatGenerator(tools=duplicate_tools)
214219

220+
def test_init_with_toolset(self, tools):
221+
"""Test that the OllamaChatGenerator can be initialized with a Toolset."""
222+
toolset = Toolset(tools)
223+
generator = OllamaChatGenerator(model="llama3", tools=toolset)
224+
assert generator.tools == toolset
225+
226+
def test_to_dict_with_toolset(self, tools):
227+
"""Test that the OllamaChatGenerator can be serialized to a dictionary with a Toolset."""
228+
toolset = Toolset(tools)
229+
generator = OllamaChatGenerator(model="llama3", tools=toolset)
230+
data = generator.to_dict()
231+
232+
assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset"
233+
assert "tools" in data["init_parameters"]["tools"]["data"]
234+
assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools)
235+
236+
def test_from_dict_with_toolset(self, tools):
237+
"""Test that the OllamaChatGenerator can be deserialized from a dictionary with a Toolset."""
238+
toolset = Toolset(tools)
239+
component = OllamaChatGenerator(model="llama3", tools=toolset)
240+
data = component.to_dict()
241+
242+
deserialized_component = OllamaChatGenerator.from_dict(data)
243+
244+
assert isinstance(deserialized_component.tools, Toolset)
245+
assert len(deserialized_component.tools) == len(tools)
246+
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
247+
215248
def test_to_dict(self):
216249
tool = Tool(
217250
name="name",
@@ -620,3 +653,44 @@ def test_run_with_tools_and_format(self, tools):
620653
message = ChatMessage.from_user("What's the weather in Paris?")
621654
with pytest.raises(ValueError):
622655
chat_generator.run([message])
656+
657+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
658+
def test_run_with_toolset(self, mock_client, tools):
659+
"""Test that the OllamaChatGenerator can run with a Toolset."""
660+
toolset = Toolset(tools)
661+
generator = OllamaChatGenerator(model="llama3", tools=toolset)
662+
663+
mock_response = ChatResponse(
664+
model="llama3",
665+
created_at="2023-12-12T14:13:43.416799Z",
666+
message={
667+
"role": "assistant",
668+
"content": "",
669+
"tool_calls": [
670+
{
671+
"function": {
672+
"name": "weather",
673+
"arguments": {"city": "Paris"},
674+
}
675+
}
676+
],
677+
},
678+
done=True,
679+
total_duration=5191566416,
680+
load_duration=2154458,
681+
prompt_eval_count=26,
682+
prompt_eval_duration=383809000,
683+
eval_count=298,
684+
eval_duration=4799921000,
685+
)
686+
687+
mock_client_instance = mock_client.return_value
688+
mock_client_instance.chat.return_value = mock_response
689+
690+
result = generator.run(messages=[ChatMessage.from_user("What's the weather in Paris?")])
691+
692+
mock_client_instance.chat.assert_called_once()
693+
assert "replies" in result
694+
assert len(result["replies"]) == 1
695+
assert result["replies"][0].tool_call.tool_name == "weather"
696+
assert result["replies"][0].tool_call.arguments == {"city": "Paris"}

0 commit comments

Comments
 (0)