Skip to content

Commit 8bb8243

Browse files
committed
Add usageMetadata to GenerateContentResponse in Vertex AI (#12777)
1 parent 6504b3d commit 8bb8243

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ struct ErrorDetailsView: View {
176176
],
177177
finishReason: FinishReason.maxTokens,
178178
citationMetadata: nil),
179-
],
180-
promptFeedback: nil)
179+
])
181180
)
182181

183182
return ErrorDetailsView(error: error)
@@ -200,8 +199,7 @@ struct ErrorDetailsView: View {
200199
],
201200
finishReason: FinishReason.other,
202201
citationMetadata: nil),
203-
],
204-
promptFeedback: nil)
202+
])
205203
)
206204

207205
return ErrorDetailsView(error: error)

FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ struct ErrorView: View {
5151
],
5252
finishReason: FinishReason.other,
5353
citationMetadata: nil),
54-
],
55-
promptFeedback: nil)
54+
])
5655
)
5756
List {
5857
MessageView(message: ChatMessage.samples[0])

FirebaseVertexAI/Sources/GenerateContentResponse.swift

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,28 @@ import Foundation
1717
/// The model's response to a generate content request.
1818
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1919
public struct GenerateContentResponse {
20+
/// Token usage metadata for processing the generate content request.
21+
public struct UsageMetadata {
22+
/// The number of tokens in the request prompt.
23+
public let promptTokenCount: Int
24+
25+
/// The total number of tokens across the generated response candidates.
26+
public let candidatesTokenCount: Int
27+
28+
/// The total number of tokens in both the request and response.
29+
public let totalTokenCount: Int
30+
}
31+
2032
/// A list of candidate response content, ordered from best to worst.
2133
public let candidates: [CandidateResponse]
2234

2335
/// A value containing the safety ratings for the response, or, if the request was blocked, a
2436
/// reason for blocking the request.
2537
public let promptFeedback: PromptFeedback?
2638

39+
/// Token usage metadata for processing the generate content request.
40+
public let usageMetadata: UsageMetadata?
41+
2742
/// The response's content as text, if it exists.
2843
public var text: String? {
2944
guard let candidate = candidates.first else {
@@ -51,9 +66,11 @@ public struct GenerateContentResponse {
5166
}
5267

5368
/// Initializer for SwiftUI previews or tests.
54-
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) {
69+
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
70+
usageMetadata: UsageMetadata? = nil) {
5571
self.candidates = candidates
5672
self.promptFeedback = promptFeedback
73+
self.usageMetadata = usageMetadata
5774
}
5875
}
5976

@@ -62,6 +79,7 @@ extension GenerateContentResponse: Decodable {
6279
enum CodingKeys: CodingKey {
6380
case candidates
6481
case promptFeedback
82+
case usageMetadata
6583
}
6684

6785
public init(from decoder: Decoder) throws {
@@ -86,6 +104,7 @@ extension GenerateContentResponse: Decodable {
86104
candidates = []
87105
}
88106
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
107+
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
89108
}
90109
}
91110

@@ -301,3 +320,20 @@ extension PromptFeedback: Decodable {
301320
}
302321
}
303322
}
323+
324+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
325+
extension GenerateContentResponse.UsageMetadata: Decodable {
326+
enum CodingKeys: CodingKey {
327+
case promptTokenCount
328+
case candidatesTokenCount
329+
case totalTokenCount
330+
}
331+
332+
public init(from decoder: any Decoder) throws {
333+
let container = try decoder.container(keyedBy: CodingKeys.self)
334+
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
335+
candidatesTokenCount = try container
336+
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
337+
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
338+
}
339+
}

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,21 @@ final class GenerativeModelTests: XCTestCase {
287287
_ = try await model.generateContent(testPrompt)
288288
}
289289

290+
func testGenerateContent_usageMetadata() async throws {
291+
MockURLProtocol
292+
.requestHandler = try httpRequestHandler(
293+
forResource: "unary-success-basic-reply-short",
294+
withExtension: "json"
295+
)
296+
297+
let response = try await model.generateContent(testPrompt)
298+
299+
let usageMetadata = try XCTUnwrap(response.usageMetadata)
300+
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
301+
XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
302+
XCTAssertEqual(usageMetadata.totalTokenCount, 13)
303+
}
304+
290305
func testGenerateContent_failure_invalidAPIKey() async throws {
291306
let expectedStatusCode = 400
292307
MockURLProtocol
@@ -814,6 +829,32 @@ final class GenerativeModelTests: XCTestCase {
814829
for try await _ in stream {}
815830
}
816831

832+
func testGenerateContentStream_usageMetadata() async throws {
833+
MockURLProtocol
834+
.requestHandler = try httpRequestHandler(
835+
forResource: "streaming-success-basic-reply-short",
836+
withExtension: "txt"
837+
)
838+
var responses = [GenerateContentResponse]()
839+
840+
let stream = model.generateContentStream(testPrompt)
841+
for try await response in stream {
842+
responses.append(response)
843+
}
844+
845+
for (index, response) in responses.enumerated() {
846+
if index == responses.endIndex - 1 {
847+
let usageMetadata = try XCTUnwrap(response.usageMetadata)
848+
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
849+
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
850+
XCTAssertEqual(usageMetadata.totalTokenCount, 10)
851+
} else {
852+
// Only the last streamed response contains usage metadata
853+
XCTAssertNil(response.usageMetadata)
854+
}
855+
}
856+
}
857+
817858
func testGenerateContentStream_errorMidStream() async throws {
818859
MockURLProtocol.requestHandler = try httpRequestHandler(
819860
forResource: "streaming-failure-error-mid-stream",

0 commit comments

Comments
 (0)