Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit 3993c68

Browse files
authored
FEATURE: anthropic function calling (#654)
Adds support for native tool calling (both streaming and non streaming) for Anthropic. This improves general tool support on the Anthropic models.
1 parent 564d2de commit 3993c68

File tree

9 files changed

+517
-242
lines changed

9 files changed

+517
-242
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# frozen_string_literal: true
2+
3+
class DiscourseAi::Completions::AnthropicMessageProcessor
4+
class AnthropicToolCall
5+
attr_reader :name, :raw_json, :id
6+
7+
def initialize(name, id)
8+
@name = name
9+
@id = id
10+
@raw_json = +""
11+
end
12+
13+
def append(json)
14+
@raw_json << json
15+
end
16+
end
17+
18+
attr_reader :tool_calls, :input_tokens, :output_tokens
19+
20+
def initialize(streaming_mode:)
21+
@streaming_mode = streaming_mode
22+
@tool_calls = []
23+
end
24+
25+
def to_xml_tool_calls(function_buffer)
26+
return function_buffer if @tool_calls.blank?
27+
28+
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
29+
<function_calls>
30+
</function_calls>
31+
TEXT
32+
33+
@tool_calls.each do |tool_call|
34+
node =
35+
function_buffer.at("function_calls").add_child(
36+
Nokogiri::HTML5::DocumentFragment.parse(
37+
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
38+
),
39+
)
40+
41+
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
42+
xml = params.map { |name, value| "<#{name}>#{value}</#{name}>" }.join("\n")
43+
44+
node.at("tool_name").content = tool_call.name
45+
node.at("tool_id").content = tool_call.id
46+
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
47+
end
48+
49+
function_buffer
50+
end
51+
52+
def process_message(payload)
53+
result = ""
54+
parsed = JSON.parse(payload, symbolize_names: true)
55+
56+
if @streaming_mode
57+
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
58+
tool_name = parsed.dig(:content_block, :name)
59+
tool_id = parsed.dig(:content_block, :id)
60+
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
61+
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
62+
if @tool_calls.present?
63+
result = parsed.dig(:delta, :partial_json).to_s
64+
@tool_calls.last.append(result)
65+
else
66+
result = parsed.dig(:delta, :text).to_s
67+
end
68+
elsif parsed[:type] == "message_start"
69+
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
70+
elsif parsed[:type] == "message_delta"
71+
@output_tokens =
72+
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
73+
elsif parsed[:type] == "message_stop"
74+
# bedrock has this ...
75+
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
76+
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
77+
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
78+
end
79+
end
80+
else
81+
content = parsed.dig(:content)
82+
if content.is_a?(Array)
83+
tool_call = content.find { |c| c[:type] == "tool_use" }
84+
if tool_call
85+
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
86+
@tool_calls.last.append(tool_call[:input].to_json)
87+
else
88+
result = parsed.dig(:content, 0, :text).to_s
89+
end
90+
end
91+
92+
@input_tokens = parsed.dig(:usage, :input_tokens)
93+
@output_tokens = parsed.dig(:usage, :output_tokens)
94+
end
95+
96+
result
97+
end
98+
end

lib/completions/dialects/claude.rb

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ def can_translate?(model_name)
1515
class ClaudePrompt
1616
attr_reader :system_prompt
1717
attr_reader :messages
18+
attr_reader :tools
1819

19-
def initialize(system_prompt, messages)
20+
def initialize(system_prompt, messages, tools)
2021
@system_prompt = system_prompt
2122
@messages = messages
23+
@tools = tools
2224
end
2325
end
2426

@@ -46,7 +48,11 @@ def translate
4648
previous_message = message
4749
end
4850

49-
ClaudePrompt.new(system_prompt.presence, interleving_messages)
51+
ClaudePrompt.new(
52+
system_prompt.presence,
53+
interleving_messages,
54+
tools_dialect.translated_tools,
55+
)
5056
end
5157

5258
def max_prompt_tokens
@@ -58,6 +64,18 @@ def max_prompt_tokens
5864

5965
private
6066

67+
def tools_dialect
68+
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
69+
end
70+
71+
def tool_call_msg(msg)
72+
tools_dialect.from_raw_tool_call(msg)
73+
end
74+
75+
def tool_msg(msg)
76+
tools_dialect.from_raw_tool(msg)
77+
end
78+
6179
def model_msg(msg)
6280
{ role: "assistant", content: msg[:content] }
6381
end
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
module Dialects
6+
class ClaudeTools
7+
def initialize(tools)
8+
@raw_tools = tools
9+
end
10+
11+
def translated_tools
12+
# Transform the raw tools into the required Anthropic Claude API format
13+
raw_tools.map do |t|
14+
properties = {}
15+
required = []
16+
17+
if t[:parameters]
18+
properties =
19+
t[:parameters].each_with_object({}) do |param, h|
20+
h[param[:name]] = {
21+
type: param[:type],
22+
description: param[:description],
23+
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
24+
end
25+
required =
26+
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
27+
end
28+
29+
{
30+
name: t[:name],
31+
description: t[:description],
32+
input_schema: {
33+
type: "object",
34+
properties: properties,
35+
required: required,
36+
},
37+
}
38+
end
39+
end
40+
41+
def instructions
42+
"" # Noop. Tools are listed separate.
43+
end
44+
45+
def from_raw_tool_call(raw_message)
46+
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
47+
tool_call_id = raw_message[:id]
48+
49+
{
50+
role: "assistant",
51+
content: [
52+
{
53+
type: "tool_use",
54+
id: tool_call_id,
55+
name: raw_message[:name],
56+
input: call_details[:arguments],
57+
},
58+
],
59+
}
60+
end
61+
62+
def from_raw_tool(raw_message)
63+
{
64+
role: "user",
65+
content: [
66+
{
67+
type: "tool_result",
68+
tool_use_id: raw_message[:id],
69+
content: raw_message[:content],
70+
},
71+
],
72+
}
73+
end
74+
75+
private
76+
77+
attr_reader :raw_tools
78+
end
79+
end
80+
end
81+
end

lib/completions/endpoints/anthropic.rb

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ def default_options(dialect)
4545
raise "Unsupported model: #{model}"
4646
end
4747

48-
options = { model: mapped_model, max_tokens: 3_000 }
49-
50-
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
51-
options
48+
{ model: mapped_model, max_tokens: 3_000 }
5249
end
5350

5451
def provider_id
@@ -73,6 +70,7 @@ def prepare_payload(prompt, model_params, dialect)
7370

7471
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
7572
payload[:stream] = true if @streaming_mode
73+
payload[:tools] = prompt.tools if prompt.tools.present?
7674

7775
payload
7876
end
@@ -87,30 +85,30 @@ def prepare_request(payload)
8785
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
8886
end
8987

90-
def final_log_update(log)
91-
log.request_tokens = @input_tokens if @input_tokens
92-
log.response_tokens = @output_tokens if @output_tokens
88+
def processor
89+
@processor ||=
90+
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
91+
end
92+
93+
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
94+
processor.to_xml_tool_calls(function_buffer) if !partial
9395
end
9496

9597
def extract_completion_from(response_raw)
96-
result = ""
97-
parsed = JSON.parse(response_raw, symbolize_names: true)
98-
99-
if @streaming_mode
100-
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
101-
result = parsed.dig(:delta, :text).to_s
102-
elsif parsed[:type] == "message_start"
103-
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
104-
elsif parsed[:type] == "message_delta"
105-
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
106-
end
107-
else
108-
result = parsed.dig(:content, 0, :text).to_s
109-
@input_tokens = parsed.dig(:usage, :input_tokens)
110-
@output_tokens = parsed.dig(:usage, :output_tokens)
111-
end
98+
processor.process_message(response_raw)
99+
end
100+
101+
def has_tool?(_response_data)
102+
processor.tool_calls.present?
103+
end
104+
105+
def final_log_update(log)
106+
log.request_tokens = processor.input_tokens if processor.input_tokens
107+
log.response_tokens = processor.output_tokens if processor.output_tokens
108+
end
112109

113-
result
110+
def native_tool_support?
111+
true
114112
end
115113

116114
def partials_from(decoded_chunk)

lib/completions/endpoints/aws_bedrock.rb

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def normalize_model_params(model_params)
3636

3737
def default_options(dialect)
3838
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
39-
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
4039
options
4140
end
4241

@@ -82,6 +81,8 @@ def model_uri
8281
def prepare_payload(prompt, model_params, dialect)
8382
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
8483
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
84+
payload[:tools] = prompt.tools if prompt.tools.present?
85+
8586
payload
8687
end
8788

@@ -142,35 +143,35 @@ def decode(chunk)
142143
end
143144

144145
def final_log_update(log)
145-
log.request_tokens = @input_tokens if @input_tokens
146-
log.response_tokens = @output_tokens if @output_tokens
146+
log.request_tokens = processor.input_tokens if processor.input_tokens
147+
log.response_tokens = processor.output_tokens if processor.output_tokens
148+
end
149+
150+
def processor
151+
@processor ||=
152+
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
153+
end
154+
155+
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
156+
processor.to_xml_tool_calls(function_buffer) if !partial
147157
end
148158

149159
def extract_completion_from(response_raw)
150-
result = ""
151-
parsed = JSON.parse(response_raw, symbolize_names: true)
152-
153-
if @streaming_mode
154-
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
155-
result = parsed.dig(:delta, :text).to_s
156-
elsif parsed[:type] == "message_start"
157-
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
158-
elsif parsed[:type] == "message_delta"
159-
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
160-
end
161-
else
162-
result = parsed.dig(:content, 0, :text).to_s
163-
@input_tokens = parsed.dig(:usage, :input_tokens)
164-
@output_tokens = parsed.dig(:usage, :output_tokens)
165-
end
160+
processor.process_message(response_raw)
161+
end
166162

167-
result
163+
def has_tool?(_response_data)
164+
processor.tool_calls.present?
168165
end
169166

170167
def partials_from(decoded_chunks)
171168
decoded_chunks
172169
end
173170

171+
def native_tool_support?
172+
true
173+
end
174+
174175
def chunk_to_string(chunk)
175176
joined = +chunk.join("\n")
176177
joined << "\n" if joined.length > 0

0 commit comments

Comments
 (0)