Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Tuple,
Type,
Union,
cast,
)

from ollama import AsyncClient, Client
Expand All @@ -33,6 +34,7 @@
MessageRole,
TextBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
Expand All @@ -58,9 +60,15 @@ def get_additional_kwargs(


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", []) or []
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]


class Ollama(FunctionCallingLLM):
Expand Down Expand Up @@ -223,6 +231,7 @@ def get_context_window(self) -> int:

def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict:
ollama_messages = []
unique_tool_calls = []
for message in messages:
cur_ollama_message = {
"role": message.role.value,
Expand All @@ -240,13 +249,47 @@ def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict:
elif isinstance(block, ThinkingBlock):
if block.content:
cur_ollama_message["thinking"] = block.content
elif isinstance(block, ToolCallBlock):
if "tool_calls" not in cur_ollama_message:
cur_ollama_message["tool_calls"] = [
{
"function": {
"name": block.tool_name,
"arguments": block.tool_kwargs,
}
}
]
else:
cur_ollama_message["tool_calls"].extend(
[
{
"function": {
"name": block.tool_name,
"arguments": block.tool_kwargs,
}
}
]
)
unique_tool_calls.append((block.tool_name, str(block.tool_kwargs)))
else:
raise ValueError(f"Unsupported block type: {type(block)}")

# keep this code for compatibility with older chat histories
if "tool_calls" in message.additional_kwargs:
cur_ollama_message["tool_calls"] = message.additional_kwargs[
"tool_calls"
]
if (
"tool_calls" not in cur_ollama_message
or cur_ollama_message["tool_calls"] == []
):
cur_ollama_message["tool_calls"] = message.additional_kwargs[
"tool_calls"
]
else:
for tool_call in message.additional_kwargs["tool_calls"]:
if (
tool_call.get("name", ""),
str(tool_call.get("arguments", {})),
) not in unique_tool_calls:
cur_ollama_message["tool_calls"].append(tool_call)

ollama_messages.append(cur_ollama_message)

Expand Down Expand Up @@ -312,7 +355,11 @@ def get_tool_calls_from_response(
error_on_no_tool_call: bool = True,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", []) or []
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]
if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
Expand All @@ -323,14 +370,14 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
argument_dict = tool_call["function"]["arguments"]
argument_dict = tool_call.tool_kwargs

tool_selections.append(
ToolSelection(
# tool ids not provided by Ollama
tool_id=tool_call["function"]["name"],
tool_name=tool_call["function"]["name"],
tool_kwargs=argument_dict,
tool_id=tool_call.tool_name,
tool_name=tool_call.tool_name,
tool_kwargs=cast(Dict[str, Any], argument_dict),
)
)

Expand All @@ -357,14 +404,21 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:

response = dict(response)

blocks: List[TextBlock | ThinkingBlock] = []
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []

tool_calls = response["message"].get("tool_calls", []) or []
thinking = response["message"].get("thinking", None)
if thinking:
blocks.append(ThinkingBlock(content=thinking))
blocks.append(TextBlock(text=response["message"].get("content", "")))

if tool_calls:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_name=str(tool_call.get("function", {}).get("name", "")),
tool_kwargs=tool_call.get("function", {}).get("arguments", {}),
)
)
token_counts = self._get_response_token_counts(response)
if token_counts:
response["usage"] = token_counts
Expand All @@ -373,7 +427,6 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
message=ChatMessage(
blocks=blocks,
role=response["message"].get("role", MessageRole.ASSISTANT),
additional_kwargs={"tool_calls": tool_calls},
),
raw=response,
)
Expand Down Expand Up @@ -432,17 +485,26 @@ def gen() -> ChatResponseGen:
if token_counts:
r["usage"] = token_counts

output_blocks = [TextBlock(text=response_txt)]
output_blocks: List[ToolCallBlock | ThinkingBlock | TextBlock] = [
TextBlock(text=response_txt)
]
if thinking_txt:
output_blocks.insert(0, ThinkingBlock(content=thinking_txt))
if all_tool_calls:
for tool_call in all_tool_calls:
output_blocks.append(
ToolCallBlock(
tool_name=tool_call.get("function", {}).get("name", ""),
tool_kwargs=tool_call.get("function", {}).get(
"arguments", {}
),
)
)

