Skip to content

Commit 9bff94b

Browse files
committed
FIX: coerce values for XML tool calls
1 parent 0534011 commit 9bff94b

File tree

6 files changed

+68
-31
lines changed

6 files changed

+68
-31
lines changed

lib/completions/endpoints/base.rb

+1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def perform_completion!(
166166
xml_tool_processor =
167167
XmlToolProcessor.new(
168168
partial_tool_calls: partial_tool_calls,
169+
tool_definitions: dialect.prompt.tools,
169170
) if xml_tools_enabled? && dialect.prompt.has_tools?
170171

171172
to_strip = xml_tags_to_strip(dialect)

lib/completions/tool_definition.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def coerce_parameters(params)
165165
return result if !params.is_a?(Hash)
166166

167167
@parameters.each do |param_def|
168-
param_name = param_def.name
168+
param_name = param_def.name.to_sym
169169

170170
# Skip if parameter is not provided and not required
171171
next if !params.key?(param_name) && !param_def.required

lib/completions/xml_tool_processor.rb

+11-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
module DiscourseAi
88
module Completions
99
class XmlToolProcessor
10-
def initialize(partial_tool_calls: false)
10+
def initialize(partial_tool_calls: false, tool_definitions: nil)
1111
@buffer = +""
1212
@function_buffer = +""
1313
@should_cancel = false
1414
@in_tool = false
1515
@partial_tool_calls = partial_tool_calls
1616
@partial_tools = [] if @partial_tool_calls
17+
@tool_definitions = tool_definitions
1718
end
1819

1920
def <<(text)
@@ -71,7 +72,7 @@ def finish
7172

7273
idx = -1
7374
parse_malformed_xml(@function_buffer).map do |tool|
74-
ToolCall.new(
75+
new_tool_call(
7576
id: "tool_#{idx += 1}",
7677
name: tool[:tool_name],
7778
parameters: tool[:parameters],
@@ -85,6 +86,13 @@ def should_cancel?
8586

8687
private
8788

89+
def new_tool_call(id:, name:, parameters:)
90+
if tool_def = @tool_definitions&.find { |d| d.name == name }
91+
parameters = tool_def.coerce_parameters(parameters)
92+
end
93+
ToolCall.new(id:, name:, parameters:)
94+
end
95+
8896
def add_to_function_buffer(text)
8997
@function_buffer << text
9098
detect_partial_tool_calls(@function_buffer, text) if @partial_tool_calls
@@ -119,7 +127,7 @@ def parse_partial_tool_call(buffer)
119127
current_tool = @partial_tools.last
120128
if !current_tool || current_tool.name != match[0].strip
121129
current_tool =
122-
ToolCall.new(
130+
new_tool_call(
123131
id: "tool_#{@partial_tools.length}",
124132
name: match[0].strip,
125133
parameters: params,

spec/lib/completions/endpoints/open_ai_spec.rb

+15-2
Original file line numberDiff line numberDiff line change
@@ -542,11 +542,24 @@ def request_body(prompt, stream: false, tool_call: false)
542542
<parameters>
543543
<location>Sydney</location>
544544
<unit>c</unit>
545+
<is_it_hot>true</is_it_hot>
545546
</parameters>
546547
</invoke>
547548
</function_calls>
548549
XML
549550

551+
let(:weather_tool) do
552+
{
553+
name: "get_weather",
554+
description: "get weather",
555+
parameters: [
556+
{ name: "location", type: "string", description: "location", required: true },
557+
{ name: "unit", type: "string", description: "unit", required: true, enum: %w[c f] },
558+
{ name: "is_it_hot", type: "boolean", description: "is it hot" },
559+
],
560+
}
561+
end
562+
550563
it "parses XML tool calls" do
551564
response = {
552565
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
@@ -574,7 +587,7 @@ def request_body(prompt, stream: false, tool_call: false)
574587
body = nil
575588
open_ai_mock.stub_raw(response, body_blk: proc { |inner_body| body = inner_body })
576589

577-
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
590+
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: [weather_tool]))
578591
tool_call = endpoint.perform_completion!(dialect, user)
579592

580593
body_parsed = JSON.parse(body, symbolize_names: true)
@@ -583,7 +596,7 @@ def request_body(prompt, stream: false, tool_call: false)
583596
expect(body_parsed[:messages][0][:content]).to include("<function_calls>")
584597

585598
expect(tool_call.name).to eq("get_weather")
586-
expect(tool_call.parameters).to eq({ location: "Sydney", unit: "c" })
599+
expect(tool_call.parameters).to eq({ location: "Sydney", unit: "c", is_it_hot: true })
587600
end
588601
end
589602

spec/lib/completions/tool_definition_spec.rb

+25-25
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@
132132
end
133133

134134
it "converts numbers to strings" do
135-
result = tool.coerce_parameters({ "name" => 123 })
136-
expect(result["name"]).to eq("123")
135+
result = tool.coerce_parameters(name: 123)
136+
expect(result[:name]).to eq("123")
137137
end
138138

139139
it "converts booleans to strings" do
140-
result = tool.coerce_parameters({ "name" => true })
141-
expect(result["name"]).to eq("true")
140+
result = tool.coerce_parameters(name: true)
141+
expect(result[:name]).to eq("true")
142142
end
143143
end
144144

@@ -156,18 +156,18 @@
156156
end
157157

158158
it "converts string numbers to floats" do
159-
result = tool.coerce_parameters({ "price" => "42.99" })
160-
expect(result["price"]).to eq(42.99)
159+
result = tool.coerce_parameters(price: "42.99")
160+
expect(result[:price]).to eq(42.99)
161161
end
162162

163163
it "converts integers to floats" do
164-
result = tool.coerce_parameters({ "price" => 42 })
165-
expect(result["price"]).to eq(42.0)
164+
result = tool.coerce_parameters(price: 42)
165+
expect(result[:price]).to eq(42.0)
166166
end
167167

168168
it "returns nil for invalid number strings" do
169-
result = tool.coerce_parameters({ "price" => "not a number" })
170-
expect(result["price"]).to be_nil
169+
result = tool.coerce_parameters(price: "not a number")
170+
expect(result[:price]).to be_nil
171171
end
172172
end
173173

@@ -190,18 +190,18 @@
190190
end
191191

192192
it "converts string elements to integers" do
193-
result = tool.coerce_parameters({ "numbers" => %w[1 2 3] })
194-
expect(result["numbers"]).to eq([1, 2, 3])
193+
result = tool.coerce_parameters(numbers: %w[1 2 3])
194+
expect(result[:numbers]).to eq([1, 2, 3])
195195
end
196196

197197
it "parses JSON strings into arrays and converts elements" do
198-
result = tool.coerce_parameters({ "numbers" => "[1, 2, 3]" })
199-
expect(result["numbers"]).to eq([1, 2, 3])
198+
result = tool.coerce_parameters(numbers: "[1, 2, 3]")
199+
expect(result[:numbers]).to eq([1, 2, 3])
200200
end
201201

202202
it "handles mixed type arrays appropriately" do
203-
result = tool.coerce_parameters({ "numbers" => [1, "two", 3.5] })
204-
expect(result["numbers"]).to eq([1, nil, 3])
203+
result = tool.coerce_parameters(numbers: [1, "two", 3.5])
204+
expect(result[:numbers]).to eq([1, nil, 3])
205205
end
206206
end
207207

@@ -231,9 +231,9 @@
231231
end
232232

233233
it "includes missing required parameters as nil" do
234-
result = tool.coerce_parameters({ "optional_param" => "value" })
235-
expect(result["required_param"]).to be_nil
236-
expect(result["optional_param"]).to eq("value")
234+
result = tool.coerce_parameters(optional_param: "value")
235+
expect(result[:required_param]).to be_nil
236+
expect(result[:optional_param]).to eq("value")
237237
end
238238

239239
it "skips missing optional parameters" do
@@ -257,16 +257,16 @@
257257
end
258258

259259
it "preserves true/false values" do
260-
result = tool.coerce_parameters({ "flag" => true })
261-
expect(result["flag"]).to be true
260+
result = tool.coerce_parameters(flag: true)
261+
expect(result[:flag]).to be true
262262
end
263263

264264
it "converts 'true'/'false' strings to booleans" do
265-
result = tool.coerce_parameters({ "flag" => "true" })
266-
expect(result["flag"]).to be true
265+
result = tool.coerce_parameters({ flag: true })
266+
expect(result[:flag]).to be true
267267

268-
result = tool.coerce_parameters({ "flag" => "False" })
269-
expect(result["flag"]).to be false
268+
result = tool.coerce_parameters({ flag: "False" })
269+
expect(result[:flag]).to be false
270270
end
271271

272272
it "returns nil for invalid boolean strings" do

spec/lib/completions/xml_tool_processor_spec.rb

+15
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,24 @@
9393
<parameters>
9494
<hello>world</hello>
9595
<test>value</test>
96+
<bool>true</bool>
9697
</parameters>
9798
</invoke>
9899
XML
99100

101+
tool_definition =
102+
DiscourseAi::Completions::ToolDefinition.from_hash(
103+
name: "hello",
104+
description: "hello world",
105+
parameters: [
106+
{ name: "hello", type: "string", description: "hello" },
107+
{ name: "test", type: "string", description: "test" },
108+
{ name: "bool", type: "boolean", description: "bool" },
109+
],
110+
)
111+
112+
processor = DiscourseAi::Completions::XmlToolProcessor.new(tool_definitions: [tool_definition])
113+
100114
result = []
101115
result << (processor << "hello")
102116
result << (processor << xml)
@@ -109,6 +123,7 @@
109123
parameters: {
110124
hello: "world",
111125
test: "value",
126+
bool: true,
112127
},
113128
)
114129
expect(result).to eq([["hello"], [" world"], [tool_call]])

0 commit comments

Comments
 (0)