Skip to content

feat: OllamaChatGenerator - add Toolset support #1765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace
from haystack.tools import (
Tool,
_check_duplicate_tool_names,
deserialize_tools_or_toolset_inplace,
serialize_tools_or_toolset,
)
from haystack.tools.toolset import Toolset
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from pydantic.json_schema import JsonSchemaValue

Expand Down Expand Up @@ -151,7 +157,7 @@ def __init__(
timeout: int = 120,
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None,
):
"""
Expand All @@ -177,7 +183,8 @@ def __init__(
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param tools:
A list of tools for which the model can prepare calls.
A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
Not all models support tools. For a list of models compatible with tools, see the
[models page](https://ollama.com/search?c=tools).
:param response_format:
Expand Down Expand Up @@ -207,7 +214,7 @@ def to_dict(self) -> Dict[str, Any]:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None

return default_to_dict(
self,
model=self.model,
Expand All @@ -216,7 +223,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
streaming_callback=callback_name,
tools=serialized_tools,
tools=serialize_tools_or_toolset(self.tools),
response_format=self.response_format,
)

Expand Down Expand Up @@ -280,7 +287,7 @@ def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
*,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
Expand All @@ -294,7 +301,8 @@ def run(
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance. If set, it will override the `tools` parameter set
during component initialization.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
Expand All @@ -320,6 +328,10 @@ def run(
msg = "Ollama does not support streaming and response_format at the same time. Please choose one."
raise ValueError(msg)

# Convert toolset to list of tools if needed
if isinstance(tools, Toolset):
tools = list(tools)

ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None

ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages]
Expand Down
76 changes: 75 additions & 1 deletion integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolCall,
)
from haystack.tools import Tool
from haystack.tools.toolset import Toolset
from ollama._types import ChatResponse, ResponseError

from haystack_integrations.components.generators.ollama.chat.chat_generator import (
Expand All @@ -20,6 +21,10 @@
)


def get_weather(city: str) -> str:
return f"The weather in {city} is sunny"


@pytest.fixture
def tools():
tool_parameters = {
Expand All @@ -31,7 +36,7 @@ def tools():
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
function=get_weather,
)

return [tool]
Expand Down Expand Up @@ -212,6 +217,34 @@ def test_init_fail_with_duplicate_tool_names(self, tools):
with pytest.raises(ValueError):
OllamaChatGenerator(tools=duplicate_tools)

def test_init_with_toolset(self, tools):
"""Test that the OllamaChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = OllamaChatGenerator(model="llama3", tools=toolset)
assert generator.tools == toolset

def test_to_dict_with_toolset(self, tools):
"""Test that the OllamaChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = OllamaChatGenerator(model="llama3", tools=toolset)
data = generator.to_dict()

assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset"
assert "tools" in data["init_parameters"]["tools"]["data"]
assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools)

def test_from_dict_with_toolset(self, tools):
"""Test that the OllamaChatGenerator can be deserialized from a dictionary with a Toolset."""
toolset = Toolset(tools)
component = OllamaChatGenerator(model="llama3", tools=toolset)
data = component.to_dict()

deserialized_component = OllamaChatGenerator.from_dict(data)

assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)

def test_to_dict(self):
tool = Tool(
name="name",
Expand Down Expand Up @@ -620,3 +653,44 @@ def test_run_with_tools_and_format(self, tools):
message = ChatMessage.from_user("What's the weather in Paris?")
with pytest.raises(ValueError):
chat_generator.run([message])

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_with_toolset(self, mock_client, tools):
"""Test that the OllamaChatGenerator can run with a Toolset."""
toolset = Toolset(tools)
generator = OllamaChatGenerator(model="llama3", tools=toolset)

mock_response = ChatResponse(
model="llama3",
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "weather",
"arguments": {"city": "Paris"},
}
}
],
},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

mock_client_instance = mock_client.return_value
mock_client_instance.chat.return_value = mock_response

result = generator.run(messages=[ChatMessage.from_user("What's the weather in Paris?")])

mock_client_instance.chat.assert_called_once()
assert "replies" in result
assert len(result["replies"]) == 1
assert result["replies"][0].tool_call.tool_name == "weather"
assert result["replies"][0].tool_call.arguments == {"city": "Paris"}