Skip to content

Commit e6f6304

Browse files
committed
DEV: use a proper object for tool definition
This moves away from using a loose hash to define tools, which is error prone. Instead given a proper object we will also be able to coerce the return values to match tool definition correctly
1 parent c34fcc8 commit e6f6304

File tree

4 files changed

+611
-19
lines changed

4 files changed

+611
-19
lines changed

lib/completions/dialects/open_ai_tools.rb

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,29 @@ def initialize(tools)
99
end
1010

1111
def translated_tools
12-
raw_tools.map do |t|
13-
tool = t.dup
12+
raw_tools.map do |tool|
13+
properties = {}
14+
required = []
1415

15-
tool[:parameters] = t[:parameters]
16-
.to_a
17-
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
18-
name = p[:name]
19-
memo[:required] << name if p[:required]
16+
result = {
17+
name: tool.name,
18+
description: tool.description,
19+
parameters: {
20+
type: "object",
21+
properties: properties,
22+
required: required,
23+
},
24+
}
2025

21-
except = %i[name required item_type]
22-
except << :enum if p[:enum].blank?
26+
tool.parameters.each do |param|
27+
name = param.name
28+
required << name if param.required
29+
properties[name] = { type: param.type, description: param.description }
30+
properties[name][:items] = { type: param.item_type } if param.item_type
31+
properties[name][:enum] = param.enum if param.enum
32+
end
2333

24-
memo[:properties][name] = p.except(*except)
25-
26-
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
27-
memo
28-
end
29-
30-
{ type: "function", function: tool }
34+
{ type: "function", function: result }
3135
end
3236
end
3337

lib/completions/prompt.rb

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ module Completions
55
class Prompt
66
INVALID_TURN = Class.new(StandardError)
77

8-
attr_reader :messages
9-
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
8+
attr_reader :messages, :tools
9+
attr_accessor :topic_id, :post_id, :max_pixels, :tool_choice
1010

