Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
Set,
Tuple,
Union,
cast,
)

from llama_index.core.llms.utils import parse_partial_json
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
Expand All @@ -23,6 +24,7 @@
LLMMetadata,
MessageRole,
ContentBlock,
ToolCallBlock,
)
from llama_index.core.base.llms.types import TextBlock as LITextBlock
from llama_index.core.base.llms.types import CitationBlock as LICitationBlock
Expand All @@ -35,7 +37,6 @@
llm_completion_callback,
)
from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection
from llama_index.core.llms.utils import parse_partial_json
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.core.utils import Tokenizer
from llama_index.llms.anthropic.utils import (
Expand All @@ -44,6 +45,7 @@
is_anthropic_prompt_caching_supported_model,
is_function_calling_model,
messages_to_anthropic_messages,
update_tool_calls,
)

import anthropic
Expand Down Expand Up @@ -351,8 +353,7 @@ def _completion_response_from_chat_response(

def _get_blocks_and_tool_calls_and_thinking(
self, response: Any
) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]:
tool_calls = []
) -> Tuple[List[ContentBlock], List[Dict[str, Any]]]:
blocks: List[ContentBlock] = []
citations: List[TextCitation] = []
tracked_citations: Set[str] = set()
Expand Down Expand Up @@ -392,9 +393,15 @@ def _get_blocks_and_tool_calls_and_thinking(
)
)
elif isinstance(content_block, ToolUseBlock):
tool_calls.append(content_block.model_dump())
blocks.append(
ToolCallBlock(
tool_call_id=content_block.id,
tool_kwargs=cast(Dict[str, Any] | str, content_block.input),
tool_name=content_block.name,
)
)

return blocks, tool_calls, [x.model_dump() for x in citations]
return blocks, [x.model_dump() for x in citations]

@llm_chat_callback()
def chat(
Expand All @@ -412,17 +419,12 @@ def chat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -536,13 +538,18 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
tc = ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
)
update_tool_calls(content, tc)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand All @@ -560,13 +567,23 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
content.append(cur_block)
cur_block = None

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
tc = ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
)
update_tool_calls(content, tc)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta="",
Expand Down Expand Up @@ -604,17 +621,12 @@ async def achat(
**all_kwargs,
)

blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
response
)
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)

return AnthropicChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs={
"tool_calls": tool_calls,
},
),
citations=citations,
raw=dict(response),
Expand Down Expand Up @@ -728,13 +740,18 @@ async def gen() -> ChatResponseAsyncGen:
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
tc = ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
)
update_tool_calls(content, tc)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta=content_delta,
Expand All @@ -752,13 +769,23 @@ async def gen() -> ChatResponseAsyncGen:
content.append(cur_block)
cur_block = None

if cur_tool_call is not None:
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
else:
tool_calls_to_send = cur_tool_calls

for tool_call in tool_calls_to_send:
tc = ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
)
update_tool_calls(content, tc)

