Skip to content

Commit 1ed073f

Browse files
committed
Automatic function calling prototype
1 parent 667eccf commit 1ed073f

File tree

6 files changed

+256
-10
lines changed

6 files changed

+256
-10
lines changed

Examples/GenerativeAICLI/Sources/GenerateContent.swift

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,20 @@ struct GenerateContent: AsyncParsableCommand {
7070
name: modelNameOrDefault(),
7171
apiKey: apiKey,
7272
generationConfig: config,
73-
safetySettings: safetySettings
73+
safetySettings: safetySettings,
74+
tools: [Tool(functionDeclarations: [
75+
FunctionDeclaration(
76+
name: "get_exchange_rate",
77+
description: "Get the exchange rate for currencies between countries",
78+
parameters: getExchangeRateSchema(),
79+
function: getExchangeRateWrapper
80+
),
81+
])],
82+
requestOptions: RequestOptions(apiVersion: "v1beta")
7483
)
7584

85+
let chat = model.startChat()
86+
7687
var parts = [ModelContent.Part]()
7788

7889
if let textPrompt = textPrompt {
@@ -96,15 +107,16 @@ struct GenerateContent: AsyncParsableCommand {
96107
let input = [ModelContent(parts: parts)]
97108

98109
if isStreaming {
99-
let contentStream = model.generateContentStream(input)
110+
let contentStream = chat.sendMessageStream(input)
100111
print("Generated Content <streaming>:")
101112
for try await content in contentStream {
102113
if let text = content.text {
103114
print(text)
104115
}
105116
}
106117
} else {
107-
let content = try await model.generateContent(input)
118+
// Unary generate content
119+
let content = try await chat.sendMessage(input)
108120
if let text = content.text {
109121
print("Generated Content:\n\(text)")
110122
}
@@ -123,6 +135,76 @@ struct GenerateContent: AsyncParsableCommand {
123135
return "gemini-1.0-pro"
124136
}
125137
}
138+
139+
// MARK: - Callable Functions
140+
141+
// Returns exchange rates from the Frankfurter API
142+
// This is an example function that a developer might provide.
143+
func getExchangeRate(amount: Double, date: String, from: String,
144+
to: String) async throws -> String {
145+
var urlComponents = URLComponents(string: "https://api.frankfurter.app")!
146+
urlComponents.path = "/\(date)"
147+
urlComponents.queryItems = [
148+
.init(name: "amount", value: String(amount)),
149+
.init(name: "from", value: from),
150+
.init(name: "to", value: to),
151+
]
152+
153+
let (data, _) = try await URLSession.shared.data(from: urlComponents.url!)
154+
return String(data: data, encoding: .utf8)!
155+
}
156+
157+
// This is a wrapper for the `getExchangeRate` function.
158+
func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject {
159+
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
160+
guard case let .string(date) = args["currency_date"] else {
161+
fatalError()
162+
}
163+
guard case let .string(from) = args["currency_from"] else {
164+
fatalError()
165+
}
166+
guard case let .string(to) = args["currency_to"] else {
167+
fatalError()
168+
}
169+
guard case let .number(amount) = args["amount"] else {
170+
fatalError()
171+
}
172+
173+
// 2. Call the wrapped function
174+
let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to)
175+
176+
// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
177+
return ["content": .string(response)]
178+
}
179+
180+
// Returns the schema of the `getExchangeRate` function
181+
func getExchangeRateSchema() -> Schema {
182+
return Schema(
183+
type: .object,
184+
properties: [
185+
"currency_date": Schema(
186+
type: .string,
187+
description: """
188+
A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period
189+
is not specified
190+
"""
191+
),
192+
"currency_from": Schema(
193+
type: .string,
194+
description: "The currency to convert from in ISO 4217 format"
195+
),
196+
"currency_to": Schema(
197+
type: .string,
198+
description: "The currency to convert to in ISO 4217 format"
199+
),
200+
"amount": Schema(
201+
type: .number,
202+
description: "The amount of currency to convert as a double value"
203+
),
204+
],
205+
required: ["currency_date", "currency_from", "currency_to", "amount"]
206+
)
207+
}
126208
}
127209

