Skip to content

Add function calling support #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 26, 2024
Merged
134 changes: 120 additions & 14 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ struct GenerateContent: AsyncParsableCommand {
help: "Enable additional debug logging."
) var debugLogEnabled = false

// Function calls pending processing
var functionCalls = [FunctionCall]()

// Input to the model
var input = [ModelContent]()

mutating func validate() throws {
if textPrompt == nil && imageURL == nil {
throw ValidationError(
Expand All @@ -70,7 +76,32 @@ struct GenerateContent: AsyncParsableCommand {
name: modelNameOrDefault(),
apiKey: apiKey,
generationConfig: config,
safetySettings: safetySettings
safetySettings: safetySettings,
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: Schema(
type: .object,
properties: [
"currency_from": Schema(
type: .string,
format: "enum",
description: "The currency to convert from in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
"currency_to": Schema(
type: .string,
format: "enum",
description: "The currency to convert to in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
],
required: ["currency_from", "currency_to"]
)
),
])],
requestOptions: RequestOptions(apiVersion: "v1beta")
)

var parts = [ModelContent.Part]()
Expand All @@ -93,27 +124,71 @@ struct GenerateContent: AsyncParsableCommand {
parts.append(.data(mimetype: mimeType, imageData))
}

let input = [ModelContent(parts: parts)]
input = [ModelContent(parts: parts)]

repeat {
try await processFunctionCalls()

if isStreaming {
let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print(text)
if isStreaming {
let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
processResponseContent(content: content)
}
} else {
// Unary generate content
let content = try await model.generateContent(input)
print("Generated Content:")
processResponseContent(content: content)
}
} else {
let content = try await model.generateContent(input)
if let text = content.text {
print("Generated Content:\n\(text)")
}
}
} while !functionCalls.isEmpty
} catch {
print("Generate Content Error: \(error)")
}
}

mutating func processResponseContent(content: GenerateContentResponse) {
guard let candidate = content.candidates.first else {
fatalError("No candidate.")
}

for part in candidate.content.parts {
switch part {
case let .text(text):
print(text)
case .data:
fatalError("Inline data not supported.")
case let .functionCall(functionCall):
functionCalls.append(functionCall)
case let .functionResponse(functionResponse):
print("FunctionResponse: \(functionResponse)")
}
}
}

mutating func processFunctionCalls() async throws {
for functionCall in functionCalls {
input.append(ModelContent(
role: "model",
parts: [ModelContent.Part.functionCall(functionCall)]
))
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
input.append(ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse(FunctionResponse(
name: "get_exchange_rate",
response: exchangeRates
))]
))
default:
fatalError("Unknown function named \"\(functionCall.name)\".")
}
}
functionCalls = []
}

func modelNameOrDefault() -> String {
if let modelName = modelName {
return modelName
Expand All @@ -123,6 +198,37 @@ struct GenerateContent: AsyncParsableCommand {
return "gemini-1.0-pro"
}
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(from) = args["currency_from"] else {
fatalError("Missing `currency_from` parameter.")
}
guard case let .string(to) = args["currency_to"] else {
fatalError("Missing `currency_to` parameter.")
}

// 2. Get the exchange rate
let allRates: [String: [String: Double]] = [
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
]
guard let fromRates = allRates[from] else {
return ["error": .string("No data for currency \(from).")]
}
guard let toRate = fromRates[to] else {
return ["error": .string("No data for currency \(to).")]
}

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["rates": .number(toRate)]
}
}

enum CLIError: Error {
Expand Down
2 changes: 1 addition & 1 deletion Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public class Chat {
case let .text(str):
combinedText += str

case .data(mimetype: _, _):
case .data, .functionCall, .functionResponse:
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
Expand Down
151 changes: 151 additions & 0 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// A predicted function call returned from the model.
///
/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall
public struct FunctionCall: Equatable, Encodable {
/// The name of the function to call.
public let name: String

/// The function parameters and values.
public let args: JSONObject
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema
public class Schema: Encodable {
let type: DataType

let format: String?

let description: String?

let nullable: Bool?

let enumValues: [String]?

let items: Schema?

let properties: [String: Schema]?

let required: [String]?

enum CodingKeys: String, CodingKey {
case type
case format
case description
case nullable
case enumValues = "enum"
case items
case properties
case required
}

public init(type: DataType, format: String? = nil, description: String? = nil,
nullable: Bool? = nil,
enumValues: [String]? = nil, items: Schema? = nil,
properties: [String: Schema]? = nil,
required: [String]? = nil) {
self.type = type
self.format = format
self.description = description
self.nullable = nullable
self.enumValues = enumValues
self.items = items
self.properties = properties
self.required = required
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type
public enum DataType: String, Encodable {
case string = "STRING"
case number = "NUMBER"
case integer = "INTEGER"
case boolean = "BOOLEAN"
case array = "ARRAY"
case object = "OBJECT"
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration
public struct FunctionDeclaration {
let name: String

let description: String

let parameters: Schema?

public init(name: String, description: String, parameters: Schema?) {
self.name = name
self.description = description
self.parameters = parameters
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool
public struct Tool: Encodable {
let functionDeclarations: [FunctionDeclaration]?

public init(functionDeclarations: [FunctionDeclaration]?) {
self.functionDeclarations = functionDeclarations
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse
public struct FunctionResponse: Equatable, Encodable {
let name: String

let response: JSONObject

public init(name: String, response: JSONObject) {
self.name = name
self.response = response
}
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
case args
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
name = try container.decode(String.self, forKey: .name)
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
self.args = args
} else {
args = JSONObject()
}
}
}

extension FunctionDeclaration: Encodable {
enum CodingKeys: String, CodingKey {
case name
case description
case parameters
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(name, forKey: .name)
try container.encode(description, forKey: .description)
try container.encode(parameters, forKey: .parameters)
}
}
2 changes: 2 additions & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct GenerateContentRequest {
let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable {
case contents
case generationConfig
case safetySettings
case tools
}
}

Expand Down
Loading