diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py index 9df88a3b5a..a75c9e98db 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py @@ -10,6 +10,7 @@ Tuple, Type, Union, + cast, ) from ollama import AsyncClient, Client @@ -33,6 +34,7 @@ MessageRole, TextBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS @@ -58,9 +60,15 @@ def get_additional_kwargs( def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) or [] + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class Ollama(FunctionCallingLLM): @@ -223,6 +231,7 @@ def get_context_window(self) -> int: def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict: ollama_messages = [] + unique_tool_calls = [] for message in messages: cur_ollama_message = { "role": message.role.value, @@ -240,13 +249,47 @@ def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict: elif isinstance(block, ThinkingBlock): if block.content: cur_ollama_message["thinking"] = block.content + elif isinstance(block, ToolCallBlock): + if "tool_calls" not in cur_ollama_message: + cur_ollama_message["tool_calls"] = [ + { + "function": { + "name": block.tool_name, + "arguments": block.tool_kwargs, + } + } + ] + else: + cur_ollama_message["tool_calls"].extend( + [ + { + "function": { + "name": block.tool_name, + "arguments": block.tool_kwargs, + } + } + ] + ) + unique_tool_calls.append((block.tool_name, str(block.tool_kwargs))) else: raise ValueError(f"Unsupported block type: {type(block)}") + # keep this code for compatibility with older chat histories if "tool_calls" in message.additional_kwargs: - cur_ollama_message["tool_calls"] = message.additional_kwargs[ - "tool_calls" - ] + if ( + "tool_calls" not in cur_ollama_message + or cur_ollama_message["tool_calls"] == [] + ): + cur_ollama_message["tool_calls"] = message.additional_kwargs[ + "tool_calls" + ] + else: + for tool_call in message.additional_kwargs["tool_calls"]: + if ( + tool_call.get("name", ""), + str(tool_call.get("arguments", {})), + ) not in unique_tool_calls: + cur_ollama_message["tool_calls"].append(tool_call) ollama_messages.append(cur_ollama_message) @@ -312,7 +355,11 @@ def get_tool_calls_from_response( error_on_no_tool_call: bool = True, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) or [] + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: raise ValueError( @@ -323,14 +370,14 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - argument_dict = tool_call["function"]["arguments"] + argument_dict = tool_call.tool_kwargs tool_selections.append( ToolSelection( # tool ids not provided by Ollama - tool_id=tool_call["function"]["name"], - tool_name=tool_call["function"]["name"], - tool_kwargs=argument_dict, + tool_id=tool_call.tool_name, + tool_name=tool_call.tool_name, + tool_kwargs=cast(Dict[str, Any], argument_dict), ) ) @@ -357,14 +404,21 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: response = dict(response) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] tool_calls = response["message"].get("tool_calls", []) or [] thinking = response["message"].get("thinking", None) if thinking: blocks.append(ThinkingBlock(content=thinking)) blocks.append(TextBlock(text=response["message"].get("content", ""))) - + if tool_calls: + for tool_call in tool_calls: + blocks.append( + ToolCallBlock( + tool_name=str(tool_call.get("function", {}).get("name", "")), + tool_kwargs=tool_call.get("function", {}).get("arguments", {}), + ) + ) token_counts = self._get_response_token_counts(response) if token_counts: response["usage"] = token_counts @@ -373,7 +427,6 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: message=ChatMessage( blocks=blocks, role=response["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={"tool_calls": tool_calls}, ), raw=response, ) @@ -432,17 +485,26 @@ def gen() -> ChatResponseGen: if token_counts: r["usage"] = token_counts - output_blocks = [TextBlock(text=response_txt)] + output_blocks: List[ToolCallBlock | ThinkingBlock | TextBlock] = [ + TextBlock(text=response_txt) + ] if thinking_txt: output_blocks.insert(0, ThinkingBlock(content=thinking_txt)) + if all_tool_calls: + for tool_call in all_tool_calls: + output_blocks.append( + ToolCallBlock( + tool_name=tool_call.get("function", {}).get("name", ""), + tool_kwargs=tool_call.get("function", {}).get( + "arguments", {} + ), + ) + ) yield ChatResponse( message=ChatMessage( blocks=output_blocks, role=r["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={ - "tool_calls": all_tool_calls, - }, ), delta=r["message"].get("content", ""), raw=r, @@ -507,17 +569,26 @@ async def gen() -> ChatResponseAsyncGen: if token_counts: r["usage"] = token_counts - output_blocks = [TextBlock(text=response_txt)] + output_blocks: List[ThinkingBlock | ToolCallBlock | TextBlock] = [ + TextBlock(text=response_txt) + ] if thinking_txt: output_blocks.insert(0, ThinkingBlock(content=thinking_txt)) + if all_tool_calls: + for tool_call in all_tool_calls: + output_blocks.append( + ToolCallBlock( + tool_name=tool_call.get("function", {}).get("name", ""), + tool_kwargs=tool_call.get("function", {}).get( + "arguments", {} + ), + ) + ) yield ChatResponse( message=ChatMessage( blocks=output_blocks, role=r["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={ - "tool_calls": all_tool_calls, - }, ), delta=r["message"].get("content", ""), raw=r, @@ -551,13 +622,21 @@ async def achat( response = dict(response) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] tool_calls = response["message"].get("tool_calls", []) or [] thinking = response["message"].get("thinking", None) if thinking: blocks.append(ThinkingBlock(content=thinking)) blocks.append(TextBlock(text=response["message"].get("content", ""))) + if tool_calls: + for tool_call in tool_calls: + blocks.append( + ToolCallBlock( + tool_name=tool_call.get("function", {}).get("name", ""), + tool_kwargs=tool_call.get("function", {}).get("arguments", {}), + ) + ) token_counts = self._get_response_token_counts(response) if token_counts: response["usage"] = token_counts @@ -566,7 +645,6 @@ async def achat( message=ChatMessage( blocks=blocks, role=response["message"].get("role", MessageRole.ASSISTANT), - additional_kwargs={"tool_calls": tool_calls}, ), raw=response, ) diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml index b567a95299..5ae638183f 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml @@ -27,7 +27,7 @@ dev = [ [project] name = "llama-index-llms-ollama" -version = "0.8.0" +version = "0.9.0" description = "llama-index llms ollama integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" @@ -35,7 +35,7 @@ readme = "README.md" license = "MIT" dependencies = [ "ollama>=0.5.1", - "llama-index-core>=0.14.3,<0.15", + "llama-index-core>=0.14.5,<0.15", ] [tool.codespell] diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py index 600bb4888c..a50d2241c9 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_llms_ollama.py @@ -4,7 +4,7 @@ from ollama import Client from typing import Annotated -from llama_index.core.base.llms.types import ThinkingBlock, TextBlock +from llama_index.core.base.llms.types import ThinkingBlock, TextBlock, ToolCallBlock from llama_index.core.base.llms.base import BaseLLM from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.llms import ChatMessage @@ -358,4 +358,81 @@ async def test_async_chat_with_tools_returns_empty_array_if_no_tools_were_called ChatMessage(role="user", content="Hello, how are you?"), ], ) - assert response.message.additional_kwargs.get("tool_calls", []) == [] + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) + + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing" +) +@pytest.mark.asyncio +async def test_chat_methods_with_tool_input() -> None: + llm = Ollama(model=thinking_test_model) + input_messages = [ + ChatMessage( + role="user", + content="Hello, can you tell me what is the weather today in London?", + ), + ChatMessage( + role="assistant", + blocks=[ + ThinkingBlock( + content="The user is asking for the weather in London, so I should use the get_weather tool" + ), + ToolCallBlock( + tool_name="get_weather_tool", tool_kwargs={"location": "London"} + ), + TextBlock( + text="The weather in London is rainy with a temperature of 15°C." + ), + ], + ), + ChatMessage( + role="user", + content="Can you tell me what input did you give to the 'get_weather' tool? (do not call any other tool)", + ), + ] + response = llm.chat(messages=input_messages) + assert response.message.content is not None + assert ( + len( + [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) + aresponse = await llm.achat(messages=input_messages) + assert aresponse.message.content is not None + assert ( + len( + [ + block + for block in aresponse.message.blocks + if isinstance(block, ToolCallBlock) + ] + ) + == 0 + ) + response_stream = llm.stream_chat(messages=input_messages) + blocks = [] + for r in response_stream: + blocks.extend(r.message.blocks) + assert len([block for block in blocks if isinstance(block, TextBlock)]) > 0 + assert len([block for block in blocks if isinstance(block, ToolCallBlock)]) == 0 + aresponse_stream = await llm.astream_chat(messages=input_messages) + ablocks = [] + async for r in aresponse_stream: + ablocks.extend(r.message.blocks) + assert len([block for block in ablocks if isinstance(block, TextBlock)]) > 0 + assert len([block for block in ablocks if isinstance(block, ToolCallBlock)]) == 0 diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/uv.lock b/llama-index-integrations/llms/llama-index-llms-ollama/uv.lock index a28fae7b26..fb3b2925c4 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/uv.lock +++ b/llama-index-integrations/llms/llama-index-llms-ollama/uv.lock @@ -1584,7 +1584,7 @@ wheels = [ [[package]] name = "llama-index-core" -version = "0.14.3" +version = "0.14.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1618,9 +1618,9 @@ dependencies = [ { name = "typing-inspect" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c5/e4/6a4ab9465b66c9d31b74ed0221293aeebe9072ec9db3b3b229f96028af78/llama_index_core-0.14.3.tar.gz", hash = "sha256:ca8a473ac92fe54f2849175f6510655999852c83fa8b7d75fd3908a8863da05a", size = 11577791, upload-time = "2025-09-24T18:21:03.653Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/42/e1de7d6a390dcd67b0754fd24e0d0acb56c1d0838a68e30671dd79fd5521/llama_index_core-0.14.5.tar.gz", hash = "sha256:913ebc3ad895d381eaab0f10dc405101c5bec5a70c09909ef2493ddc115f8552", size = 11578206, upload-time = "2025-10-15T19:10:09.746Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b0/5a/de1002b10109a0dfa122ba84a3b640124cf2418a78e00ac0b382574f2b3f/llama_index_core-0.14.3-py3-none-any.whl", hash = "sha256:fc4291fbae8c6609e3367da39a85a453099476685d5a3e97b766d82d4bcce5a4", size = 11918952, upload-time = "2025-09-24T18:21:00.744Z" }, + { url = "https://files.pythonhosted.org/packages/0f/64/c02576991efcefd30a65971e87ece7494d6bbf3739b7bffeeb56c86b5a76/llama_index_core-0.14.5-py3-none-any.whl", hash = "sha256:5445aa322b83a9d48baa608c3b920df4f434ed5d461a61e6bccb36d99348bddf", size = 11919461, upload-time = "2025-10-15T19:10:06.92Z" }, ] [[package]] @@ -1638,7 +1638,7 @@ wheels = [ [[package]] name = "llama-index-llms-ollama" -version = "0.8.0" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "llama-index-core" }, @@ -1671,7 +1671,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "llama-index-core", specifier = ">=0.14.3,<0.15" }, + { name = "llama-index-core", specifier = ">=0.14.5,<0.15" }, { name = "ollama", specifier = ">=0.5.1" }, ]