Skip to content

Commit f94bdba

Browse files
authored
Add usageMetadata to GenerateContentResponse (#159)
1 parent 16e68be commit f94bdba

File tree

4 files changed

+86
-27
lines changed

4 files changed

+86
-27
lines changed

Sources/GoogleAI/GenerateContentResponse.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,28 @@ import Foundation
1717
/// The model's response to a generate content request.
1818
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1919
public struct GenerateContentResponse {
20+
/// Token usage metadata for processing the generate content request.
21+
public struct UsageMetadata {
22+
/// The number of tokens in the request prompt.
23+
public let promptTokenCount: Int
24+
25+
/// The total number of tokens across the generated response candidates.
26+
public let candidatesTokenCount: Int
27+
28+
/// The total number of tokens in both the request and response.
29+
public let totalTokenCount: Int
30+
}
31+
2032
/// A list of candidate response content, ordered from best to worst.
2133
public let candidates: [CandidateResponse]
2234

2335
/// A value containing the safety ratings for the response, or, if the request was blocked, a
2436
/// reason for blocking the request.
2537
public let promptFeedback: PromptFeedback?
2638

39+
/// Token usage metadata for processing the generate content request.
40+
public let usageMetadata: UsageMetadata?
41+
2742
/// The response's content as text, if it exists.
2843
public var text: String? {
2944
guard let candidate = candidates.first else {
@@ -51,9 +66,11 @@ public struct GenerateContentResponse {
5166
}
5267

5368
/// Initializer for SwiftUI previews or tests.
54-
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) {
69+
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
70+
usageMetadata: UsageMetadata? = nil) {
5571
self.candidates = candidates
5672
self.promptFeedback = promptFeedback
73+
self.usageMetadata = usageMetadata
5774
}
5875
}
5976

@@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable {
170187
enum CodingKeys: CodingKey {
171188
case candidates
172189
case promptFeedback
190+
case usageMetadata
173191
}
174192

175193
public init(from decoder: Decoder) throws {
@@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable {
194212
candidates = []
195213
}
196214
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
215+
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
216+
}
217+
}
218+
219+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
220+
extension GenerateContentResponse.UsageMetadata: Decodable {
221+
enum CodingKeys: CodingKey {
222+
case promptTokenCount
223+
case candidatesTokenCount
224+
case totalTokenCount
225+
}
226+
227+
public init(from decoder: any Decoder) throws {
228+
let container = try decoder.container(keyedBy: CodingKeys.self)
229+
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
230+
candidatesTokenCount = try container
231+
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
232+
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
197233
}
198234
}
199235

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
data: {"candidates": [{"content": {"parts": [{"text": "Cheyenne"}]},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}}
2-
1+
data: {"candidates": [{"content": {"parts": [{"text": "Mountain View, California"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"candidatesTokenCount": 4}}

Tests/GoogleAITests/GenerateContentResponses/unary-success-basic-reply-short.json

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"content": {
55
"parts": [
66
{
7-
"text": "Mountain View, California, United States"
7+
"text": "Mountain View, California"
88
}
99
],
1010
"role": "model"
@@ -31,24 +31,7 @@
3131
]
3232
}
3333
],
34-
"promptFeedback": {
35-
"safetyRatings": [
36-
{
37-
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
38-
"probability": "NEGLIGIBLE"
39-
},
40-
{
41-
"category": "HARM_CATEGORY_HATE_SPEECH",
42-
"probability": "NEGLIGIBLE"
43-
},
44-
{
45-
"category": "HARM_CATEGORY_HARASSMENT",
46-
"probability": "NEGLIGIBLE"
47-
},
48-
{
49-
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
50-
"probability": "NEGLIGIBLE"
51-
}
52-
]
34+
"usageMetadata": {
35+
"candidatesTokenCount": 4
5336
}
5437
}

Tests/GoogleAITests/GenerativeModelTests.swift

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase {
8282
XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible)
8383
XCTAssertEqual(candidate.content.parts.count, 1)
8484
let part = try XCTUnwrap(candidate.content.parts.first)
85-
XCTAssertEqual(part.text, "Mountain View, California, United States")
85+
XCTAssertEqual(part.text, "Mountain View, California")
8686
XCTAssertEqual(response.text, part.text)
87-
let promptFeedback = try XCTUnwrap(response.promptFeedback)
88-
XCTAssertNil(promptFeedback.blockReason)
89-
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
87+
XCTAssertNil(response.promptFeedback)
9088
XCTAssertEqual(response.functionCalls, [])
9189
}
9290

@@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase {
256254
XCTAssertEqual(response.functionCalls, [functionCall])
257255
}
258256

257+
func testGenerateContent_usageMetadata() async throws {
258+
MockURLProtocol
259+
.requestHandler = try httpRequestHandler(
260+
forResource: "unary-success-basic-reply-short",
261+
withExtension: "json"
262+
)
263+
264+
let response = try await model.generateContent(testPrompt)
265+
266+
let usageMetadata = try XCTUnwrap(response.usageMetadata)
267+
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
268+
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
269+
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
270+
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
271+
}
272+
259273
func testGenerateContent_failure_invalidAPIKey() async throws {
260274
let expectedStatusCode = 400
261275
MockURLProtocol
@@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase {
756770
}))
757771
}
758772

773+
func testGenerateContentStream_usageMetadata() async throws {
774+
MockURLProtocol
775+
.requestHandler = try httpRequestHandler(
776+
forResource: "streaming-success-basic-reply-short",
777+
withExtension: "txt"
778+
)
779+
var responses = [GenerateContentResponse]()
780+
781+
let stream = model.generateContentStream(testPrompt)
782+
for try await response in stream {
783+
responses.append(response)
784+
}
785+
786+
for (index, response) in responses.enumerated() {
787+
if index == responses.endIndex - 1 {
788+
let usageMetadata = try XCTUnwrap(response.usageMetadata)
789+
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
790+
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
791+
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
792+
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
793+
} else {
794+
// Only the last streamed response contains usage metadata
795+
XCTAssertNil(response.usageMetadata)
796+
}
797+
}
798+
}
799+
759800
func testGenerateContentStream_errorMidStream() async throws {
760801
MockURLProtocol.requestHandler = try httpRequestHandler(
761802
forResource: "streaming-failure-error-mid-stream",

0 commit comments

Comments
 (0)