Skip to content

Commit 2bf9fe5

Browse files
authored
Add code execution support (#196)
1 parent 672fdb7 commit 2bf9fe5

11 files changed

+644
-12
lines changed

Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class FunctionCallingViewModel: ObservableObject {
157157
case let .functionCall(functionCall):
158158
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
159159
functionCalls.append(functionCall)
160-
case .data, .fileData, .functionResponse:
160+
case .data, .fileData, .functionResponse, .executableCode, .codeExecutionResult:
161161
fatalError("Unsupported response content.")
162162
}
163163
}

Sources/GoogleAI/Chat.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ public class Chat {
160160
case let .text(str):
161161
combinedText += str
162162

163-
case .data, .fileData, .functionCall, .functionResponse:
163+
case .data, .fileData, .functionCall, .functionResponse, .executableCode,
164+
.codeExecutionResult:
164165
// Don't combine it, just add to the content. If there's any text pending, add that as
165166
// a part.
166167
if !combinedText.isEmpty {

Sources/GoogleAI/FunctionCalling.swift

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ public struct Tool {
161161
/// A list of `FunctionDeclarations` available to the model.
162162
let functionDeclarations: [FunctionDeclaration]?
163163

164+
/// Enables the model to execute code as part of generation.
165+
let codeExecution: CodeExecution?
166+
164167
/// Constructs a new `Tool`.
165168
///
166169
/// - Parameters:
@@ -172,8 +175,11 @@ public struct Tool {
172175
/// populating ``FunctionCall`` in the response. The next conversation turn may contain a
173176
/// ``FunctionResponse`` in ``ModelContent/Part/functionResponse(_:)`` with the
174177
/// ``ModelContent/role`` "function", providing generation context for the next model turn.
175-
public init(functionDeclarations: [FunctionDeclaration]?) {
178+
/// - codeExecution: Enables the model to execute code as part of generation, if provided.
179+
public init(functionDeclarations: [FunctionDeclaration]? = nil,
180+
codeExecution: CodeExecution? = nil) {
176181
self.functionDeclarations = functionDeclarations
182+
self.codeExecution = codeExecution
177183
}
178184
}
179185

@@ -244,6 +250,55 @@ public struct FunctionResponse: Equatable {
244250
}
245251
}
246252

253+
/// Tool that executes code generated by the model, automatically returning the result to the model.
254+
///
255+
/// This type has no fields. See ``ExecutableCode`` and ``CodeExecutionResult``, which are only
256+
/// generated when using this tool.
257+
public struct CodeExecution {
258+
/// Constructs a new `CodeExecution` tool.
259+
public init() {}
260+
}
261+
262+
/// Code generated by the model that is meant to be executed, and the result returned to the model.
263+
///
264+
/// Only generated when using the ``CodeExecution`` tool, in which case the code will automatically
265+
/// be executed, and a corresponding ``CodeExecutionResult`` will also be generated.
266+
public struct ExecutableCode: Equatable {
267+
/// The programming language of the ``code``.
268+
public let language: String
269+
270+
/// The code to be executed.
271+
public let code: String
272+
}
273+
274+
/// Result of executing the ``ExecutableCode``.
275+
///
276+
/// Only generated when using the ``CodeExecution`` tool, and always follows a part containing the
277+
/// ``ExecutableCode``.
278+
public struct CodeExecutionResult: Equatable {
279+
/// Possible outcomes of the code execution.
280+
public enum Outcome: String {
281+
/// An unrecognized code execution outcome was provided.
282+
case unknown = "OUTCOME_UNKNOWN"
283+
/// Unspecified status; this value should not be used.
284+
case unspecified = "OUTCOME_UNSPECIFIED"
285+
/// Code execution completed successfully.
286+
case ok = "OUTCOME_OK"
287+
/// Code execution finished but with a failure; ``CodeExecutionResult/output`` should contain
288+
/// the failure details from `stderr`.
289+
case failed = "OUTCOME_FAILED"
290+
/// Code execution ran for too long, and was cancelled. There may or may not be a partial
291+
/// ``CodeExecutionResult/output`` present.
292+
case deadlineExceeded = "OUTCOME_DEADLINE_EXCEEDED"
293+
}
294+
295+
/// Outcome of the code execution.
296+
public let outcome: Outcome
297+
298+
/// Contains `stdout` when code execution is successful, `stderr` or other description otherwise.
299+
public let output: String
300+
}
301+
247302
// MARK: - Codable Conformance
248303

249304
extension FunctionCall: Decodable {
@@ -293,3 +348,31 @@ extension FunctionCallingConfig.Mode: Encodable {}
293348
extension ToolConfig: Encodable {}
294349

295350
extension FunctionResponse: Encodable {}
351+
352+
extension CodeExecution: Encodable {}
353+
354+
extension ExecutableCode: Codable {}
355+
356+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
357+
extension CodeExecutionResult.Outcome: Codable {
358+
public init(from decoder: any Decoder) throws {
359+
let value = try decoder.singleValueContainer().decode(String.self)
360+
guard let decodedOutcome = CodeExecutionResult.Outcome(rawValue: value) else {
361+
Logging.default
362+
.error("[GoogleGenerativeAI] Unrecognized Outcome with value \"\(value)\".")
363+
self = .unknown
364+
return
365+
}
366+
367+
self = decodedOutcome
368+
}
369+
}
370+
371+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
372+
extension CodeExecutionResult: Codable {
373+
public init(from decoder: any Decoder) throws {
374+
let container = try decoder.container(keyedBy: CodingKeys.self)
375+
outcome = try container.decode(Outcome.self, forKey: .outcome)
376+
output = try container.decodeIfPresent(String.self, forKey: .output) ?? ""
377+
}
378+
}

Sources/GoogleAI/GenerateContentResponse.swift

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,31 @@ public struct GenerateContentResponse {
4646
return nil
4747
}
4848
let textValues: [String] = candidate.content.parts.compactMap { part in
49-
guard case let .text(text) = part else {
49+
switch part {
50+
case let .text(text):
51+
return text
52+
case let .executableCode(executableCode):
53+
let codeBlockLanguage: String
54+
if executableCode.language == "LANGUAGE_UNSPECIFIED" {
55+
codeBlockLanguage = ""
56+
} else {
57+
codeBlockLanguage = executableCode.language.lowercased()
58+
}
59+
return "```\(codeBlockLanguage)\n\(executableCode.code)\n```"
60+
case let .codeExecutionResult(codeExecutionResult):
61+
if codeExecutionResult.output.isEmpty {
62+
return nil
63+
}
64+
return "```\n\(codeExecutionResult.output)\n```"
65+
case .data, .fileData, .functionCall, .functionResponse:
5066
return nil
5167
}
52-
return text
5368
}
5469
guard textValues.count > 0 else {
5570
Logging.default.error("Could not get a text part from the first candidate.")
5671
return nil
5772
}
58-
return textValues.joined(separator: " ")
73+
return textValues.joined(separator: "\n")
5974
}
6075

6176
/// Returns function calls found in any `Part`s of the first candidate of the response, if any.

Sources/GoogleAI/ModelContent.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ public struct ModelContent: Equatable {
5151
/// A response to a function call.
5252
case functionResponse(FunctionResponse)
5353

54+
/// Code generated by the model that is meant to be executed.
55+
case executableCode(ExecutableCode)
56+
57+
/// Result of executing the ``ExecutableCode``.
58+
case codeExecutionResult(CodeExecutionResult)
59+
5460
// MARK: Convenience Initializers
5561

5662
/// Convenience function for populating a Part with JPEG data.
@@ -129,6 +135,8 @@ extension ModelContent.Part: Codable {
129135
case fileData
130136
case functionCall
131137
case functionResponse
138+
case executableCode
139+
case codeExecutionResult
132140
}
133141

134142
enum InlineDataKeys: String, CodingKey {
@@ -164,6 +172,10 @@ extension ModelContent.Part: Codable {
164172
try container.encode(functionCall, forKey: .functionCall)
165173
case let .functionResponse(functionResponse):
166174
try container.encode(functionResponse, forKey: .functionResponse)
175+
case let .executableCode(executableCode):
176+
try container.encode(executableCode, forKey: .executableCode)
177+
case let .codeExecutionResult(codeExecutionResult):
178+
try container.encode(codeExecutionResult, forKey: .codeExecutionResult)
167179
}
168180
}
169181

@@ -181,6 +193,13 @@ extension ModelContent.Part: Codable {
181193
self = .data(mimetype: mimetype, bytes)
182194
} else if values.contains(.functionCall) {
183195
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
196+
} else if values.contains(.executableCode) {
197+
self = try .executableCode(values.decode(ExecutableCode.self, forKey: .executableCode))
198+
} else if values.contains(.codeExecutionResult) {
199+
self = try .codeExecutionResult(values.decode(
200+
CodeExecutionResult.self,
201+
forKey: .codeExecutionResult
202+
))
184203
} else {
185204
throw DecodingError.dataCorrupted(.init(
186205
codingPath: [CodingKeys.text, CodingKeys.inlineData],
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import GoogleGenerativeAI
18+
19+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
20+
final class CodeExecutionTests: XCTestCase {
21+
let decoder = JSONDecoder()
22+
let encoder = JSONEncoder()
23+
24+
let languageKey = "language"
25+
let languageValue = "PYTHON"
26+
let codeKey = "code"
27+
let codeValue = "print('Hello, world!')"
28+
let outcomeKey = "outcome"
29+
let outcomeValue = "OUTCOME_OK"
30+
let outputKey = "output"
31+
let outputValue = "Hello, world!"
32+
33+
override func setUp() {
34+
encoder.outputFormatting = .init(
35+
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
36+
)
37+
}
38+
39+
func testEncodeCodeExecution() throws {
40+
let jsonData = try encoder.encode(CodeExecution())
41+
42+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
43+
XCTAssertEqual(json, """
44+
{
45+
46+
}
47+
""")
48+
}
49+
50+
func testDecodeExecutableCode() throws {
51+
let expectedExecutableCode = ExecutableCode(language: languageValue, code: codeValue)
52+
let json = """
53+
{
54+
"\(languageKey)": "\(languageValue)",
55+
"\(codeKey)": "\(codeValue)"
56+
}
57+
"""
58+
let jsonData = try XCTUnwrap(json.data(using: .utf8))
59+
60+
let executableCode = try XCTUnwrap(decoder.decode(ExecutableCode.self, from: jsonData))
61+
62+
XCTAssertEqual(executableCode, expectedExecutableCode)
63+
}
64+
65+
func testEncodeExecutableCode() throws {
66+
let executableCode = ExecutableCode(language: languageValue, code: codeValue)
67+
68+
let jsonData = try encoder.encode(executableCode)
69+
70+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
71+
XCTAssertEqual(json, """
72+
{
73+
"\(codeKey)" : "\(codeValue)",
74+
"\(languageKey)" : "\(languageValue)"
75+
}
76+
""")
77+
}
78+
79+
func testDecodeCodeExecutionResultOutcome_ok() throws {
80+
let expectedOutcome = CodeExecutionResult.Outcome.ok
81+
let json = "\"\(outcomeValue)\""
82+
let jsonData = try XCTUnwrap(json.data(using: .utf8))
83+
84+
let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData))
85+
86+
XCTAssertEqual(outcome, expectedOutcome)
87+
}
88+
89+
func testDecodeCodeExecutionResultOutcome_unknown() throws {
90+
let expectedOutcome = CodeExecutionResult.Outcome.unknown
91+
let json = "\"OUTCOME_NEW_VALUE\""
92+
let jsonData = try XCTUnwrap(json.data(using: .utf8))
93+
94+
let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData))
95+
96+
XCTAssertEqual(outcome, expectedOutcome)
97+
}
98+
99+
func testEncodeCodeExecutionResultOutcome() throws {
100+
let jsonData = try encoder.encode(CodeExecutionResult.Outcome.ok)
101+
102+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
103+
XCTAssertEqual(json, "\"\(outcomeValue)\"")
104+
}
105+
106+
func testDecodeCodeExecutionResult() throws {
107+
let expectedCodeExecutionResult = CodeExecutionResult(outcome: .ok, output: "Hello, world!")
108+
let json = """
109+
{
110+
"\(outcomeKey)": "\(outcomeValue)",
111+
"\(outputKey)": "\(outputValue)"
112+
}
113+
"""
114+
let jsonData = try XCTUnwrap(json.data(using: .utf8))
115+
116+
let codeExecutionResult = try XCTUnwrap(decoder.decode(
117+
CodeExecutionResult.self,
118+
from: jsonData
119+
))
120+
121+
XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult)
122+
}
123+
124+
func testDecodeCodeExecutionResult_missingOutput() throws {
125+
let expectedCodeExecutionResult = CodeExecutionResult(outcome: .deadlineExceeded, output: "")
126+
let json = """
127+
{
128+
"\(outcomeKey)": "OUTCOME_DEADLINE_EXCEEDED"
129+
}
130+
"""
131+
let jsonData = try XCTUnwrap(json.data(using: .utf8))
132+
133+
let codeExecutionResult = try XCTUnwrap(decoder.decode(
134+
CodeExecutionResult.self,
135+
from: jsonData
136+
))
137+
138+
XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult)
139+
}
140+
141+
func testEncodeCodeExecutionResult() throws {
142+
let codeExecutionResult = CodeExecutionResult(outcome: .ok, output: outputValue)
143+
144+
let jsonData = try encoder.encode(codeExecutionResult)
145+
146+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
147+
XCTAssertEqual(json, """
148+
{
149+
"\(outcomeKey)" : "\(outcomeValue)",
150+
"\(outputKey)" : "\(outputValue)"
151+
}
152+
""")
153+
}
154+
}

0 commit comments

Comments
 (0)