|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from typing import Any, Dict, List, Optional, Tuple, Union |
| 6 | + |
| 7 | +from google import genai |
| 8 | +from google.genai import types |
| 9 | +from haystack import Document, component, default_from_dict, default_to_dict, logging |
| 10 | +from haystack.utils import Secret, deserialize_secrets_inplace |
| 11 | +from more_itertools import batched |
| 12 | +from tqdm import tqdm |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +@component |
| 18 | +class GoogleGenAIDocumentEmbedder: |
| 19 | + """ |
| 20 | + Computes document embeddings using Google AI models. |
| 21 | +
|
| 22 | + ### Usage example |
| 23 | +
|
| 24 | + ```python |
| 25 | + from haystack import Document |
| 26 | + from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder |
| 27 | +
|
| 28 | + doc = Document(content="I love pizza!") |
| 29 | +
|
| 30 | + document_embedder = GoogleGenAIDocumentEmbedder() |
| 31 | +
|
| 32 | + result = document_embedder.run([doc]) |
| 33 | + print(result['documents'][0].embedding) |
| 34 | +
|
| 35 | + # [0.017020374536514282, -0.023255806416273117, ...] |
| 36 | + ``` |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + *, |
| 42 | + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), |
| 43 | + model: str = "text-embedding-004", |
| 44 | + prefix: str = "", |
| 45 | + suffix: str = "", |
| 46 | + batch_size: int = 32, |
| 47 | + progress_bar: bool = True, |
| 48 | + meta_fields_to_embed: Optional[List[str]] = None, |
| 49 | + embedding_separator: str = "\n", |
| 50 | + config: Optional[Dict[str, Any]] = None, |
| 51 | + ): |
| 52 | + """ |
| 53 | + Creates an GoogleGenAIDocumentEmbedder component. |
| 54 | +
|
| 55 | + :param api_key: |
| 56 | + The Google API key. |
| 57 | + You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter |
| 58 | + during initialization. |
| 59 | + :param model: |
| 60 | + The name of the model to use for calculating embeddings. |
| 61 | + The default model is `text-embedding-ada-002`. |
| 62 | + :param prefix: |
| 63 | + A string to add at the beginning of each text. |
| 64 | + :param suffix: |
| 65 | + A string to add at the end of each text. |
| 66 | + :param batch_size: |
| 67 | + Number of documents to embed at once. |
| 68 | + :param progress_bar: |
| 69 | + If `True`, shows a progress bar when running. |
| 70 | + :param meta_fields_to_embed: |
| 71 | + List of metadata fields to embed along with the document text. |
| 72 | + :param embedding_separator: |
| 73 | + Separator used to concatenate the metadata fields to the document text. |
| 74 | + :param config: |
| 75 | + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. |
| 76 | + If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}. |
| 77 | + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). |
| 78 | + """ |
| 79 | + self._api_key = api_key |
| 80 | + self._model = model |
| 81 | + self._prefix = prefix |
| 82 | + self._suffix = suffix |
| 83 | + self._batch_size = batch_size |
| 84 | + self._progress_bar = progress_bar |
| 85 | + self._meta_fields_to_embed = meta_fields_to_embed or [] |
| 86 | + self._embedding_separator = embedding_separator |
| 87 | + self._client = genai.Client(api_key=api_key.resolve_value()) |
| 88 | + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} |
| 89 | + |
| 90 | + def to_dict(self) -> Dict[str, Any]: |
| 91 | + """ |
| 92 | + Serializes the component to a dictionary. |
| 93 | +
|
| 94 | + :returns: |
| 95 | + Dictionary with serialized data. |
| 96 | + """ |
| 97 | + return default_to_dict( |
| 98 | + self, |
| 99 | + model=self._model, |
| 100 | + prefix=self._prefix, |
| 101 | + suffix=self._suffix, |
| 102 | + batch_size=self._batch_size, |
| 103 | + progress_bar=self._progress_bar, |
| 104 | + meta_fields_to_embed=self._meta_fields_to_embed, |
| 105 | + embedding_separator=self._embedding_separator, |
| 106 | + api_key=self._api_key.to_dict(), |
| 107 | + config=self._config, |
| 108 | + ) |
| 109 | + |
| 110 | + @classmethod |
| 111 | + def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAIDocumentEmbedder": |
| 112 | + """ |
| 113 | + Deserializes the component from a dictionary. |
| 114 | +
|
| 115 | + :param data: |
| 116 | + Dictionary to deserialize from. |
| 117 | + :returns: |
| 118 | + Deserialized component. |
| 119 | + """ |
| 120 | + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) |
| 121 | + return default_from_dict(cls, data) |
| 122 | + |
| 123 | + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: |
| 124 | + """ |
| 125 | + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. |
| 126 | + """ |
| 127 | + texts_to_embed: List[str] = [] |
| 128 | + for doc in documents: |
| 129 | + meta_values_to_embed = [ |
| 130 | + str(doc.meta[key]) |
| 131 | + for key in self._meta_fields_to_embed |
| 132 | + if key in doc.meta and doc.meta[key] is not None |
| 133 | + ] |
| 134 | + |
| 135 | + text_to_embed = ( |
| 136 | + self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix |
| 137 | + ) |
| 138 | + texts_to_embed.append(text_to_embed) |
| 139 | + |
| 140 | + return texts_to_embed |
| 141 | + |
| 142 | + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: |
| 143 | + """ |
| 144 | + Embed a list of texts in batches. |
| 145 | + """ |
| 146 | + |
| 147 | + all_embeddings = [] |
| 148 | + meta: Dict[str, Any] = {} |
| 149 | + for batch in tqdm( |
| 150 | + batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" |
| 151 | + ): |
| 152 | + args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} |
| 153 | + if self._config: |
| 154 | + args["config"] = types.EmbedContentConfig(**self._config) if self._config else None |
| 155 | + |
| 156 | + response = self._client.models.embed_content(**args) |
| 157 | + |
| 158 | + embeddings = [el.values for el in response.embeddings] |
| 159 | + all_embeddings.extend(embeddings) |
| 160 | + |
| 161 | + if "model" not in meta: |
| 162 | + meta["model"] = self._model |
| 163 | + |
| 164 | + return all_embeddings, meta |
| 165 | + |
| 166 | + @component.output_types(documents=List[Document], meta=Dict[str, Any]) |
| 167 | + def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: |
| 168 | + """ |
| 169 | + Embeds a list of documents. |
| 170 | +
|
| 171 | + :param documents: |
| 172 | + A list of documents to embed. |
| 173 | +
|
| 174 | + :returns: |
| 175 | + A dictionary with the following keys: |
| 176 | + - `documents`: A list of documents with embeddings. |
| 177 | + - `meta`: Information about the usage of the model. |
| 178 | + """ |
| 179 | + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): |
| 180 | + error_message_documents = ( |
| 181 | + "GoogleGenAIDocumentEmbedder expects a list of Documents as input. " |
| 182 | + "In case you want to embed a string, please use the GoogleGenAITextEmbedder." |
| 183 | + ) |
| 184 | + raise TypeError(error_message_documents) |
| 185 | + |
| 186 | + texts_to_embed = self._prepare_texts_to_embed(documents=documents) |
| 187 | + |
| 188 | + embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) |
| 189 | + |
| 190 | + for doc, emb in zip(documents, embeddings): |
| 191 | + doc.embedding = emb |
| 192 | + |
| 193 | + return {"documents": documents, "meta": meta} |
0 commit comments