Skip to content

Commit 3e44a2f

Browse files
authored
[Vertex AI] Add responseMIMEType to GenerationConfig (#12918)
1 parent ae5b57f commit 3e44a2f

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

FirebaseVertexAI/Sources/GenerationConfig.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ public struct GenerationConfig {
6363
/// The stop sequence will not be included as part of the response.
6464
public let stopSequences: [String]?
6565

66+
/// Output response MIME type of the generated candidate text.
67+
///
68+
/// Supported MIME types:
69+
/// - `text/plain`: Text output; the default behavior if unspecified.
70+
/// - `application/json`: JSON response in the candidates.
71+
public let responseMIMEType: String?
72+
6673
/// Creates a new `GenerationConfig` value.
6774
///
6875
/// - Parameter temperature: See ``temperature``
@@ -73,7 +80,7 @@ public struct GenerationConfig {
7380
/// - Parameter stopSequences: See ``stopSequences``
7481
public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil,
7582
candidateCount: Int? = nil, maxOutputTokens: Int? = nil,
76-
stopSequences: [String]? = nil) {
83+
stopSequences: [String]? = nil, responseMIMEType: String? = nil) {
7784
// Explicit init because otherwise if we re-arrange the above variables it changes the API
7885
// surface.
7986
self.temperature = temperature
@@ -82,6 +89,7 @@ public struct GenerationConfig {
8289
self.candidateCount = candidateCount
8390
self.maxOutputTokens = maxOutputTokens
8491
self.stopSequences = stopSequences
92+
self.responseMIMEType = responseMIMEType
8593
}
8694
}
8795

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseVertexAI
16+
import Foundation
17+
import XCTest
18+
19+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
20+
final class GenerationConfigTests: XCTestCase {
21+
let encoder = JSONEncoder()
22+
23+
override func setUp() {
24+
encoder.outputFormatting = .init(
25+
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
26+
)
27+
}
28+
29+
// MARK: GenerationConfig Encoding
30+
31+
func testEncodeGenerationConfig_default() throws {
32+
let generationConfig = GenerationConfig()
33+
34+
let jsonData = try encoder.encode(generationConfig)
35+
36+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
37+
XCTAssertEqual(json, """
38+
{
39+
40+
}
41+
""")
42+
}
43+
44+
func testEncodeGenerationConfig_allOptions() throws {
45+
let temperature: Float = 0.5
46+
let topP: Float = 0.75
47+
let topK = 40
48+
let candidateCount = 2
49+
let maxOutputTokens = 256
50+
let stopSequences = ["END", "DONE"]
51+
let responseMIMEType = "text/plain"
52+
let generationConfig = GenerationConfig(
53+
temperature: temperature,
54+
topP: topP,
55+
topK: topK,
56+
candidateCount: candidateCount,
57+
maxOutputTokens: maxOutputTokens,
58+
stopSequences: stopSequences,
59+
responseMIMEType: responseMIMEType
60+
)
61+
62+
let jsonData = try encoder.encode(generationConfig)
63+
64+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
65+
XCTAssertEqual(json, """
66+
{
67+
"candidateCount" : \(candidateCount),
68+
"maxOutputTokens" : \(maxOutputTokens),
69+
"responseMIMEType" : "\(responseMIMEType)",
70+
"stopSequences" : [
71+
"END",
72+
"DONE"
73+
],
74+
"temperature" : \(temperature),
75+
"topK" : \(topK),
76+
"topP" : \(topP)
77+
}
78+
""")
79+
}
80+
81+
func testEncodeGenerationConfig_responseMIMEType() throws {
82+
let mimeType = "image/jpeg"
83+
let generationConfig = GenerationConfig(responseMIMEType: mimeType)
84+
85+
let jsonData = try encoder.encode(generationConfig)
86+
87+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
88+
XCTAssertEqual(json, """
89+
{
90+
"responseMIMEType" : "\(mimeType)"
91+
}
92+
""")
93+
}
94+
}

FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ final class VertexAIAPITests: XCTestCase {
3030
topK: 16,
3131
candidateCount: 4,
3232
maxOutputTokens: 256,
33-
stopSequences: ["..."])
33+
stopSequences: ["..."],
34+
responseMIMEType: "text/plain")
3435
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
3536
let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")])
3637

0 commit comments

Comments
 (0)