Skip to content

Commit ff2e18f

Browse files
authored
FIX: Structured output discrepancies. (#1340)
This change fixes two bugs and adds a safeguard. The first issue is that the schema Gemini expected differed from the one sent, resulting in 400 errors when performing completions. The second issue was that creating a new persona won't define a method for `response_format`. This has to be explicitly defined when we wrap it inside the Persona class. Also, There was a mismatch between the default value and what we stored in the DB. Some parts of the code expected symbols as keys and others as strings. Finally, we add a safeguard when, even if asked to, the model refuses to reply with a valid JSON. In this case, we are making a best-effort to recover and stream the raw response.
1 parent 1b3fdad commit ff2e18f

File tree

13 files changed

+74
-40
lines changed

13 files changed

+74
-40
lines changed

app/models/ai_persona.rb

+1
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def class_instance
266266
define_method(:top_p) { @ai_persona&.top_p }
267267
define_method(:system_prompt) { @ai_persona&.system_prompt || "You are a helpful bot." }
268268
define_method(:uploads) { @ai_persona&.uploads }
269+
define_method(:response_format) { @ai_persona&.response_format }
269270
define_method(:examples) { @ai_persona&.examples }
270271
end
271272
end

lib/completions/endpoints/gemini.rb

+7-3
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,13 @@ def prepare_payload(prompt, model_params, dialect)
8787
if model_params.present?
8888
payload[:generationConfig].merge!(model_params.except(:response_format))
8989

90-
if model_params[:response_format].present?
91-
# https://ai.google.dev/api/generate-content#generationconfig
92-
payload[:generationConfig][:responseSchema] = model_params[:response_format]
90+
# https://ai.google.dev/api/generate-content#generationconfig
91+
gemini_schema = model_params[:response_format].dig(:json_schema, :schema)
92+
93+
if gemini_schema.present?
94+
payload[:generationConfig][:responseSchema] = gemini_schema.except(
95+
:additionalProperties,
96+
)
9397
payload[:generationConfig][:responseMimeType] = "application/json"
9498
end
9599
end

lib/completions/json_streaming_tracker.rb

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def initialize(stream_consumer)
2424
end
2525
end
2626

27+
def broken?
28+
@broken
29+
end
30+
2731
def <<(json)
2832
# llm could send broken json
2933
# in that case just deal with it later

lib/completions/structured_output.rb

+24-15
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,40 @@ def initialize(json_schema_properties)
1313

1414
@tracked = {}
1515

16+
@raw_response = +""
17+
@raw_cursor = 0
18+
1619
@partial_json_tracker = JsonStreamingTracker.new(self)
1720
end
1821

1922
attr_reader :last_chunk_buffer
2023

2124
def <<(raw)
25+
@raw_response << raw
2226
@partial_json_tracker << raw
2327
end
2428

25-
def read_latest_buffered_chunk
26-
@property_names.reduce({}) do |memo, pn|
27-
if @tracked[pn].present?
28-
# This means this property is a string and we want to return unread chunks.
29-
if @property_cursors[pn].present?
30-
unread = @tracked[pn][@property_cursors[pn]..]
31-
32-
memo[pn] = unread if unread.present?
33-
@property_cursors[pn] = @tracked[pn].length
34-
else
35-
# Ints and bools are always returned as is.
36-
memo[pn] = @tracked[pn]
37-
end
38-
end
29+
def read_buffered_property(prop_name)
30+
# Safeguard: If the model is misbehaving and generating something that's not a JSON,
31+
# treat response as a normal string.
32+
# This is a best-effort to recover from an unexpected scenario.
33+
if @partial_json_tracker.broken?
34+
unread_chunk = @raw_response[@raw_cursor..]
35+
@raw_cursor = @raw_response.length
36+
return unread_chunk
37+
end
3938

40-
memo
39+
# Maybe we haven't read that part of the JSON yet.
40+
return nil if @tracked[prop_name].blank?
41+
42+
# This means this property is a string and we want to return unread chunks.
43+
if @property_cursors[prop_name].present?
44+
unread = @tracked[prop_name][@property_cursors[prop_name]..]
45+
@property_cursors[prop_name] = @tracked[prop_name].length
46+
unread
47+
else
48+
# Ints and bools are always returned as is.
49+
@tracked[prop_name]
4150
end
4251
end
4352

lib/personas/bot.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def build_json_schema(response_format)
316316
response_format
317317
.to_a
318318
.reduce({}) do |memo, format|
319-
memo[format[:key].to_sym] = { type: format[:type] }
319+
memo[format["key"].to_sym] = { type: format["type"] }
320320
memo
321321
end
322322

lib/personas/short_summarizer.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def system_prompt
3333
end
3434

3535
def response_format
36-
[{ key: "summary", type: "string" }]
36+
[{ "key" => "summary", "type" => "string" }]
3737
end
3838
end
3939
end

lib/personas/summarizer.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def system_prompt
3434
end
3535

3636
def response_format
37-
[{ key: "summary", type: "string" }]
37+
[{ "key" => "summary", "type" => "string" }]
3838
end
3939

4040
def examples

lib/summarization/fold_content.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def fold(items, user, &on_partial_blk)
116116
if type == :structured_output
117117
json_summary_schema_key = bot.persona.response_format&.first.to_h
118118
partial_summary =
119-
partial.read_latest_buffered_chunk[json_summary_schema_key[:key].to_sym]
119+
partial.read_buffered_property(json_summary_schema_key["key"]&.to_sym)
120120

121121
if partial_summary.present?
122122
summary << partial_summary

spec/lib/completions/endpoints/anthropic_spec.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@
845845
response_format: schema,
846846
) { |partial, cancel| structured_output = partial }
847847