yield AnthropicChatResponse(
message=ChatMessage(
role=role,
blocks=content,
additional_kwargs={
"tool_calls": [t.dict() for t in tool_calls_to_send]
},
),
citations=cur_citations,
delta="",
Expand Down Expand Up @@ -867,7 +894,11 @@ def get_tool_calls_from_response(
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]

if len(tool_calls) < 1:
if error_on_no_tool_call:
Expand All @@ -879,24 +910,16 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
if (
"input" not in tool_call
or "id" not in tool_call
or "name" not in tool_call
):
raise ValueError("Invalid tool call.")
if tool_call["type"] != "tool_use":
raise ValueError("Invalid tool type. Unsupported by Anthropic")
argument_dict = (
json.loads(tool_call["input"])
if isinstance(tool_call["input"], str)
else tool_call["input"]
json.loads(tool_call.tool_kwargs)
if isinstance(tool_call.tool_kwargs, str)
else tool_call.tool_kwargs
)

tool_selections.append(
ToolSelection(
tool_id=tool_call["id"],
tool_name=tool_call["name"],
tool_id=tool_call.tool_call_id or "",
tool_name=tool_call.tool_name,
tool_kwargs=argument_dict,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CitationBlock,
ThinkingBlock,
ContentBlock,
ToolCallBlock,
)

from anthropic.types import (
Expand All @@ -24,6 +25,7 @@
DocumentBlockParam,
ThinkingBlockParam,
ImageBlockParam,
ToolUseBlockParam,
CacheControlEphemeralParam,
Base64PDFSourceParam,
)
Expand Down Expand Up @@ -207,6 +209,7 @@ def blocks_to_anthropic_blocks(
) -> List[AnthropicContentBlock]:
anthropic_blocks: List[AnthropicContentBlock] = []
global_cache_control: Optional[CacheControlEphemeralParam] = None
unique_tool_calls = []

if kwargs.get("cache_control"):
global_cache_control = CacheControlEphemeralParam(**kwargs["cache_control"])
Expand Down Expand Up @@ -269,6 +272,19 @@ def blocks_to_anthropic_blocks(
if global_cache_control:
anthropic_blocks[-1]["cache_control"] = global_cache_control

elif isinstance(block, ToolCallBlock):
unique_tool_calls.append((block.tool_call_id, block.tool_name))
anthropic_blocks.append(
ToolUseBlockParam(
id=block.tool_call_id or "",
input=block.tool_kwargs,
name=block.tool_name,
type="tool_use",
)
)
if global_cache_control:
anthropic_blocks[-1]["cache_control"] = global_cache_control

elif isinstance(block, CachePoint):
if len(anthropic_blocks) > 0:
anthropic_blocks[-1]["cache_control"] = CacheControlEphemeralParam(
Expand All @@ -282,20 +298,25 @@ def blocks_to_anthropic_blocks(
else:
raise ValueError(f"Unsupported block type: {type(block)}")

# keep this code for compatibility with older chat histories
tool_calls = kwargs.get("tool_calls", [])
for tool_call in tool_calls:
assert "id" in tool_call
assert "input" in tool_call
assert "name" in tool_call

anthropic_blocks.append(
ToolUseBlockParam(
id=tool_call["id"],
input=tool_call["input"],
name=tool_call["name"],
type="tool_use",
)
)
try:
assert "id" in tool_call
assert "input" in tool_call
assert "name" in tool_call

if (tool_call["id"], tool_call["name"]) not in unique_tool_calls:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this check here to avoid duplicates

anthropic_blocks.append(
ToolUseBlockParam(
id=tool_call["id"],
input=tool_call["input"],
name=tool_call["name"],
type="tool_use",
)
)
except AssertionError:
continue

return anthropic_blocks

Expand Down Expand Up @@ -359,9 +380,15 @@ def messages_to_anthropic_messages(


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]


# Anthropic models that support prompt caching
Expand Down Expand Up @@ -400,6 +427,33 @@ def force_single_tool_call(response: ChatResponse) -> None:
)


def update_tool_calls(blocks: list[ContentBlock], tool_call: ToolCallBlock) -> None:
if len([block for block in blocks if isinstance(block, ToolCallBlock)]) == 0:
blocks.append(tool_call)
return
elif not any(
block.tool_call_id == tool_call.tool_call_id
for block in blocks
if isinstance(block, ToolCallBlock)
):
blocks.append(tool_call)
return
elif any(
block.tool_call_id == tool_call.tool_call_id
and block.tool_kwargs == tool_call.tool_kwargs
for block in blocks
if isinstance(block, ToolCallBlock)
):
return
else:
for i, block in enumerate(blocks):
if isinstance(block, ToolCallBlock):
if block.tool_call_id == tool_call.tool_call_id:
blocks[i] = tool_call
break
return
Comment on lines +430 to +454
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this logic because, since Anthropic streams partial JSON, tool calls with the same ID are streamed, but they might have different tool_kwargs. So:

  • We append the tool call if no other tool call is present/there is not tool call with the same ID
  • We skip if there is already a tool call with the same ID and arguments
  • We update the tool call if there is one with the same ID but it has different arguments

I added a test for this and I also tested end-to-end with the agent script, and now everything goes smooth (no duplicate calls, correct arguments parsing)



def is_anthropic_prompt_caching_supported_model(model: str) -> bool:
"""
Check if the given Anthropic model supports prompt caching.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ dev = [

[project]
name = "llama-index-llms-anthropic"
version = "0.9.7"
version = "0.10.0"
description = "llama-index llms anthropic integration"
authors = [{name = "Your Name", email = "you@example.com"}]
requires-python = ">=3.9,<4.0"
readme = "README.md"
license = "MIT"
dependencies = [
"anthropic[bedrock, vertex]>=0.69.0",
"llama-index-core>=0.14.3,<0.15",
"llama-index-core>=0.14.5,<0.15",
]

[tool.codespell]
Expand Down
Loading
Loading