Skip to content

Commit 667eccf

Browse files
committed
Add FunctionCall decoding (google-gemini#114)
1 parent 45bc200 commit 667eccf

File tree

7 files changed

+182
-1
lines changed

7 files changed

+182
-1
lines changed

Sources/GoogleAI/Chat.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ public class Chat {
162162
}
163163

164164
parts.append(part)
165+
166+
case .functionCall:
167+
// TODO(andrewheard): Add function call to the chat history when encoding is implemented.
168+
fatalError("Function calling not yet implemented in chat.")
165169
}
166170
}
167171
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// A predicted function call returned from the model.
18+
public struct FunctionCall: Equatable {
19+
/// The name of the function to call.
20+
let name: String
21+
22+
/// The function parameters and values.
23+
let args: JSONObject
24+
}
25+
26+
extension FunctionCall: Decodable {
27+
enum CodingKeys: CodingKey {
28+
case name
29+
case args
30+
}
31+
32+
public init(from decoder: Decoder) throws {
33+
let container = try decoder.container(keyedBy: CodingKeys.self)
34+
name = try container.decode(String.self, forKey: .name)
35+
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
36+
self.args = args
37+
} else {
38+
args = JSONObject()
39+
}
40+
}
41+
}

Sources/GoogleAI/ModelContent.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public struct ModelContent: Codable, Equatable {
2525
enum CodingKeys: String, CodingKey {
2626
case text
2727
case inlineData
28+
case functionCall
2829
}
2930

3031
enum InlineDataKeys: String, CodingKey {
@@ -38,6 +39,9 @@ public struct ModelContent: Codable, Equatable {
3839
/// Data with a specified media type. Not all media types may be supported by the AI model.
3940
case data(mimetype: String, Data)
4041

42+
/// A predicted function call returned from the model.
43+
case functionCall(FunctionCall)
44+
4145
// MARK: Convenience Initializers
4246

4347
/// Convenience function for populating a Part with JPEG data.
@@ -64,6 +68,9 @@ public struct ModelContent: Codable, Equatable {
6468
)
6569
try inlineDataContainer.encode(mimetype, forKey: .mimeType)
6670
try inlineDataContainer.encode(bytes, forKey: .bytes)
71+
case .functionCall:
72+
// TODO(andrewheard): Encode FunctionCalls when when encoding is implemented.
73+
fatalError("FunctionCall encoding not implemented.")
6774
}
6875
}
6976

@@ -79,10 +86,12 @@ public struct ModelContent: Codable, Equatable {
7986
let mimetype = try dataContainer.decode(String.self, forKey: .mimeType)
8087
let bytes = try dataContainer.decode(Data.self, forKey: .bytes)
8188
self = .data(mimetype: mimetype, bytes)
89+
} else if values.contains(.functionCall) {
90+
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
8291
} else {
8392
throw DecodingError.dataCorrupted(.init(
8493
codingPath: [CodingKeys.text, CodingKeys.inlineData],
85-
debugDescription: "Neither text or inline data was found."
94+
debugDescription: "No text, inline data or function call was found."
8695
))
8796
}
8897
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"candidates": [
3+
{
4+
"content": {
5+
"parts": [
6+
{
7+
"functionCall": {
8+
"name": "current_time"
9+
}
10+
}
11+
],
12+
"role": "model"
13+
},
14+
"finishReason": "STOP",
15+
"index": 0
16+
}
17+
]
18+
}
19+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"candidates": [
3+
{
4+
"content": {
5+
"parts": [
6+
{
7+
"functionCall": {
8+
"name": "current_time",
9+
"args": {}
10+
}
11+
}
12+
],
13+
"role": "model"
14+
},
15+
"finishReason": "STOP",
16+
"index": 0
17+
}
18+
]
19+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"candidates": [
3+
{
4+
"content": {
5+
"parts": [
6+
{
7+
"functionCall": {
8+
"name": "sum",
9+
"args": {
10+
"y": 5,
11+
"x": 4
12+
}
13+
}
14+
}
15+
],
16+
"role": "model"
17+
},
18+
"finishReason": "STOP",
19+
"index": 0
20+
}
21+
]
22+
}

Tests/GoogleAITests/GenerativeModelTests.swift

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,73 @@ final class GenerativeModelTests: XCTestCase {
169169
_ = try await model.generateContent(testPrompt)
170170
}
171171

172+
func testGenerateContent_success_functionCall_emptyArguments() async throws {
173+
MockURLProtocol
174+
.requestHandler = try httpRequestHandler(
175+
forResource: "unary-success-function-call-empty-arguments",
176+
withExtension: "json"
177+
)
178+
179+
let response = try await model.generateContent(testPrompt)
180+
181+
XCTAssertEqual(response.candidates.count, 1)
182+
let candidate = try XCTUnwrap(response.candidates.first)
183+
XCTAssertEqual(candidate.content.parts.count, 1)
184+
let part = try XCTUnwrap(candidate.content.parts.first)
185+
guard case let .functionCall(functionCall) = part else {
186+
XCTFail("Part is not a FunctionCall.")
187+
return
188+
}
189+
XCTAssertEqual(functionCall.name, "current_time")
190+
XCTAssertTrue(functionCall.args.isEmpty)
191+
}
192+
193+
func testGenerateContent_success_functionCall_noArguments() async throws {
194+
MockURLProtocol
195+
.requestHandler = try httpRequestHandler(
196+
forResource: "unary-success-function-call-no-arguments",
197+
withExtension: "json"
198+
)
199+
200+
let response = try await model.generateContent(testPrompt)
201+
202+
XCTAssertEqual(response.candidates.count, 1)
203+
let candidate = try XCTUnwrap(response.candidates.first)
204+
XCTAssertEqual(candidate.content.parts.count, 1)
205+
let part = try XCTUnwrap(candidate.content.parts.first)
206+
guard case let .functionCall(functionCall) = part else {
207+
XCTFail("Part is not a FunctionCall.")
208+
return
209+
}
210+
XCTAssertEqual(functionCall.name, "current_time")
211+
XCTAssertTrue(functionCall.args.isEmpty)
212+
}
213+
214+
func testGenerateContent_success_functionCall_withArguments() async throws {
215+
MockURLProtocol
216+
.requestHandler = try httpRequestHandler(
217+
forResource: "unary-success-function-call-with-arguments",
218+
withExtension: "json"
219+
)
220+
221+
let response = try await model.generateContent(testPrompt)
222+
223+
XCTAssertEqual(response.candidates.count, 1)
224+
let candidate = try XCTUnwrap(response.candidates.first)
225+
XCTAssertEqual(candidate.content.parts.count, 1)
226+
let part = try XCTUnwrap(candidate.content.parts.first)
227+
guard case let .functionCall(functionCall) = part else {
228+
XCTFail("Part is not a FunctionCall.")
229+
return
230+
}
231+
XCTAssertEqual(functionCall.name, "sum")
232+
XCTAssertEqual(functionCall.args.count, 2)
233+
let argX = try XCTUnwrap(functionCall.args["x"])
234+
XCTAssertEqual(argX, .number(4))
235+
let argY = try XCTUnwrap(functionCall.args["y"])
236+
XCTAssertEqual(argY, .number(5))
237+
}
238+
172239
func testGenerateContent_failure_invalidAPIKey() async throws {
173240
let expectedStatusCode = 400
174241
MockURLProtocol

0 commit comments

Comments
 (0)