128210
enum CLIError: Error {

Sources/GoogleAI/Chat.swift

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,32 @@ public class Chat {
7070
// Make sure we inject the role into the content received.
7171
let toAdd = ModelContent(role: "model", parts: reply.parts)
7272

73+
var functionResponses = [FunctionResponse]()
74+
for part in reply.parts {
75+
if case let .functionCall(functionCall) = part {
76+
try functionResponses.append(await model.executeFunction(functionCall: functionCall))
77+
}
78+
}
79+
80+
// Call the functions requested by the model, if any.
81+
let functionResponseContent = try ModelContent(
82+
role: "function",
83+
functionResponses.map { functionResponse in
84+
ModelContent.Part.functionResponse(functionResponse)
85+
}
86+
)
87+
7388
// Append the request and successful result to history, then return the value.
7489
history.append(contentsOf: newContent)
7590
history.append(toAdd)
76-
return result
91+
92+
// If no function calls requested, return the results.
93+
if functionResponses.isEmpty {
94+
return result
95+
}
96+
97+
// Re-send the message with the function responses.
98+
return try await sendMessage([functionResponseContent])
7799
}
78100

79101
/// See ``sendMessageStream(_:)-4abs3``.
@@ -166,6 +188,10 @@ public class Chat {
166188
case .functionCall:
167189
// TODO(andrewheard): Add function call to the chat history when encoding is implemented.
168190
fatalError("Function calling not yet implemented in chat.")
191+
192+
case .functionResponse:
193+
// TODO(andrewheard): Add function response to chat history when encoding is implemented.
194+
fatalError("Function calling not yet implemented in chat.")
169195
}
170196
}
171197
}

Sources/GoogleAI/FunctionCalling.swift

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,97 @@
1515
import Foundation
1616

1717
/// A predicted function call returned from the model.
18-
public struct FunctionCall: Equatable {
18+
///
19+
/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall
20+
public struct FunctionCall: Equatable, Encodable {
1921
/// The name of the function to call.
20-
let name: String
22+
public let name: String
2123

2224
/// The function parameters and values.
23-
let args: JSONObject
25+
public let args: JSONObject
2426
}
2527

28+
// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema
29+
public class Schema: Encodable {
30+
let type: DataType
31+
32+
let format: String?
33+
34+
let description: String?
35+
36+
let nullable: Bool?
37+
38+
let enumValues: [String]?
39+
40+
let items: Schema?
41+
42+
let properties: [String: Schema]?
43+
44+
let required: [String]?
45+
46+
public init(type: DataType, format: String? = nil, description: String? = nil,
47+
nullable: Bool? = nil,
48+
enumValues: [String]? = nil, items: Schema? = nil,
49+
properties: [String: Schema]? = nil,
50+
required: [String]? = nil) {
51+
self.type = type
52+
self.format = format
53+
self.description = description
54+
self.nullable = nullable
55+
self.enumValues = enumValues
56+
self.items = items
57+
self.properties = properties
58+
self.required = required
59+
}
60+
}
61+
62+
// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type
63+
public enum DataType: String, Encodable {
64+
case string = "STRING"
65+
case number = "NUMBER"
66+
case integer = "INTEGER"
67+
case boolean = "BOOLEAN"
68+
case array = "ARRAY"
69+
case object = "OBJECT"
70+
}
71+
72+
// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration
73+
public struct FunctionDeclaration {
74+
let name: String
75+
76+
let description: String
77+
78+
let parameters: Schema
79+
80+
let function: ((JSONObject) async throws -> JSONObject)?
81+
82+
public init(name: String, description: String, parameters: Schema,
83+
function: ((JSONObject) async throws -> JSONObject)?) {
84+
self.name = name
85+
self.description = description
86+
self.parameters = parameters
87+
self.function = function
88+
}
89+
}
90+
91+
// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool
92+
public struct Tool: Encodable {
93+
let functionDeclarations: [FunctionDeclaration]?
94+
95+
public init(functionDeclarations: [FunctionDeclaration]?) {
96+
self.functionDeclarations = functionDeclarations
97+
}
98+
}
99+
100+
// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse
101+
public struct FunctionResponse: Equatable, Encodable {
102+
let name: String
103+
104+
let response: JSONObject
105+
}
106+
107+
// MARK: - Codable Conformance
108+
26109
extension FunctionCall: Decodable {
27110
enum CodingKeys: CodingKey {
28111
case name
@@ -39,3 +122,18 @@ extension FunctionCall: Decodable {
39122
}
40123
}
41124
}
125+
126+
extension FunctionDeclaration: Encodable {
127+
enum CodingKeys: String, CodingKey {
128+
case name
129+
case description
130+
case parameters
131+
}
132+
133+
public func encode(to encoder: Encoder) throws {
134+
var container = encoder.container(keyedBy: CodingKeys.self)
135+
try container.encode(name, forKey: .name)
136+
try container.encode(description, forKey: .description)
137+
try container.encode(parameters, forKey: .parameters)
138+
}
139+
}

Sources/GoogleAI/GenerateContentRequest.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct GenerateContentRequest {
2121
let contents: [ModelContent]
2222
let generationConfig: GenerationConfig?
2323
let safetySettings: [SafetySetting]?
24+
let tools: [Tool]?
2425
let isStreaming: Bool
2526
let options: RequestOptions
2627
}
@@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable {
3132
case contents
3233
case generationConfig
3334
case safetySettings
35+
case tools
3436
}
3537
}
3638

