Skip to content

[Firebase AI] Add workaround for invalid SafetyRatings in response #14817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions FirebaseAI/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Unreleased
- [fixed] Fixed `ModalityTokenCount` decoding when the `tokenCount` field is
omitted; this occurs when the count is 0. (#14745)
- [fixed] Fixed `Candidate` decoding when `SafetyRating` values are missing a
category or probability; this may occur when using `gemini-2.0-flash-exp` for
image generation. (#14817)

# 11.12.0
- [added] **Public Preview**: Added support for specifying response modalities
Expand Down
10 changes: 7 additions & 3 deletions FirebaseAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,14 @@ extension Candidate: Decodable {
}

if let safetyRatings = try container.decodeIfPresent(
[SafetyRating].self,
forKey: .safetyRatings
[SafetyRating].self, forKey: .safetyRatings
) {
self.safetyRatings = safetyRatings
self.safetyRatings = safetyRatings.filter {
// Due to a bug in the backend, the SDK may receive invalid `SafetyRating` values that do
// not include a category or probability; these are filtered out of the safety ratings.
$0.category != HarmCategory.unspecified
&& $0.probability != SafetyRating.HarmProbability.unspecified
}
} else {
safetyRatings = []
}
Expand Down
23 changes: 18 additions & 5 deletions FirebaseAI/Sources/Safety.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct HarmProbability: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case unspecified = "HARM_PROBABILITY_UNSPECIFIED"
case negligible = "NEGLIGIBLE"
case low = "LOW"
case medium = "MEDIUM"
case high = "HIGH"
}

/// Internal-only; harm probability is unknown or unspecified by the backend.
static let unspecified = HarmProbability(kind: .unspecified)

/// The probability is zero or close to zero.
///
/// For benign content, the probability across all categories will be this value.
Expand Down Expand Up @@ -114,12 +118,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case unspecified = "HARM_SEVERITY_UNSPECIFIED"
case negligible = "HARM_SEVERITY_NEGLIGIBLE"
case low = "HARM_SEVERITY_LOW"
case medium = "HARM_SEVERITY_MEDIUM"
case high = "HARM_SEVERITY_HIGH"
}

/// Internal-only; harm severity is unknown or unspecified by the backend.
static let unspecified: HarmSeverity = .init(kind: .unspecified)

/// Negligible level of harm severity.
public static let negligible = HarmSeverity(kind: .negligible)

Expand Down Expand Up @@ -234,13 +242,17 @@ public struct SafetySetting: Sendable {
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case unspecified = "HARM_CATEGORY_UNSPECIFIED"
case harassment = "HARM_CATEGORY_HARASSMENT"
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
}

/// Internal-only; harm category is unknown or unspecified by the backend.
static let unspecified = HarmCategory(kind: .unspecified)

/// Harassment content.
public static let harassment = HarmCategory(kind: .harassment)

Expand Down Expand Up @@ -281,13 +293,14 @@ extension SafetyRating: Decodable {

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
category = try container.decode(HarmCategory.self, forKey: .category)
probability = try container.decode(HarmProbability.self, forKey: .probability)
category = try container.decodeIfPresent(HarmCategory.self, forKey: .category) ?? .unspecified
probability = try container.decodeIfPresent(
HarmProbability.self, forKey: .probability
) ?? .unspecified

// The following 3 fields are only omitted in our test data.
// The following 3 fields are only provided when using the Vertex AI backend (not Google AI).
probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ??
HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED")
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ?? .unspecified
severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0

// The blocked field is only included when true.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ struct GenerateContentIntegrationTests {
}

@Test(arguments: [
// TODO(andrewheard): Vertex AI configs temporarily disabled to due empty SafetyRatings bug.
// InstanceConfig.vertexV1,
// InstanceConfig.vertexV1Beta,
InstanceConfig.vertexAI_v1,
InstanceConfig.vertexAI_v1beta,
InstanceConfig.googleAI_v1beta,
InstanceConfig.googleAI_v1beta_staging,
InstanceConfig.googleAI_v1beta_freeTier_bypassProxy,
Expand Down
82 changes: 81 additions & 1 deletion FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,41 @@ final class GenerativeModelVertexAITests: XCTestCase {
blocked: false
),
].sorted()
let safetyRatingsInvalidIgnored = [
SafetyRating(
category: .hateSpeech,
probability: .negligible,
probabilityScore: 0.00039444832,
severity: .negligible,
severityScore: 0.0,
blocked: false
),
SafetyRating(
category: .dangerousContent,
probability: .negligible,
probabilityScore: 0.0010654529,
severity: .negligible,
severityScore: 0.0049325973,
blocked: false
),
SafetyRating(
category: .harassment,
probability: .negligible,
probabilityScore: 0.00026658305,
severity: .negligible,
severityScore: 0.0,
blocked: false
),
SafetyRating(
category: .sexuallyExplicit,
probability: .negligible,
probabilityScore: 0.0013701695,
severity: .negligible,
severityScore: 0.07626295,
blocked: false
),
// Ignored Invalid Safety Ratings: {},{},{},{}
].sorted()
let testModelName = "test-model"
let testModelResourceName =
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
Expand Down Expand Up @@ -399,6 +434,26 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertEqual(text, "The sum of [1, 2, 3] is")
}

func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "unary-success-image-invalid-safety-ratings",
withExtension: "json",
subdirectory: vertexSubdirectory
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 1)
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
let inlineDataParts = response.inlineDataParts
XCTAssertEqual(inlineDataParts.count, 1)
let imagePart = try XCTUnwrap(inlineDataParts.first)
XCTAssertEqual(imagePart.mimeType, "image/png")
XCTAssertGreaterThan(imagePart.data.count, 0)
}

func testGenerateContent_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
Expand Down Expand Up @@ -1118,7 +1173,7 @@ final class GenerativeModelVertexAITests: XCTestCase {
responses += 1
}

XCTAssertEqual(responses, 6)
XCTAssertEqual(responses, 4)
}

func testGenerateContentStream_successBasicReplyShort() async throws {
Expand Down Expand Up @@ -1220,6 +1275,31 @@ final class GenerativeModelVertexAITests: XCTestCase {
XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
}

func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
forResource: "streaming-success-image-invalid-safety-ratings",
withExtension: "txt",
subdirectory: vertexSubdirectory
)

let stream = try model.generateContentStream(testPrompt)
var responses = [GenerateContentResponse]()
for try await content in stream {
responses.append(content)
}

let response = try XCTUnwrap(responses.first)
XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
XCTAssertEqual(candidate.content.parts.count, 1)
let inlineDataParts = response.inlineDataParts
XCTAssertEqual(inlineDataParts.count, 1)
let imagePart = try XCTUnwrap(inlineDataParts.first)
XCTAssertEqual(imagePart.mimeType, "image/png")
XCTAssertGreaterThan(imagePart.data.count, 0)
}

func testGenerateContentStream_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
Expand Down
Loading