yield ChatResponse(
message=ChatMessage(
blocks=output_blocks,
role=r["message"].get("role", MessageRole.ASSISTANT),
additional_kwargs={
"tool_calls": all_tool_calls,
},
),
delta=r["message"].get("content", ""),
raw=r,
Expand Down Expand Up @@ -507,17 +569,26 @@ async def gen() -> ChatResponseAsyncGen:
if token_counts:
r["usage"] = token_counts

output_blocks = [TextBlock(text=response_txt)]
output_blocks: List[ThinkingBlock | ToolCallBlock | TextBlock] = [
TextBlock(text=response_txt)
]
if thinking_txt:
output_blocks.insert(0, ThinkingBlock(content=thinking_txt))
if all_tool_calls:
for tool_call in all_tool_calls:
output_blocks.append(
ToolCallBlock(
tool_name=tool_call.get("function", {}).get("name", ""),
tool_kwargs=tool_call.get("function", {}).get(
"arguments", {}
),
)
)

yield ChatResponse(
message=ChatMessage(
blocks=output_blocks,
role=r["message"].get("role", MessageRole.ASSISTANT),
additional_kwargs={
"tool_calls": all_tool_calls,
},
),
delta=r["message"].get("content", ""),
raw=r,
Expand Down Expand Up @@ -551,13 +622,21 @@ async def achat(

response = dict(response)

blocks: List[TextBlock | ThinkingBlock] = []
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []

tool_calls = response["message"].get("tool_calls", []) or []
thinking = response["message"].get("thinking", None)
if thinking:
blocks.append(ThinkingBlock(content=thinking))
blocks.append(TextBlock(text=response["message"].get("content", "")))
if tool_calls:
for tool_call in tool_calls:
blocks.append(
ToolCallBlock(
tool_name=tool_call.get("function", {}).get("name", ""),
tool_kwargs=tool_call.get("function", {}).get("arguments", {}),
)
)
token_counts = self._get_response_token_counts(response)
if token_counts:
response["usage"] = token_counts
Expand All @@ -566,7 +645,6 @@ async def achat(
message=ChatMessage(
blocks=blocks,
role=response["message"].get("role", MessageRole.ASSISTANT),
additional_kwargs={"tool_calls": tool_calls},
),
raw=response,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ dev = [

[project]
name = "llama-index-llms-ollama"
version = "0.8.0"
version = "0.9.0"
description = "llama-index llms ollama integration"
authors = [{name = "Your Name", email = "you@example.com"}]
requires-python = ">=3.9,<4.0"
readme = "README.md"
license = "MIT"
dependencies = [
"ollama>=0.5.1",
"llama-index-core>=0.14.3,<0.15",
"llama-index-core>=0.14.5,<0.15",
]

[tool.codespell]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ollama import Client
from typing import Annotated

from llama_index.core.base.llms.types import ThinkingBlock, TextBlock
from llama_index.core.base.llms.types import ThinkingBlock, TextBlock, ToolCallBlock
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.llms import ChatMessage
Expand Down Expand Up @@ -358,4 +358,81 @@ async def test_async_chat_with_tools_returns_empty_array_if_no_tools_were_called
ChatMessage(role="user", content="Hello, how are you?"),
],
)
assert response.message.additional_kwargs.get("tool_calls", []) == []
assert (
len(
[
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]
)
== 0
)


@pytest.mark.skipif(
client is None, reason="Ollama client is not available or test model is missing"
)
@pytest.mark.asyncio
async def test_chat_methods_with_tool_input() -> None:
llm = Ollama(model=thinking_test_model)
input_messages = [
ChatMessage(
role="user",
content="Hello, can you tell me what is the weather today in London?",
),
ChatMessage(
role="assistant",
blocks=[
ThinkingBlock(
content="The user is asking for the weather in London, so I should use the get_weather tool"
),
ToolCallBlock(
tool_name="get_weather_tool", tool_kwargs={"location": "London"}
),
TextBlock(
text="The weather in London is rainy with a temperature of 15°C."
),
],
),
ChatMessage(
role="user",
content="Can you tell me what input did you give to the 'get_weather' tool? (do not call any other tool)",
),
]
response = llm.chat(messages=input_messages)
assert response.message.content is not None
assert (
len(
[
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]
)
== 0
)
aresponse = await llm.achat(messages=input_messages)
assert aresponse.message.content is not None
assert (
len(
[
block
for block in aresponse.message.blocks
if isinstance(block, ToolCallBlock)
]
)
== 0
)
response_stream = llm.stream_chat(messages=input_messages)
blocks = []
for r in response_stream:
blocks.extend(r.message.blocks)
assert len([block for block in blocks if isinstance(block, TextBlock)]) > 0
assert len([block for block in blocks if isinstance(block, ToolCallBlock)]) == 0
aresponse_stream = await llm.astream_chat(messages=input_messages)
ablocks = []
async for r in aresponse_stream:
ablocks.extend(r.message.blocks)
assert len([block for block in ablocks if isinstance(block, TextBlock)]) > 0
assert len([block for block in ablocks if isinstance(block, ToolCallBlock)]) == 0
10 changes: 5 additions & 5 deletions llama-index-integrations/llms/llama-index-llms-ollama/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading