diff --git a/integrations/google_genai/pydoc/config.yml b/integrations/google_genai/pydoc/config.yml index e87f53cd0..095a67c07 100644 --- a/integrations/google_genai/pydoc/config.yml +++ b/integrations/google_genai/pydoc/config.yml @@ -3,6 +3,8 @@ loaders: search_path: [../src] modules: [ "haystack_integrations.components.generators.google_genai.chat.chat_generator", + "haystack_integrations.components.embedders.google_genai.document_embedder", + "haystack_integrations.components.embedders.google_genai.text_embedder" ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/google_genai/pyproject.toml b/integrations/google_genai/pyproject.toml index 41021b5ff..65c8797ea 100644 --- a/integrations/google_genai/pyproject.toml +++ b/integrations/google_genai/pyproject.toml @@ -10,7 +10,10 @@ readme = "README.md" requires-python = ">=3.9" license = "Apache-2.0" keywords = [] -authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Gary Badwal", email = "gurpreet071999@gmail.com" } +] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -74,7 +77,7 @@ types = "mypy --install-types --non-interactive --explicit-package-bases {args:s [tool.hatch.envs.lint] installer = "uv" detached = true -dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "more-itertools"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py new file mode 100644 index 000000000..f426cd628 --- /dev/null +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_embedder import GoogleGenAIDocumentEmbedder +from .text_embedder import GoogleGenAITextEmbedder + +__all__ = ["GoogleGenAIDocumentEmbedder", "GoogleGenAITextEmbedder"] diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py new file mode 100644 index 000000000..4f143a07e --- /dev/null +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Tuple, Union + +from google import genai +from google.genai import types +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace +from more_itertools import batched +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +@component +class GoogleGenAIDocumentEmbedder: + """ + Computes document embeddings using Google AI models. + + ### Usage example + + ```python + from haystack import Document + from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = GoogleGenAIDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), + model: str = "text-embedding-004", + prefix: str = "", + suffix: str = "", + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + config: Optional[Dict[str, Any]] = None, + ): + """ + Creates an GoogleGenAIDocumentEmbedder component. + + :param api_key: + The Google API key. + You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter + during initialization. + :param model: + The name of the model to use for calculating embeddings. + The default model is `text-embedding-ada-002`. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param batch_size: + Number of documents to embed at once. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param config: + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}. + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). + """ + self._api_key = api_key + self._model = model + self._prefix = prefix + self._suffix = suffix + self._batch_size = batch_size + self._progress_bar = progress_bar + self._meta_fields_to_embed = meta_fields_to_embed or [] + self._embedding_separator = embedding_separator + self._client = genai.Client(api_key=api_key.resolve_value()) + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self._model, + prefix=self._prefix, + suffix=self._suffix, + batch_size=self._batch_size, + progress_bar=self._progress_bar, + meta_fields_to_embed=self._meta_fields_to_embed, + embedding_separator=self._embedding_separator, + api_key=self._api_key.to_dict(), + config=self._config, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAIDocumentEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed: List[str] = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self._meta_fields_to_embed + if key in doc.meta and doc.meta[key] is not None + ] + + text_to_embed = ( + self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix + ) + texts_to_embed.append(text_to_embed) + + return texts_to_embed + + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + meta: Dict[str, Any] = {} + for batch in tqdm( + batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" + ): + args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} + if self._config: + args["config"] = types.EmbedContentConfig(**self._config) if self._config else None + + response = self._client.models.embed_content(**args) + + embeddings = [el.values for el in response.embeddings] + all_embeddings.extend(embeddings) + + if "model" not in meta: + meta["model"] = self._model + + return all_embeddings, meta + + @component.output_types(documents=List[Document], meta=Dict[str, Any]) + def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: + """ + Embeds a list of documents. + + :param documents: + A list of documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents with embeddings. + - `meta`: Information about the usage of the model. + """ + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): + error_message_documents = ( + "GoogleGenAIDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the GoogleGenAITextEmbedder." + ) + raise TypeError(error_message_documents) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "meta": meta} diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py new file mode 100644 index 000000000..415d5fc21 --- /dev/null +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union + +from google import genai +from google.genai import types +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +@component +class GoogleGenAITextEmbedder: + """ + Embeds strings using Google AI models. + + You can use it to embed user query and send it to an embedding Retriever. + + ### Usage example + + ```python + from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = GoogleGenAITextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + # 'meta': {'model': 'text-embedding-004-v2', + # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} + ``` + """ + + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), + model: str = "text-embedding-004", + prefix: str = "", + suffix: str = "", + config: Optional[Dict[str, Any]] = None, + ): + """ + Creates an GoogleGenAITextEmbedder component. + + :param api_key: + The Google API key. + You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter + during initialization. + :param model: + The name of the model to use for calculating embeddings. + The default model is `text-embedding-004`. + :param prefix: + A string to add at the beginning of each text to embed. + :param suffix: + A string to add at the end of each text to embed. + :param config: + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}. + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). + """ + + self._api_key = api_key + self._model_name = model + self._prefix = prefix + self._suffix = suffix + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} + self._client = genai.Client(api_key=api_key.resolve_value()) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self._model_name, + api_key=self._api_key.to_dict(), + prefix=self._prefix, + suffix=self._suffix, + config=self._config, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAITextEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_input(self, text: str) -> Dict[str, Any]: + if not isinstance(text, str): + error_message_text = ( + "GoogleGenAITextEmbedder expects a string as an input. " + "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." + ) + + raise TypeError(error_message_text) + + text_to_embed = self._prefix + text + self._suffix + + kwargs: Dict[str, Any] = {"model": self._model_name, "contents": text_to_embed} + if self._config: + kwargs["config"] = types.EmbedContentConfig(**self._config) + + return kwargs + + def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]: + return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}} + + @component.output_types(embedding=List[float], meta=Dict[str, Any]) + def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]: + """ + Embeds a single string. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + create_kwargs = self._prepare_input(text=text) + response = self._client.models.embed_content(**create_kwargs) + return self._prepare_output(result=response) diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py new file mode 100644 index 000000000..31e55baf4 --- /dev/null +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +from typing import List + +import pytest +from haystack import Document +from haystack.utils.auth import Secret + +from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder + + +def mock_google_response(contents: List[str], model: str = "text-embedding-004", **kwargs) -> dict: + secure_random = random.SystemRandom() + dict_response = { + "embedding": [[secure_random.random() for _ in range(768)] for _ in contents], + "meta": {"model": model}, + } + + return dict_response + + +class TestGoogleGenAIDocumentEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleGenAIDocumentEmbedder() + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model == "text-embedding-004" + assert embedder._prefix == "" + assert embedder._suffix == "" + assert embedder._batch_size == 32 + assert embedder._progress_bar is True + assert embedder._meta_fields_to_embed == [] + assert embedder._embedding_separator == "\n" + assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"} + + def test_init_with_parameters(self, monkeypatch): + embedder = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key-2"), + model="model", + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, + ) + assert embedder._api_key.resolve_value() == "fake-api-key-2" + assert embedder._model == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._batch_size == 64 + assert embedder._progress_bar is False + assert embedder._meta_fields_to_embed == ["test_field"] + assert embedder._embedding_separator == " | " + assert embedder._config == {"task_type": "CLASSIFICATION"} + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + embedder = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key-2"), + model="model", + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, + ) + assert embedder._api_key.resolve_value() == "fake-api-key-2" + assert embedder._model == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._batch_size == 64 + assert embedder._progress_bar is False + assert embedder._meta_fields_to_embed == ["test_field"] + assert embedder._embedding_separator == " | " + assert embedder._config == {"task_type": "CLASSIFICATION"} + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + GoogleGenAIDocumentEmbedder() + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + component = GoogleGenAIDocumentEmbedder() + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder" + ), + "init_parameters": { + "model": "text-embedding-004", + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, + "config": {"task_type": "SEMANTIC_SIMILARITY"}, + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "fake-api-key") + component = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="model", + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder" + ), + "init_parameters": { + "model": "model", + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + "api_key": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "config": {"task_type": "CLASSIFICATION"}, + }, + } + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) + for i in range(5) + ] + + embedder = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | " + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + assert prepared_texts == [ + "meta_value 0 | document number 0:\ncontent", + "meta_value 1 | document number 1:\ncontent", + "meta_value 2 | document number 2:\ncontent", + "meta_value 3 | document number 3:\ncontent", + "meta_value 4 | document number 4:\ncontent", + ] + + def test_run_wrong_input_format(self): + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) + + # wrong formats + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="GoogleGenAIDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + + with pytest.raises(TypeError, match="GoogleGenAIDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) + + def test_run_on_empty_list(self): + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list + + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "text-embedding-004" + + embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") + + result = embedder.run(documents=docs) + documents_with_embeddings = result["documents"] + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 768 + assert all(isinstance(x, float) for x in doc.embedding) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py new file mode 100644 index 000000000..bb700527b --- /dev/null +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from google.genai.types import ContentEmbedding, EmbedContentConfig, EmbedContentResponse +from haystack.utils.auth import Secret + +from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder + + +class TestGoogleGenAITextEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleGenAITextEmbedder() + + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "text-embedding-004" + assert embedder._prefix == "" + assert embedder._suffix == "" + assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"} + + def test_init_with_parameters(self): + embedder = GoogleGenAITextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="model", + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, + ) + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._config == {"task_type": "CLASSIFICATION"} + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + embedder = GoogleGenAITextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="model", + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, + ) + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._config == {"task_type": "CLASSIFICATION"} + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + component = GoogleGenAITextEmbedder() + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, + "model": "text-embedding-004", + "prefix": "", + "suffix": "", + "config": {"task_type": "SEMANTIC_SIMILARITY"}, + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "fake-api-key") + component = GoogleGenAITextEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="model", + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", + "init_parameters": { + "model": "model", + "api_key": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "prefix": "prefix", + "suffix": "suffix", + "config": {"task_type": "CLASSIFICATION"}, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, + "model": "text-embedding-004", + "prefix": "", + "suffix": "", + "config": {"task_type": "CLASSIFICATION"}, + }, + } + component = GoogleGenAITextEmbedder.from_dict(data) + assert component._api_key.resolve_value() == "fake-api-key" + assert component._model_name == "text-embedding-004" + assert component._prefix == "" + assert component._suffix == "" + assert component._config == {"task_type": "CLASSIFICATION"} + + def test_prepare_input(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleGenAITextEmbedder() + + contents = "The food was delicious" + prepared_input = embedder._prepare_input(contents) + assert prepared_input == { + "model": "text-embedding-004", + "contents": "The food was delicious", + "config": EmbedContentConfig( + http_options=None, + task_type="SEMANTIC_SIMILARITY", + title=None, + output_dimensionality=None, + mime_type=None, + auto_truncate=None, + ), + } + + def test_prepare_output(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + + response = EmbedContentResponse( + embeddings=[ContentEmbedding(values=[0.1, 0.2, 0.3])], + ) + + embedder = GoogleGenAITextEmbedder() + result = embedder._prepare_output(result=response) + assert result == { + "embedding": [0.1, 0.2, 0.3], + "meta": {"model": "text-embedding-004"}, + } + + def test_run_wrong_input_format(self): + embedder = GoogleGenAITextEmbedder(api_key=Secret.from_token("fake-api-key")) + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="GoogleGenAITextEmbedder expects a string as an input"): + embedder.run(text=list_integers_input) + + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + model = "text-embedding-004" + + embedder = GoogleGenAITextEmbedder(model=model) + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 768 + assert all(isinstance(x, float) for x in result["embedding"]) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + )