Skip to content

Commit 2d70719

Browse files
authored
Update vertex-preview-0.1.0 for pre-release (#12751)
1 parent 9b288f9 commit 2d70719

File tree

12 files changed

+553
-3
lines changed

12 files changed

+553
-3
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseVertexAI
16+
import GenerativeAIUIComponents
17+
import SwiftUI
18+
19+
struct FunctionCallingScreen: View {
20+
@EnvironmentObject
21+
var viewModel: FunctionCallingViewModel
22+
23+
@State
24+
private var userPrompt = "What is 100 Euros in U.S. Dollars?"
25+
26+
enum FocusedField: Hashable {
27+
case message
28+
}
29+
30+
@FocusState
31+
var focusedField: FocusedField?
32+
33+
var body: some View {
34+
VStack {
35+
ScrollViewReader { scrollViewProxy in
36+
List {
37+
Text("Interact with a currency conversion API using function calling in Gemini.")
38+
ForEach(viewModel.messages) { message in
39+
MessageView(message: message)
40+
}
41+
if let error = viewModel.error {
42+
ErrorView(error: error)
43+
.tag("errorView")
44+
}
45+
}
46+
.listStyle(.plain)
47+
.onChange(of: viewModel.messages, perform: { newValue in
48+
if viewModel.hasError {
49+
// Wait for a short moment to make sure we can actually scroll to the bottom.
50+
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
51+
withAnimation {
52+
scrollViewProxy.scrollTo("errorView", anchor: .bottom)
53+
}
54+
focusedField = .message
55+
}
56+
} else {
57+
guard let lastMessage = viewModel.messages.last else { return }
58+
59+
// Wait for a short moment to make sure we can actually scroll to the bottom.
60+
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
61+
withAnimation {
62+
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom)
63+
}
64+
focusedField = .message
65+
}
66+
}
67+
})
68+
}
69+
InputField("Message...", text: $userPrompt) {
70+
Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill")
71+
.font(.title)
72+
}
73+
.focused($focusedField, equals: .message)
74+
.onSubmit { sendOrStop() }
75+
}
76+
.toolbar {
77+
ToolbarItem(placement: .primaryAction) {
78+
Button(action: newChat) {
79+
Image(systemName: "square.and.pencil")
80+
}
81+
}
82+
}
83+
.navigationTitle("Function Calling")
84+
.onAppear {
85+
focusedField = .message
86+
}
87+
}
88+
89+
private func sendMessage() {
90+
Task {
91+
let prompt = userPrompt
92+
userPrompt = ""
93+
await viewModel.sendMessage(prompt, streaming: true)
94+
}
95+
}
96+
97+
private func sendOrStop() {
98+
if viewModel.busy {
99+
viewModel.stop()
100+
} else {
101+
sendMessage()
102+
}
103+
}
104+
105+
private func newChat() {
106+
viewModel.startNewChat()
107+
}
108+
}
109+
110+
struct FunctionCallingScreen_Previews: PreviewProvider {
111+
struct ContainerView: View {
112+
@EnvironmentObject
113+
var viewModel: FunctionCallingViewModel
114+
115+
var body: some View {
116+
FunctionCallingScreen()
117+
.onAppear {
118+
viewModel.messages = ChatMessage.samples
119+
}
120+
}
121+
}
122+
123+
static var previews: some View {
124+
NavigationStack {
125+
FunctionCallingScreen().environmentObject(FunctionCallingViewModel())
126+
}
127+
}
128+
}
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseVertexAI
16+
import Foundation
17+
import UIKit
18+
19+
@MainActor
20+
class FunctionCallingViewModel: ObservableObject {
21+
/// This array holds both the user's and the system's chat messages
22+
@Published var messages = [ChatMessage]()
23+
24+
/// Indicates we're waiting for the model to finish
25+
@Published var busy = false
26+
27+
@Published var error: Error?
28+
var hasError: Bool {
29+
return error != nil
30+
}
31+
32+
/// Function calls pending processing
33+
private var functionCalls = [FunctionCall]()
34+
35+
private var model: GenerativeModel
36+
private var chat: Chat
37+
38+
private var chatTask: Task<Void, Never>?
39+
40+
init() {
41+
model = VertexAI.vertexAI().generativeModel(
42+
modelName: "gemini-1.0-pro",
43+
tools: [Tool(functionDeclarations: [
44+
FunctionDeclaration(
45+
name: "get_exchange_rate",
46+
description: "Get the exchange rate for currencies between countries",
47+
parameters: [
48+
"currency_from": Schema(
49+
type: .string,
50+
format: "enum",
51+
description: "The currency to convert from in ISO 4217 format",
52+
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
53+
),
54+
"currency_to": Schema(
55+
type: .string,
56+
format: "enum",
57+
description: "The currency to convert to in ISO 4217 format",
58+
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
59+
),
60+
],
61+
requiredParameters: ["currency_from", "currency_to"]
62+
),
63+
])]
64+
)
65+
chat = model.startChat()
66+
}
67+
68+
func sendMessage(_ text: String, streaming: Bool = true) async {
69+
error = nil
70+
chatTask?.cancel()
71+
72+
chatTask = Task {
73+
busy = true
74+
defer {
75+
busy = false
76+
}
77+
78+
// first, add the user's message to the chat
79+
let userMessage = ChatMessage(message: text, participant: .user)
80+
messages.append(userMessage)
81+
82+
// add a pending message while we're waiting for a response from the backend
83+
let systemMessage = ChatMessage.pending(participant: .system)
84+
messages.append(systemMessage)
85+
86+
print(messages)
87+
do {
88+
repeat {
89+
if streaming {
90+
try await internalSendMessageStreaming(text)
91+
} else {
92+
try await internalSendMessage(text)
93+
}
94+
} while !functionCalls.isEmpty
95+
} catch {
96+
self.error = error
97+
print(error.localizedDescription)
98+
messages.removeLast()
99+
}
100+
}
101+
}
102+
103+
func startNewChat() {
104+
stop()
105+
error = nil
106+
chat = model.startChat()
107+
messages.removeAll()
108+
}
109+
110+
func stop() {
111+
chatTask?.cancel()
112+
error = nil
113+
}
114+
115+
private func internalSendMessageStreaming(_ text: String) async throws {
116+
let functionResponses = try await processFunctionCalls()
117+
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
118+
if functionResponses.isEmpty {
119+
responseStream = chat.sendMessageStream(text)
120+
} else {
121+
for functionResponse in functionResponses {
122+
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
123+
}
124+
responseStream = chat.sendMessageStream(functionResponses.modelContent())
125+
}
126+
for try await chunk in responseStream {
127+
processResponseContent(content: chunk)
128+
}
129+
}
130+
131+
private func internalSendMessage(_ text: String) async throws {
132+
let functionResponses = try await processFunctionCalls()
133+
let response: GenerateContentResponse
134+
if functionResponses.isEmpty {
135+
response = try await chat.sendMessage(text)
136+
} else {
137+
for functionResponse in functionResponses {
138+
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
139+
}
140+
response = try await chat.sendMessage(functionResponses.modelContent())
141+
}
142+
processResponseContent(content: response)
143+
}
144+
145+
func processResponseContent(content: GenerateContentResponse) {
146+
guard let candidate = content.candidates.first else {
147+
fatalError("No candidate.")
148+
}
149+
150+
for part in candidate.content.parts {
151+
switch part {
152+
case let .text(text):
153+
// replace pending message with backend response
154+
messages[messages.count - 1].message += text
155+
messages[messages.count - 1].pending = false
156+
case let .functionCall(functionCall):
157+
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
158+
functionCalls.append(functionCall)
159+
case .data, .functionResponse:
160+
fatalError("Unsupported response content.")
161+
}
162+
}
163+
}
164+
165+
func processFunctionCalls() async throws -> [FunctionResponse] {
166+
var functionResponses = [FunctionResponse]()
167+
for functionCall in functionCalls {
168+
switch functionCall.name {
169+
case "get_exchange_rate":
170+
let exchangeRates = getExchangeRate(args: functionCall.args)
171+
functionResponses.append(FunctionResponse(
172+
name: "get_exchange_rate",
173+
response: exchangeRates
174+
))
175+
default:
176+
fatalError("Unknown function named \"\(functionCall.name)\".")
177+
}
178+
}
179+
functionCalls = []
180+
181+
return functionResponses
182+
}
183+
184+
// MARK: - Callable Functions
185+
186+
func getExchangeRate(args: JSONObject) -> JSONObject {
187+
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
188+
guard case let .string(from) = args["currency_from"] else {
189+
fatalError("Missing `currency_from` parameter.")
190+
}
191+
guard case let .string(to) = args["currency_to"] else {
192+
fatalError("Missing `currency_to` parameter.")
193+
}
194+
195+
// 2. Get the exchange rate
196+
let allRates: [String: [String: Double]] = [
197+
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
198+
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
199+
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
200+
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
201+
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
202+
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
203+
]
204+
guard let fromRates = allRates[from] else {
205+
return ["error": .string("No data for currency \(from).")]
206+
}
207+
guard let toRate = fromRates[to] else {
208+
return ["error": .string("No data for currency \(to).")]
209+
}
210+
211+
// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
212+
return ["rates": .number(toRate)]
213+
}
214+
}
215+
216+
private extension FunctionCall {
217+
func chatMessage() -> ChatMessage {
218+
let encoder = JSONEncoder()
219+
encoder.outputFormatting = .prettyPrinted
220+
221+
let jsonData: Data
222+
do {
223+
jsonData = try encoder.encode(self)
224+
} catch {
225+
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
226+
}
227+
guard let json = String(data: jsonData, encoding: .utf8) else {
228+
fatalError("Failed to convert JSON data to a String.")
229+
}
230+
let messageText = "Function call requested by model:\n```\n\(json)\n```"
231+
232+
return ChatMessage(message: messageText, participant: .system)
233+
}
234+
}
235+
236+
private extension FunctionResponse {
237+
func chatMessage() -> ChatMessage {
238+
let encoder = JSONEncoder()
239+
encoder.outputFormatting = .prettyPrinted
240+
241+
let jsonData: Data
242+
do {
243+
jsonData = try encoder.encode(self)
244+
} catch {
245+
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
246+
}
247+
guard let json = String(data: jsonData, encoding: .utf8) else {
248+
fatalError("Failed to convert JSON data to a String.")
249+
}
250+
let messageText = "Function response returned by app:\n```\n\(json)\n```"
251+
252+
return ChatMessage(message: messageText, participant: .user)
253+
}
254+
}
255+
256+
private extension [FunctionResponse] {
257+
func modelContent() -> [ModelContent] {
258+
return self.map { ModelContent(
259+
role: "function",
260+
parts: [ModelContent.Part.functionResponse($0)]
261+
)
262+
}
263+
}
264+
}

0 commit comments

Comments
 (0)