848-
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
848+
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
849849

850850
expected_body = {
851851
model: "claude-3-opus-20240229",

spec/lib/completions/endpoints/aws_bedrock_spec.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def encode_message(message)
607607
}
608608
expect(JSON.parse(request.body)).to eq(expected)
609609

610-
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
610+
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
611611
end
612612
end
613613
end

spec/lib/completions/endpoints/cohere_spec.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,6 @@
366366
)
367367
expect(parsed_body[:message]).to eq("user1: thanks")
368368

369-
expect(structured_output.read_latest_buffered_chunk).to eq({ key: "Hello!" })
369+
expect(structured_output.read_buffered_property(:key)).to eq("Hello!")
370370
end
371371
end

spec/lib/completions/endpoints/gemini_spec.rb

+4-2
Original file line numberDiff line numberDiff line change
@@ -565,12 +565,14 @@ def tool_response
565565
structured_response = partial
566566
end
567567

568-
expect(structured_response.read_latest_buffered_chunk).to eq({ key: "Hello!" })
568+
expect(structured_response.read_buffered_property(:key)).to eq("Hello!")
569569

570570
parsed = JSON.parse(req_body, symbolize_names: true)
571571

572572
# Verify that schema is passed following Gemini API specs.
573-
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(schema)
573+
expect(parsed.dig(:generationConfig, :responseSchema)).to eq(
574+
schema.dig(:json_schema, :schema).except(:additionalProperties),
575+
)
574576
expect(parsed.dig(:generationConfig, :responseMimeType)).to eq("application/json")
575577
end
576578
end

spec/lib/completions/structured_output_spec.rb

+27-13
Original file line numberDiff line numberDiff line change
@@ -34,36 +34,50 @@
3434
]
3535

3636
structured_output << chunks[0]
37-
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 1\n" })
37+
expect(structured_output.read_buffered_property(:message)).to eq("Line 1\n")
3838

3939
structured_output << chunks[1]
40-
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 2\n" })
40+
expect(structured_output.read_buffered_property(:message)).to eq("Line 2\n")
4141

4242
structured_output << chunks[2]
43-
expect(structured_output.read_latest_buffered_chunk).to eq({ message: "Line 3" })
43+
expect(structured_output.read_buffered_property(:message)).to eq("Line 3")
4444

4545
structured_output << chunks[3]
46-
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
46+
expect(structured_output.read_buffered_property(:bool)).to eq(true)
4747

4848
# Waiting for number to be fully buffered.
4949
structured_output << chunks[4]
50-
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true })
50+
expect(structured_output.read_buffered_property(:bool)).to eq(true)
51+
expect(structured_output.read_buffered_property(:number)).to be_nil
5152

5253
structured_output << chunks[5]
53-
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
54+
expect(structured_output.read_buffered_property(:number)).to eq(42)
5455

5556
structured_output << chunks[6]
56-
expect(structured_output.read_latest_buffered_chunk).to eq(
57-
{ bool: true, number: 42, status: "o" },
58-
)
57+
expect(structured_output.read_buffered_property(:number)).to eq(42)
58+
expect(structured_output.read_buffered_property(:bool)).to eq(true)
59+
expect(structured_output.read_buffered_property(:status)).to eq("o")
5960

6061
structured_output << chunks[7]
61-
expect(structured_output.read_latest_buffered_chunk).to eq(
62-
{ bool: true, number: 42, status: "\"k\"" },
63-
)
62+
expect(structured_output.read_buffered_property(:status)).to eq("\"k\"")
6463

6564
# No partial string left to read.
66-
expect(structured_output.read_latest_buffered_chunk).to eq({ bool: true, number: 42 })
65+
expect(structured_output.read_buffered_property(:status)).to eq("")
66+
end
67+
end
68+
69+
describe "dealing with non-JSON responses" do
70+
it "treat it as plain text once we determined it's invalid JSON" do
71+
chunks = [+"I'm not", +"a", +"JSON :)"]
72+
73+
structured_output << chunks[0]
74+
expect(structured_output.read_buffered_property(nil)).to eq("I'm not")
75+
76+
structured_output << chunks[1]
77+
expect(structured_output.read_buffered_property(nil)).to eq("a")
78+
79+
structured_output << chunks[2]
80+
expect(structured_output.read_buffered_property(nil)).to eq("JSON :)")
6781
end
6882
end
6983
end

0 commit comments

Comments
 (0)