From 8104964185e49ffa484377de0fff17322b2989d3 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 2 Apr 2025 15:28:43 -0400 Subject: [PATCH 1/3] [Vertex AI] Add `countTokens` support for Developer API via VinF --- .../Sources/GenerativeModel.swift | 32 ++++++++++++++++--- FirebaseVertexAI/Sources/VertexAI.swift | 1 + .../CountTokensIntegrationTests.swift | 2 +- .../Tests/Utilities/InstanceConfig.swift | 10 ++++-- FirebaseVertexAI/Tests/Unit/ChatTests.swift | 6 +++- .../Tests/Unit/GenerativeModelTests.swift | 20 ++++++++++-- 6 files changed, 58 insertions(+), 13 deletions(-) diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index 8ec905cc436..3f2c8273a41 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -23,7 +23,10 @@ public final class GenerativeModel: Sendable { /// Model name prefix to identify Gemini models. static let geminiModelNamePrefix = "gemini-" - /// The resource name of the model in the backend; has the format "models/model-name". + /// The name of the model, for example "gemini-2.0-flash". + let modelName: String + + /// The model resource name corresponding with `modelName` in the backend. let modelResourceName: String /// Configuration for the backend API used by this model. @@ -53,8 +56,13 @@ public final class GenerativeModel: Sendable { /// Initializes a new remote model with the given parameters. /// /// - Parameters: - /// - modelResourceName: The resource name of the model to use, for example - /// `"projects/{project-id}/locations/{location-id}/publishers/google/models/{model-name}"`. + /// - modelName: The name of the model, for example "gemini-2.0-flash". + /// - modelResourceName: The model resource name corresponding with `modelName` in the backend. + /// The form depends on the backend and will be one of: + /// - Vertex AI via Vertex AI in Firebase: + /// `"projects/{projectID}/locations/{locationID}/publishers/google/models/{modelName}"` + /// - Developer API via Vertex AI in Firebase: `"projects/{projectID}/models/{modelName}"` + /// - Developer API via Generative Language: `"models/{modelName}"` /// - firebaseInfo: Firebase data used by the SDK, including project ID and API key. /// - apiConfig: Configuration for the backend API used by this model. /// - generationConfig: The content generation parameters your model should use. @@ -65,7 +73,8 @@ public final class GenerativeModel: Sendable { /// only text content is supported. /// - requestOptions: Configuration parameters for sending requests to the backend. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`. - init(modelResourceName: String, + init(modelName: String, + modelResourceName: String, firebaseInfo: FirebaseInfo, apiConfig: APIConfig, generationConfig: GenerationConfig? = nil, @@ -75,6 +84,7 @@ public final class GenerativeModel: Sendable { systemInstruction: ModelContent? = nil, requestOptions: RequestOptions, urlSession: URLSession = .shared) { + self.modelName = modelName self.modelResourceName = modelResourceName self.apiConfig = apiConfig generativeAIService = GenerativeAIService( @@ -275,8 +285,20 @@ public final class GenerativeModel: Sendable { content.map { ModelContent(role: nil, parts: $0.parts) } } + // When using the Developer API via the Firebase backend, the model name of the + // `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form + // "models/model-name". This field is unaltered by the Firebase backend before forwarding the + // request to the Generative Language backend, which expects the form "models/model-name". + let generateContentRequestModelResourceName = switch apiConfig.service { + case .vertexAI, .developer(endpoint: .generativeLanguage): + modelResourceName + case .developer(endpoint: .firebaseVertexAIProd), + .developer(endpoint: .firebaseVertexAIStaging): + "models/\(modelName)" + } + let generateContentRequest = GenerateContentRequest( - model: modelResourceName, + model: generateContentRequestModelResourceName, contents: requestContent, generationConfig: generationConfig, safetySettings: safetySettings, diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 8b91f5e54c2..0d837e3b041 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -80,6 +80,7 @@ public class VertexAI { } return GenerativeModel( + modelName: modelName, modelResourceName: modelResourceName(modelName: modelName), firebaseInfo: firebaseInfo, apiConfig: apiConfig, diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/CountTokensIntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/CountTokensIntegrationTests.swift index ab6e42c9f28..8d7c0a202c8 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/CountTokensIntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/CountTokensIntegrationTests.swift @@ -102,7 +102,7 @@ struct CountTokensIntegrationTests { @Test(arguments: [ /* System instructions are not supported on the v1 Developer API. */ - InstanceConfig.developerV1, + InstanceConfig.developerV1Spark, ]) func countTokens_text_systemInstruction_unsupported(_ config: InstanceConfig) async throws { let model = VertexAI.componentInstance(config).generativeModel( diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift index a3c9334976b..587a298d8dc 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift @@ -32,11 +32,14 @@ struct InstanceConfig { static let vertexV1BetaStaging = InstanceConfig( apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseVertexAIStaging), version: .v1beta) ) - static let developerV1 = InstanceConfig( + static let developerV1Beta = InstanceConfig( + apiConfig: APIConfig(service: .developer(endpoint: .firebaseVertexAIProd), version: .v1beta) + ) + static let developerV1Spark = InstanceConfig( appName: FirebaseAppNames.spark, apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1) ) - static let developerV1Beta = InstanceConfig( + static let developerV1BetaSpark = InstanceConfig( appName: FirebaseAppNames.spark, apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta) ) @@ -45,8 +48,9 @@ struct InstanceConfig { vertexV1Staging, vertexV1Beta, vertexV1BetaStaging, - developerV1, developerV1Beta, + developerV1Spark, + developerV1BetaSpark, ] static let vertexV1AppCheckNotConfigured = InstanceConfig( diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 774af8d3c44..acb98d1e0e1 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -20,6 +20,9 @@ import FirebaseCore @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) final class ChatTests: XCTestCase { + let modelName = "test-model-name" + let modelResourceName = "projects/my-project/locations/us-central1/models/test-model-name" + var urlSession: URLSession! override func setUp() { @@ -59,7 +62,8 @@ final class ChatTests: XCTestCase { options: FirebaseOptions(googleAppID: "ignore", gcmSenderID: "ignore")) let model = GenerativeModel( - modelResourceName: "my-model", + modelName: modelName, + modelResourceName: modelResourceName, firebaseInfo: FirebaseInfo( projectID: "my-project-id", apiKey: "API_KEY", diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 51295f81bee..290238445c5 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -56,6 +56,7 @@ final class GenerativeModelTests: XCTestCase { blocked: false ), ].sorted() + let testModelName = "test-model" let testModelResourceName = "projects/test-project-id/locations/test-location/publishers/google/models/test-model" let apiConfig = VertexAI.defaultVertexAIAPIConfig @@ -70,6 +71,7 @@ final class GenerativeModelTests: XCTestCase { configuration.protocolClasses = [MockURLProtocol.self] urlSession = try XCTUnwrap(URLSession(configuration: configuration)) model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, @@ -275,8 +277,8 @@ final class GenerativeModelTests: XCTestCase { subdirectory: vertexSubdirectory ) let model = GenerativeModel( - // Model name is prefixed with "models/". - modelResourceName: "models/test-model", + modelName: testModelName, + modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, tools: nil, @@ -399,6 +401,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_appCheck_validToken() async throws { let appCheckToken = "test-valid-token" model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)), apiConfig: apiConfig, @@ -420,6 +423,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_dataCollectionOff() async throws { let appCheckToken = "test-valid-token" model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true), @@ -442,6 +446,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_appCheck_tokenRefreshError() async throws { model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())), apiConfig: apiConfig, @@ -463,6 +468,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_auth_validAuthToken() async throws { let authToken = "test-valid-token" model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)), apiConfig: apiConfig, @@ -483,6 +489,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_auth_nilAuthToken() async throws { model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)), apiConfig: apiConfig, @@ -503,7 +510,8 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_auth_authTokenRefreshError() async throws { model = GenerativeModel( - modelResourceName: "my-model", + modelName: testModelName, + modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())), apiConfig: apiConfig, tools: nil, @@ -900,6 +908,7 @@ final class GenerativeModelTests: XCTestCase { ) let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, @@ -1204,6 +1213,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContentStream_appCheck_validToken() async throws { let appCheckToken = "test-valid-token" model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)), apiConfig: apiConfig, @@ -1225,6 +1235,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContentStream_appCheck_tokenRefreshError() async throws { model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())), apiConfig: apiConfig, @@ -1375,6 +1386,7 @@ final class GenerativeModelTests: XCTestCase { ) let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, @@ -1451,6 +1463,7 @@ final class GenerativeModelTests: XCTestCase { parts: "You are a calculator. Use the provided tools." ) model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, @@ -1511,6 +1524,7 @@ final class GenerativeModelTests: XCTestCase { ) let requestOptions = RequestOptions(timeout: expectedTimeout) model = GenerativeModel( + modelName: testModelName, modelResourceName: testModelResourceName, firebaseInfo: testFirebaseInfo(), apiConfig: apiConfig, From 0d0b78abd48417871ff08b725dde86ffb4eb00c1 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 2 Apr 2025 16:13:13 -0400 Subject: [PATCH 2/3] Fix `CountTokensRequest` URL mistake, uncovered by integration test --- FirebaseVertexAI/Sources/GenerativeModel.swift | 4 +++- .../Sources/Types/Internal/Requests/CountTokensRequest.swift | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index 3f2c8273a41..4714e78f1ef 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -309,7 +309,9 @@ public final class GenerativeModel: Sendable { apiMethod: .countTokens, options: requestOptions ) - let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest) + let countTokensRequest = CountTokensRequest( + modelResourceName: modelResourceName, generateContentRequest: generateContentRequest + ) return try await generativeAIService.loadRequest(request: countTokensRequest) } diff --git a/FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift b/FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift index 8a49adcab3f..2f48ca9a32b 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Requests/CountTokensRequest.swift @@ -16,6 +16,8 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct CountTokensRequest { + let modelResourceName: String + let generateContentRequest: GenerateContentRequest } @@ -30,7 +32,7 @@ extension CountTokensRequest: GenerativeAIRequest { var url: URL { let version = apiConfig.version.rawValue let endpoint = apiConfig.service.endpoint.rawValue - return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")! + return URL(string: "\(endpoint)/\(version)/\(modelResourceName):countTokens")! } } From 00472b00c8344ece01ed4f12fdb4637919437fbe Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 2 Apr 2025 16:32:40 -0400 Subject: [PATCH 3/3] Fix unit tests --- FirebaseVertexAI/Sources/VertexAI.swift | 4 +--- .../Types/Internal/Requests/CountTokensRequestTests.swift | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 0d837e3b041..112d117047d 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -241,13 +241,11 @@ public class VertexAI { private func developerModelResourceName(modelName: String) -> String { switch apiConfig.service.endpoint { - case .firebaseVertexAIStaging: + case .firebaseVertexAIStaging, .firebaseVertexAIProd: let projectID = firebaseInfo.projectID return "projects/\(projectID)/models/\(modelName)" case .generativeLanguage: return "models/\(modelName)" - default: - fatalError("The Developer API is not supported on '\(apiConfig.service.endpoint)'.") } } diff --git a/FirebaseVertexAI/Tests/Unit/Types/Internal/Requests/CountTokensRequestTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Internal/Requests/CountTokensRequestTests.swift index 13972835fa2..338671287ab 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Internal/Requests/CountTokensRequestTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Internal/Requests/CountTokensRequestTests.swift @@ -52,7 +52,9 @@ final class CountTokensRequestTests: XCTestCase { apiMethod: .countTokens, options: requestOptions ) - let request = CountTokensRequest(generateContentRequest: generateContentRequest) + let request = CountTokensRequest( + modelResourceName: modelResourceName, generateContentRequest: generateContentRequest + ) let jsonData = try encoder.encode(request) @@ -86,7 +88,9 @@ final class CountTokensRequestTests: XCTestCase { apiMethod: .countTokens, options: requestOptions ) - let request = CountTokensRequest(generateContentRequest: generateContentRequest) + let request = CountTokensRequest( + modelResourceName: modelResourceName, generateContentRequest: generateContentRequest + ) let jsonData = try encoder.encode(request)