|
| 1 | +// Copyright 2023 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +import FirebaseVertexAI |
| 16 | +import Foundation |
| 17 | +import UIKit |
| 18 | + |
| 19 | +@MainActor |
| 20 | +class FunctionCallingViewModel: ObservableObject { |
| 21 | + /// This array holds both the user's and the system's chat messages |
| 22 | + @Published var messages = [ChatMessage]() |
| 23 | + |
| 24 | + /// Indicates we're waiting for the model to finish |
| 25 | + @Published var busy = false |
| 26 | + |
| 27 | + @Published var error: Error? |
| 28 | + var hasError: Bool { |
| 29 | + return error != nil |
| 30 | + } |
| 31 | + |
| 32 | + /// Function calls pending processing |
| 33 | + private var functionCalls = [FunctionCall]() |
| 34 | + |
| 35 | + private var model: GenerativeModel |
| 36 | + private var chat: Chat |
| 37 | + |
| 38 | + private var chatTask: Task<Void, Never>? |
| 39 | + |
| 40 | + init() { |
| 41 | + model = VertexAI.vertexAI().generativeModel( |
| 42 | + modelName: "gemini-1.0-pro", |
| 43 | + tools: [Tool(functionDeclarations: [ |
| 44 | + FunctionDeclaration( |
| 45 | + name: "get_exchange_rate", |
| 46 | + description: "Get the exchange rate for currencies between countries", |
| 47 | + parameters: [ |
| 48 | + "currency_from": Schema( |
| 49 | + type: .string, |
| 50 | + format: "enum", |
| 51 | + description: "The currency to convert from in ISO 4217 format", |
| 52 | + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] |
| 53 | + ), |
| 54 | + "currency_to": Schema( |
| 55 | + type: .string, |
| 56 | + format: "enum", |
| 57 | + description: "The currency to convert to in ISO 4217 format", |
| 58 | + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] |
| 59 | + ), |
| 60 | + ], |
| 61 | + requiredParameters: ["currency_from", "currency_to"] |
| 62 | + ), |
| 63 | + ])] |
| 64 | + ) |
| 65 | + chat = model.startChat() |
| 66 | + } |
| 67 | + |
| 68 | + func sendMessage(_ text: String, streaming: Bool = true) async { |
| 69 | + error = nil |
| 70 | + chatTask?.cancel() |
| 71 | + |
| 72 | + chatTask = Task { |
| 73 | + busy = true |
| 74 | + defer { |
| 75 | + busy = false |
| 76 | + } |
| 77 | + |
| 78 | + // first, add the user's message to the chat |
| 79 | + let userMessage = ChatMessage(message: text, participant: .user) |
| 80 | + messages.append(userMessage) |
| 81 | + |
| 82 | + // add a pending message while we're waiting for a response from the backend |
| 83 | + let systemMessage = ChatMessage.pending(participant: .system) |
| 84 | + messages.append(systemMessage) |
| 85 | + |
| 86 | + print(messages) |
| 87 | + do { |
| 88 | + repeat { |
| 89 | + if streaming { |
| 90 | + try await internalSendMessageStreaming(text) |
| 91 | + } else { |
| 92 | + try await internalSendMessage(text) |
| 93 | + } |
| 94 | + } while !functionCalls.isEmpty |
| 95 | + } catch { |
| 96 | + self.error = error |
| 97 | + print(error.localizedDescription) |
| 98 | + messages.removeLast() |
| 99 | + } |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + func startNewChat() { |
| 104 | + stop() |
| 105 | + error = nil |
| 106 | + chat = model.startChat() |
| 107 | + messages.removeAll() |
| 108 | + } |
| 109 | + |
| 110 | + func stop() { |
| 111 | + chatTask?.cancel() |
| 112 | + error = nil |
| 113 | + } |
| 114 | + |
| 115 | + private func internalSendMessageStreaming(_ text: String) async throws { |
| 116 | + let functionResponses = try await processFunctionCalls() |
| 117 | + let responseStream: AsyncThrowingStream<GenerateContentResponse, Error> |
| 118 | + if functionResponses.isEmpty { |
| 119 | + responseStream = chat.sendMessageStream(text) |
| 120 | + } else { |
| 121 | + for functionResponse in functionResponses { |
| 122 | + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) |
| 123 | + } |
| 124 | + responseStream = chat.sendMessageStream(functionResponses.modelContent()) |
| 125 | + } |
| 126 | + for try await chunk in responseStream { |
| 127 | + processResponseContent(content: chunk) |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + private func internalSendMessage(_ text: String) async throws { |
| 132 | + let functionResponses = try await processFunctionCalls() |
| 133 | + let response: GenerateContentResponse |
| 134 | + if functionResponses.isEmpty { |
| 135 | + response = try await chat.sendMessage(text) |
| 136 | + } else { |
| 137 | + for functionResponse in functionResponses { |
| 138 | + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) |
| 139 | + } |
| 140 | + response = try await chat.sendMessage(functionResponses.modelContent()) |
| 141 | + } |
| 142 | + processResponseContent(content: response) |
| 143 | + } |
| 144 | + |
| 145 | + func processResponseContent(content: GenerateContentResponse) { |
| 146 | + guard let candidate = content.candidates.first else { |
| 147 | + fatalError("No candidate.") |
| 148 | + } |
| 149 | + |
| 150 | + for part in candidate.content.parts { |
| 151 | + switch part { |
| 152 | + case let .text(text): |
| 153 | + // replace pending message with backend response |
| 154 | + messages[messages.count - 1].message += text |
| 155 | + messages[messages.count - 1].pending = false |
| 156 | + case let .functionCall(functionCall): |
| 157 | + messages.insert(functionCall.chatMessage(), at: messages.count - 1) |
| 158 | + functionCalls.append(functionCall) |
| 159 | + case .data, .functionResponse: |
| 160 | + fatalError("Unsupported response content.") |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + func processFunctionCalls() async throws -> [FunctionResponse] { |
| 166 | + var functionResponses = [FunctionResponse]() |
| 167 | + for functionCall in functionCalls { |
| 168 | + switch functionCall.name { |
| 169 | + case "get_exchange_rate": |
| 170 | + let exchangeRates = getExchangeRate(args: functionCall.args) |
| 171 | + functionResponses.append(FunctionResponse( |
| 172 | + name: "get_exchange_rate", |
| 173 | + response: exchangeRates |
| 174 | + )) |
| 175 | + default: |
| 176 | + fatalError("Unknown function named \"\(functionCall.name)\".") |
| 177 | + } |
| 178 | + } |
| 179 | + functionCalls = [] |
| 180 | + |
| 181 | + return functionResponses |
| 182 | + } |
| 183 | + |
| 184 | + // MARK: - Callable Functions |
| 185 | + |
| 186 | + func getExchangeRate(args: JSONObject) -> JSONObject { |
| 187 | + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) |
| 188 | + guard case let .string(from) = args["currency_from"] else { |
| 189 | + fatalError("Missing `currency_from` parameter.") |
| 190 | + } |
| 191 | + guard case let .string(to) = args["currency_to"] else { |
| 192 | + fatalError("Missing `currency_to` parameter.") |
| 193 | + } |
| 194 | + |
| 195 | + // 2. Get the exchange rate |
| 196 | + let allRates: [String: [String: Double]] = [ |
| 197 | + "AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], |
| 198 | + "CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], |
| 199 | + "EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], |
| 200 | + "GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], |
| 201 | + "JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], |
| 202 | + "USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], |
| 203 | + ] |
| 204 | + guard let fromRates = allRates[from] else { |
| 205 | + return ["error": .string("No data for currency \(from).")] |
| 206 | + } |
| 207 | + guard let toRate = fromRates[to] else { |
| 208 | + return ["error": .string("No data for currency \(to).")] |
| 209 | + } |
| 210 | + |
| 211 | + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) |
| 212 | + return ["rates": .number(toRate)] |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +private extension FunctionCall { |
| 217 | + func chatMessage() -> ChatMessage { |
| 218 | + let encoder = JSONEncoder() |
| 219 | + encoder.outputFormatting = .prettyPrinted |
| 220 | + |
| 221 | + let jsonData: Data |
| 222 | + do { |
| 223 | + jsonData = try encoder.encode(self) |
| 224 | + } catch { |
| 225 | + fatalError("JSON Encoding Failed: \(error.localizedDescription)") |
| 226 | + } |
| 227 | + guard let json = String(data: jsonData, encoding: .utf8) else { |
| 228 | + fatalError("Failed to convert JSON data to a String.") |
| 229 | + } |
| 230 | + let messageText = "Function call requested by model:\n```\n\(json)\n```" |
| 231 | + |
| 232 | + return ChatMessage(message: messageText, participant: .system) |
| 233 | + } |
| 234 | +} |
| 235 | + |
| 236 | +private extension FunctionResponse { |
| 237 | + func chatMessage() -> ChatMessage { |
| 238 | + let encoder = JSONEncoder() |
| 239 | + encoder.outputFormatting = .prettyPrinted |
| 240 | + |
| 241 | + let jsonData: Data |
| 242 | + do { |
| 243 | + jsonData = try encoder.encode(self) |
| 244 | + } catch { |
| 245 | + fatalError("JSON Encoding Failed: \(error.localizedDescription)") |
| 246 | + } |
| 247 | + guard let json = String(data: jsonData, encoding: .utf8) else { |
| 248 | + fatalError("Failed to convert JSON data to a String.") |
| 249 | + } |
| 250 | + let messageText = "Function response returned by app:\n```\n\(json)\n```" |
| 251 | + |
| 252 | + return ChatMessage(message: messageText, participant: .user) |
| 253 | + } |
| 254 | +} |
| 255 | + |
| 256 | +private extension [FunctionResponse] { |
| 257 | + func modelContent() -> [ModelContent] { |
| 258 | + return self.map { ModelContent( |
| 259 | + role: "function", |
| 260 | + parts: [ModelContent.Part.functionResponse($0)] |
| 261 | + ) |
| 262 | + } |
| 263 | + } |
| 264 | +} |
0 commit comments