Sources/GoogleAI/GenerativeModel.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ public final class GenerativeModel {
3636
/// The safety settings to be used for prompts.
3737
let safetySettings: [SafetySetting]?
3838

39+
let tools: [Tool]?
40+
3941
/// Configuration parameters for sending requests to the backend.
4042
let requestOptions: RequestOptions
4143

@@ -52,12 +54,14 @@ public final class GenerativeModel {
5254
apiKey: String,
5355
generationConfig: GenerationConfig? = nil,
5456
safetySettings: [SafetySetting]? = nil,
57+
tools: [Tool]? = nil,
5558
requestOptions: RequestOptions = RequestOptions()) {
5659
self.init(
5760
name: name,
5861
apiKey: apiKey,
5962
generationConfig: generationConfig,
6063
safetySettings: safetySettings,
64+
tools: tools,
6165
requestOptions: requestOptions,
6266
urlSession: .shared
6367
)
@@ -68,12 +72,14 @@ public final class GenerativeModel {
6872
apiKey: String,
6973
generationConfig: GenerationConfig? = nil,
7074
safetySettings: [SafetySetting]? = nil,
75+
tools: [Tool]? = nil,
7176
requestOptions: RequestOptions = RequestOptions(),
7277
urlSession: URLSession) {
7378
modelResourceName = GenerativeModel.modelResourceName(name: name)
7479
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
7580
self.generationConfig = generationConfig
7681
self.safetySettings = safetySettings
82+
self.tools = tools
7783
self.requestOptions = requestOptions
7884

7985
Logging.default.info("""
@@ -119,6 +125,7 @@ public final class GenerativeModel {
119125
contents: content(),
120126
generationConfig: generationConfig,
121127
safetySettings: safetySettings,
128+
tools: tools,
122129
isStreaming: false,
123130
options: requestOptions)
124131
response = try await generativeAIService.loadRequest(request: generateContentRequest)
@@ -190,6 +197,7 @@ public final class GenerativeModel {
190197
contents: evaluatedContent,
191198
generationConfig: generationConfig,
192199
safetySettings: safetySettings,
200+
tools: tools,
193201
isStreaming: true,
194202
options: requestOptions)
195203

@@ -270,6 +278,30 @@ public final class GenerativeModel {
270278
}
271279
}
272280

281+
func executeFunction(functionCall: FunctionCall) async throws -> FunctionResponse {
282+
guard let tools = tools else {
283+
throw GenerateContentError.internalError(underlying: FunctionCallError())
284+
}
285+
guard let tool = tools.first(where: { tool in
286+
tool.functionDeclarations != nil
287+
}) else {
288+
throw GenerateContentError.internalError(underlying: FunctionCallError())
289+
}
290+
guard let functionDeclaration = tool.functionDeclarations?.first(where: { functionDeclaration in
291+
functionDeclaration.name == functionCall.name
292+
}) else {
293+
throw GenerateContentError.internalError(underlying: FunctionCallError())
294+
}
295+
guard let function = functionDeclaration.function else {
296+
throw GenerateContentError.internalError(underlying: FunctionCallError())
297+
}
298+
299+
return try FunctionResponse(
300+
name: functionCall.name,
301+
response: await function(functionCall.args)
302+
)
303+
}
304+
273305
/// Returns a model resource name of the form "models/model-name" based on `name`.
274306
private static func modelResourceName(name: String) -> String {
275307
if name.contains("/") {
@@ -299,3 +331,5 @@ public final class GenerativeModel {
299331
public enum CountTokensError: Error {
300332
case internalError(underlying: Error)
301333
}
334+
335+
struct FunctionCallError: Error {}

0 commit comments

Comments
 (0)