Skip to content

Commit 2063447

Browse files
authored
Add function calling sample (#127)
1 parent 0976d0f commit 2063447

File tree

4 files changed

+434
-0
lines changed

4 files changed

+434
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import GenerativeAIUIComponents
16+
import GoogleGenerativeAI
17+
import SwiftUI
18+
19+
struct FunctionCallingScreen: View {
20+
@EnvironmentObject
21+
var viewModel: FunctionCallingViewModel
22+
23+
@State
24+
private var userPrompt = "What is 100 Euros in U.S. Dollars?"
25+
26+
enum FocusedField: Hashable {
27+
case message
28+
}
29+
30+
@FocusState
31+
var focusedField: FocusedField?
32+
33+
var body: some View {
34+
VStack {
35+
ScrollViewReader { scrollViewProxy in
36+
List {
37+
Text("Interact with a currency conversion API using function calling in Gemini.")
38+
ForEach(viewModel.messages) { message in
39+
MessageView(message: message)
40+
}
41+
if let error = viewModel.error {
42+
ErrorView(error: error)
43+
.tag("errorView")
44+
}
45+
}
46+
.listStyle(.plain)
47+
.onChange(of: viewModel.messages, perform: { newValue in
48+
if viewModel.hasError {
49+
// Wait for a short moment to make sure we can actually scroll to the bottom.
50+
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
51+
withAnimation {
52+
scrollViewProxy.scrollTo("errorView", anchor: .bottom)
53+
}
54+
focusedField = .message
55+
}
56+
} else {
57+
guard let lastMessage = viewModel.messages.last else { return }
58+
59+
// Wait for a short moment to make sure we can actually scroll to the bottom.
60+
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
61+
withAnimation {
62+
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom)
63+
}
64+
focusedField = .message
65+
}
66+
}
67+
})
68+
}
69+
InputField("Message...", text: $userPrompt) {
70+
Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill")
71+
.font(.title)
72+
}
73+
.focused($focusedField, equals: .message)
74+
.onSubmit { sendOrStop() }
75+
}
76+
.toolbar {
77+
ToolbarItem(placement: .primaryAction) {
78+
Button(action: newChat) {
79+
Image(systemName: "square.and.pencil")
80+
}
81+
}
82+
}
83+
.navigationTitle("Function Calling")
84+
.onAppear {
85+
focusedField = .message
86+
}
87+
}
88+
89+
private func sendMessage() {
90+
Task {
91+
let prompt = userPrompt
92+
userPrompt = ""
93+
await viewModel.sendMessage(prompt, streaming: true)
94+
}
95+
}
96+
97+
private func sendOrStop() {
98+
if viewModel.busy {
99+
viewModel.stop()
100+
} else {
101+
sendMessage()
102+
}
103+
}
104+
105+
private func newChat() {
106+
viewModel.startNewChat()
107+
}
108+
}
109+
110+
struct FunctionCallingScreen_Previews: PreviewProvider {
111+
struct ContainerView: View {
112+
@EnvironmentObject
113+
var viewModel: FunctionCallingViewModel
114+
115+
var body: some View {
116+
FunctionCallingScreen()
117+
.onAppear {
118+
viewModel.messages = ChatMessage.samples
119+
}
120+
}
121+
}
122+
123+
static var previews: some View {
124+
NavigationStack {
125+
FunctionCallingScreen().environmentObject(FunctionCallingViewModel())
126+
}
127+
}
128+
}
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import GoogleGenerativeAI
17+
import UIKit
18+
19+
@MainActor
20+
class FunctionCallingViewModel: ObservableObject {
21+
/// This array holds both the user's and the system's chat messages
22+
@Published var messages = [ChatMessage]()
23+
24+
/// Indicates we're waiting for the model to finish
25+
@Published var busy = false
26+
27+
@Published var error: Error?
28+
var hasError: Bool {
29+
return error != nil
30+
}
31+
32+
/// Function calls pending processing
33+
private var functionCalls = [FunctionCall]()
34+
35+
private var model: GenerativeModel
36+
private var chat: Chat
37+
38+
private var chatTask: Task<Void, Never>?
39+
40+
init() {
41+
model = GenerativeModel(
42+
name: "gemini-1.0-pro",
43+
apiKey: APIKey.default,
44+
tools: [Tool(functionDeclarations: [
45+
FunctionDeclaration(
46+
name: "get_exchange_rate",
47+
description: "Get the exchange rate for currencies between countries",
48+
parameters: [
49+
"currency_from": Schema(
50+
type: .string,
51+
format: "enum",
52+
description: "The currency to convert from in ISO 4217 format",
53+
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
54+
),
55+
"currency_to": Schema(
56+
type: .string,
57+
format: "enum",
58+
description: "The currency to convert to in ISO 4217 format",
59+
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
60+
),
61+
],
62+
requiredParameters: ["currency_from", "currency_to"]
63+
),
64+
])],
65+
requestOptions: RequestOptions(apiVersion: "v1beta")
66+
)
67+
chat = model.startChat()
68+
}
69+
70+
func sendMessage(_ text: String, streaming: Bool = true) async {
71+
error = nil
72+
chatTask?.cancel()
73+
74+
chatTask = Task {
75+
busy = true
76+
defer {
77+
busy = false
78+
}
79+
80+
// first, add the user's message to the chat
81+
let userMessage = ChatMessage(message: text, participant: .user)
82+
messages.append(userMessage)
83+
84+
// add a pending message while we're waiting for a response from the backend
85+
let systemMessage = ChatMessage.pending(participant: .system)
86+
messages.append(systemMessage)
87+
88+
print(messages)
89+
do {
90+
repeat {
91+
if streaming {
92+
try await internalSendMessageStreaming(text)
93+
} else {
94+
try await internalSendMessage(text)
95+
}
96+
} while !functionCalls.isEmpty
97+
} catch {
98+
self.error = error
99+
print(error.localizedDescription)
100+
messages.removeLast()
101+
}
102+
}
103+
}
104+
105+
func startNewChat() {
106+
stop()
107+
error = nil
108+
chat = model.startChat()
109+
messages.removeAll()
110+
}
111+
112+
func stop() {
113+
chatTask?.cancel()
114+
error = nil
115+
}
116+
117+
private func internalSendMessageStreaming(_ text: String) async throws {
118+
let functionResponses = try await processFunctionCalls()
119+
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
120+
if functionResponses.isEmpty {
121+
responseStream = chat.sendMessageStream(text)
122+
} else {
123+
for functionResponse in functionResponses {
124+
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
125+
}
126+
responseStream = chat.sendMessageStream(functionResponses.modelContent())
127+
}
128+
for try await chunk in responseStream {
129+
processResponseContent(content: chunk)
130+
}
131+
}
132+
133+
private func internalSendMessage(_ text: String) async throws {
134+
let functionResponses = try await processFunctionCalls()
135+
let response: GenerateContentResponse
136+
if functionResponses.isEmpty {
137+
response = try await chat.sendMessage(text)
138+
} else {
139+
for functionResponse in functionResponses {
140+
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
141+
}
142+
response = try await chat.sendMessage(functionResponses.modelContent())
143+
}
144+
processResponseContent(content: response)
145+
}
146+
147+
func processResponseContent(content: GenerateContentResponse) {
148+
guard let candidate = content.candidates.first else {
149+
fatalError("No candidate.")
150+
}
151+
152+
for part in candidate.content.parts {
153+
switch part {
154+
case let .text(text):
155+
// replace pending message with backend response
156+
messages[messages.count - 1].message += text
157+
messages[messages.count - 1].pending = false
158+
case let .functionCall(functionCall):
159+
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
160+
functionCalls.append(functionCall)
161+
case .data, .functionResponse:
162+
fatalError("Unsupported response content.")
163+
}
164+
}
165+
}
166+
167+
func processFunctionCalls() async throws -> [FunctionResponse] {
168+
var functionResponses = [FunctionResponse]()
169+
for functionCall in functionCalls {
170+
switch functionCall.name {
171+
case "get_exchange_rate":
172+
let exchangeRates = getExchangeRate(args: functionCall.args)
173+
functionResponses.append(FunctionResponse(
174+
name: "get_exchange_rate",
175+
response: exchangeRates
176+
))
177+
default:
178+
fatalError("Unknown function named \"\(functionCall.name)\".")
179+
}
180+
}
181+
functionCalls = []
182+
183+
return functionResponses
184+
}
185+
186+
// MARK: - Callable Functions
187+
188+
func getExchangeRate(args: JSONObject) -> JSONObject {
189+
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
190+
guard case let .string(from) = args["currency_from"] else {
191+
fatalError("Missing `currency_from` parameter.")
192+
}
193+
guard case let .string(to) = args["currency_to"] else {
194+
fatalError("Missing `currency_to` parameter.")
195+
}
196+
197+
// 2. Get the exchange rate
198+
let allRates: [String: [String: Double]] = [
199+
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
200+
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
201+
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
202+
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
203+
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
204+
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
205+
]
206+
guard let fromRates = allRates[from] else {
207+
return ["error": .string("No data for currency \(from).")]
208+
}
209+
guard let toRate = fromRates[to] else {
210+
return ["error": .string("No data for currency \(to).")]
211+
}
212+
213+
// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
214+
return ["rates": .number(toRate)]
215+
}
216+
}
217+
218+
private extension FunctionCall {
219+
func chatMessage() -> ChatMessage {
220+
let encoder = JSONEncoder()
221+
encoder.outputFormatting = .prettyPrinted
222+
223+
let jsonData: Data
224+
do {
225+
jsonData = try encoder.encode(self)
226+
} catch {
227+
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
228+
}
229+
guard let json = String(data: jsonData, encoding: .utf8) else {
230+
fatalError("Failed to convert JSON data to a String.")
231+
}
232+
let messageText = "Function call requested by model:\n```\n\(json)\n```"
233+
234+
return ChatMessage(message: messageText, participant: .system)
235+
}
236+
}
237+
238+
private extension FunctionResponse {
239+
func chatMessage() -> ChatMessage {
240+
let encoder = JSONEncoder()
241+
encoder.outputFormatting = .prettyPrinted
242+
243+
let jsonData: Data
244+
do {
245+
jsonData = try encoder.encode(self)
246+
} catch {
247+
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
248+
}
249+
guard let json = String(data: jsonData, encoding: .utf8) else {
250+
fatalError("Failed to convert JSON data to a String.")
251+
}
252+
let messageText = "Function response returned by app:\n```\n\(json)\n```"
253+
254+
return ChatMessage(message: messageText, participant: .user)
255+
}
256+
}
257+
258+
private extension [FunctionResponse] {
259+
func modelContent() -> [ModelContent] {
260+
return self.map { ModelContent(
261+
role: "function",
262+
parts: [ModelContent.Part.functionResponse($0)]
263+
)
264+
}
265+
}
266+
}

0 commit comments

Comments
 (0)