1111
def initialize(
1212
system_message_text = nil,
@@ -37,10 +37,25 @@ def initialize(
3737
@messages.each { |message| validate_message(message) }
3838
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
3939

40-
@tools = tools
40+
self.tools = tools
4141
@tool_choice = tool_choice
4242
end
4343

44+
def tools=(tools)
45+
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) && !tools.nil?
46+
47+
@tools =
48+
tools.map do |tool|
49+
if tool.is_a?(Hash)
50+
ToolDefinition.from_hash(tool)
51+
elsif tool.is_a?(ToolDefinition)
52+
tool
53+
else
54+
raise ArgumentError, "tool must be a hash or a ToolDefinition was #{tool.class}"
55+
end
56+
end
57+
end
58+
4459
# this new api tries to create symmetry between responses and prompts
4560
# this means anything we get back from the model via endpoint can be easily appended
4661
def push_model_response(response)

lib/completions/tool_definition.rb

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# frozen_string_literal: true
2+
3+
module DiscourseAi
4+
module Completions
5+
class ToolDefinition
6+
class ParameterDefinition
7+
ALLOWED_TYPES = %i[string boolean integer array number].freeze
8+
ALLOWED_KEYS = %i[name description type required enum item_type].freeze
9+
10+
attr_reader :name, :description, :type, :required, :enum, :item_type
11+
12+
def self.from_hash(hash)
13+
extra_keys = hash.keys - ALLOWED_KEYS
14+
if !extra_keys.empty?
15+
raise ArgumentError, "Unexpected keys in parameter definition: #{extra_keys}"
16+
end
17+
18+
new(
19+
name: hash[:name],
20+
description: hash[:description],
21+
type: hash[:type],
22+
required: hash[:required],
23+
enum: hash[:enum],
24+
item_type: hash[:item_type],
25+
)
26+
end
27+
28+
def initialize(name:, description:, type:, required: false, enum: nil, item_type: nil)
29+
raise ArgumentError, "name must be a string" if !name.is_a?(String) || name.empty?
30+
31+
if !description.is_a?(String) || description.empty?
32+
raise ArgumentError, "description must be a string"
33+
end
34+
35+
type_sym = type.to_sym
36+
37+
if !ALLOWED_TYPES.include?(type_sym)
38+
raise ArgumentError, "type must be one of: #{ALLOWED_TYPES.join(", ")}"
39+
end
40+
41+
# Validate enum if provided
42+
if enum
43+
raise ArgumentError, "enum must be an array" if !enum.is_a?(Array)
44+
45+
# Validate enum entries match the specified type
46+
enum.each do |value|
47+
case type_sym
48+
when :string
49+
if !value.is_a?(String)
50+
raise ArgumentError, "enum values must be strings for type 'string'"
51+
end
52+
when :boolean
53+
if ![true, false].include?(value)
54+
raise ArgumentError, "enum values must be booleans for type 'boolean'"
55+
end
56+
when :integer
57+
if !value.is_a?(Integer)
58+
raise ArgumentError, "enum values must be integers for type 'integer'"
59+
end
60+
when :number
61+
if !value.is_a?(Numeric)
62+
raise ArgumentError, "enum values must be numbers for type 'number'"
63+
end
64+
when :array
65+
if !value.is_a?(Array)
66+
raise ArgumentError, "enum values must be arrays for type 'array'"
67+
end
68+
end
69+
end
70+
end
71+
72+
if item_type && type_sym != :array
73+
raise ArgumentError, "item_type can only be specified for array type"
74+
end
75+
76+
if item_type
77+
if !ALLOWED_TYPES.include?(item_type.to_sym)
78+
raise ArgumentError, "item type must be one of: #{ALLOWED_TYPES.join(", ")}"
79+
end
80+
end
81+
82+
@name = name
83+
@description = description
84+
@type = type_sym
85+
@required = !!required
86+
@enum = enum
87+
@item_type = item_type ? item_type.to_sym : nil
88+
end
89+
90+
def to_h
91+
result = { name: @name, description: @description, type: @type, required: @required }
92+
result[:enum] = @enum if @enum
93+
result[:item_type] = @item_type if @item_type
94+
result
95+
end
96+
end
97+
98+
attr_reader :name, :description, :parameters
99+
100+
def self.from_hash(hash)
101+
allowed_keys = %i[name description parameters]
102+
extra_keys = hash.keys - allowed_keys
103+
if !extra_keys.empty?
104+
raise ArgumentError, "Unexpected keys in tool definition: #{extra_keys}"
105+
end
106+
107+
params = hash[:parameters] || []
108+
parameter_objects =
109+
params.map do |param|
110+
if param.is_a?(Hash)
111+
ParameterDefinition.from_hash(param)
112+
else
113+
param
114+
end
115+
end
116+
117+
new(name: hash[:name], description: hash[:description], parameters: parameter_objects)
118+
end
119+
120+
def initialize(name:, description:, parameters: [])
121+
raise ArgumentError, "name must be a string" if !name.is_a?(String) || name.empty?
122+
123+
if !description.is_a?(String) || description.empty?
124+
raise ArgumentError, "description must be a string"
125+
end
126+
127+
raise ArgumentError, "parameters must be an array" if !parameters.is_a?(Array)
128+
129+
# Check for duplicated parameter names
130+
param_names = parameters.map { |p| p.name }
131+
duplicates = param_names.select { |param_name| param_names.count(param_name) > 1 }.uniq
132+
if !duplicates.empty?
133+
raise ArgumentError, "Duplicate parameter names found: #{duplicates.join(", ")}"
134+
end
135+
136+
@name = name
137+
@description = description
138+
@parameters = parameters
139+
end
140+
141+
def to_h
142+
{ name: @name, description: @description, parameters: @parameters.map(&:to_h) }
143+
end
144+
145+
def coerce_parameters(params)
146+
result = {}
147+
148+
return result if !params.is_a?(Hash)
149+
150+
@parameters.each do |param_def|
151+
param_name = param_def.name
152+
153+
# Skip if parameter is not provided and not required
154+
next if !params.key?(param_name) && !param_def.required
155+
156+
# Handle required but missing parameters
157+
if !params.key?(param_name) && param_def.required
158+
result[param_name] = nil
159+
next
160+
end
161+
162+
value = params[param_name]
163+
164+
# For array type, handle item coercion
165+
if param_def.type == :array
166+
result[param_name] = coerce_array_value(value, param_def.item_type)
167+
else
168+
result[param_name] = coerce_single_value(value, param_def.type)
169+
end
170+
end
171+
172+
result
173+
end
174+
175+
private
176+
177+
def coerce_array_value(value, item_type)
178+
# Handle non-array input by attempting to parse JSON strings
179+
if !value.is_a?(Array)
180+
if value.is_a?(String)
181+
begin
182+
parsed = JSON.parse(value)
183+
value = parsed.is_a?(Array) ? parsed : nil
184+
rescue JSON::ParserError
185+
return nil
186+
end
187+
else
188+
return nil
189+
end
190+
end
191+
192+
# No item type specified, return the array as is
193+
return value if !item_type
194+
195+
# Coerce each item in the array
196+
value.map { |item| coerce_single_value(item, item_type) }
197+
end
198+
199+
def coerce_single_value(value, type)
200+
result = nil
201+
202+
case type
203+
when :string
204+
result = value.to_s
205+
when :integer
206+
if value.is_a?(Integer)
207+
result = value
208+
elsif value.is_a?(Float)
209+
result = value.to_i
210+
elsif value.is_a?(String) && value.match?(/\A-?\d+\z/)
211+
result = value.to_i
212+
end
213+
when :number
214+
if value.is_a?(Numeric)
215+
result = value.to_f
216+
elsif value.is_a?(String) && value.match?(/\A-?\d+(\.\d+)?\z/)
217+
result = value.to_f
218+
end
219+
when :boolean
220+
if value == true || value == false
221+
result = value
222+
elsif value.is_a?(String)
223+
if value.downcase == "true"
224+
result = true
225+
elsif value.downcase == "false"
226+
result = false
227+
end
228+
end
229+
end
230+
231+
result
232+
end
233+
end
234+
end
235+
end

0 commit comments

Comments
 (0)