Skip to content

Commit 9bd9134

Browse files
authored
feat: Add GoogleAITextEmbedder and GoogleAIDocumentEmbedder components (#1783)
* feat: Add GoogleAITextEmbedder and GoogleAIDocumentEmbedder components * fix: Improve error messages for input type validation in GoogleAITextEmbedder and GoogleAIDocumentEmbedder * feat: add Google GenAI embedder components for document and text embeddings * feat: add unit tests for GoogleAIDocumentEmbedder and GoogleAITextEmbedder * refactor: clean up imports and improve list handling in GoogleAIDocumentEmbedder and GoogleAITextEmbedder tests * refactor: Rename classes and update imports for Google GenAI components * feat: Add additional modules for Google GenAI embedders in config * chore: add 'more-itertools' to lint environment dependencies * refactor: update GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder to use private attributes for initialization * refactor: update _prepare_texts_to_embed to return a list instead of a dictionary * refactor: format code for better readability and consistency in document embedder * refactor: improve code formatting for consistency and readability in document embedder and tests * refactor: update _prepare_texts_to_embed to return a list instead of a dictionary * feat: add new author to project metadata in pyproject.toml
1 parent 501ef14 commit 9bd9134

File tree

7 files changed

+720
-2
lines changed

7 files changed

+720
-2
lines changed

integrations/google_genai/pydoc/config.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ loaders:
33
search_path: [../src]
44
modules: [
55
"haystack_integrations.components.generators.google_genai.chat.chat_generator",
6+
"haystack_integrations.components.embedders.google_genai.document_embedder",
7+
"haystack_integrations.components.embedders.google_genai.text_embedder"
68
]
79
ignore_when_discovered: ["__init__"]
810
processors:

integrations/google_genai/pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ readme = "README.md"
1010
requires-python = ">=3.9"
1111
license = "Apache-2.0"
1212
keywords = []
13-
authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }]
13+
authors = [
14+
{ name = "deepset GmbH", email = "info@deepset.ai" },
15+
{ name = "Gary Badwal", email = "gurpreet071999@gmail.com" }
16+
]
1417
classifiers = [
1518
"License :: OSI Approved :: Apache Software License",
1619
"Development Status :: 4 - Beta",
@@ -74,7 +77,7 @@ types = "mypy --install-types --non-interactive --explicit-package-bases {args:s
7477
[tool.hatch.envs.lint]
7578
installer = "uv"
7679
detached = true
77-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
80+
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "more-itertools"]
7881

7982
[tool.hatch.envs.lint.scripts]
8083
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from .document_embedder import GoogleGenAIDocumentEmbedder
5+
from .text_embedder import GoogleGenAITextEmbedder
6+
7+
__all__ = ["GoogleGenAIDocumentEmbedder", "GoogleGenAITextEmbedder"]
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Any, Dict, List, Optional, Tuple, Union
6+
7+
from google import genai
8+
from google.genai import types
9+
from haystack import Document, component, default_from_dict, default_to_dict, logging
10+
from haystack.utils import Secret, deserialize_secrets_inplace
11+
from more_itertools import batched
12+
from tqdm import tqdm
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@component
18+
class GoogleGenAIDocumentEmbedder:
19+
"""
20+
Computes document embeddings using Google AI models.
21+
22+
### Usage example
23+
24+
```python
25+
from haystack import Document
26+
from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder
27+
28+
doc = Document(content="I love pizza!")
29+
30+
document_embedder = GoogleGenAIDocumentEmbedder()
31+
32+
result = document_embedder.run([doc])
33+
print(result['documents'][0].embedding)
34+
35+
# [0.017020374536514282, -0.023255806416273117, ...]
36+
```
37+
"""
38+
39+
def __init__(
40+
self,
41+
*,
42+
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"),
43+
model: str = "text-embedding-004",
44+
prefix: str = "",
45+
suffix: str = "",
46+
batch_size: int = 32,
47+
progress_bar: bool = True,
48+
meta_fields_to_embed: Optional[List[str]] = None,
49+
embedding_separator: str = "\n",
50+
config: Optional[Dict[str, Any]] = None,
51+
):
52+
"""
53+
Creates an GoogleGenAIDocumentEmbedder component.
54+
55+
:param api_key:
56+
The Google API key.
57+
You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter
58+
during initialization.
59+
:param model:
60+
The name of the model to use for calculating embeddings.
61+
The default model is `text-embedding-ada-002`.
62+
:param prefix:
63+
A string to add at the beginning of each text.
64+
:param suffix:
65+
A string to add at the end of each text.
66+
:param batch_size:
67+
Number of documents to embed at once.
68+
:param progress_bar:
69+
If `True`, shows a progress bar when running.
70+
:param meta_fields_to_embed:
71+
List of metadata fields to embed along with the document text.
72+
:param embedding_separator:
73+
Separator used to concatenate the metadata fields to the document text.
74+
:param config:
75+
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
76+
If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}.
77+
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
78+
"""
79+
self._api_key = api_key
80+
self._model = model
81+
self._prefix = prefix
82+
self._suffix = suffix
83+
self._batch_size = batch_size
84+
self._progress_bar = progress_bar
85+
self._meta_fields_to_embed = meta_fields_to_embed or []
86+
self._embedding_separator = embedding_separator
87+
self._client = genai.Client(api_key=api_key.resolve_value())
88+
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
89+
90+
def to_dict(self) -> Dict[str, Any]:
91+
"""
92+
Serializes the component to a dictionary.
93+
94+
:returns:
95+
Dictionary with serialized data.
96+
"""
97+
return default_to_dict(
98+
self,
99+
model=self._model,
100+
prefix=self._prefix,
101+
suffix=self._suffix,
102+
batch_size=self._batch_size,
103+
progress_bar=self._progress_bar,
104+
meta_fields_to_embed=self._meta_fields_to_embed,
105+
embedding_separator=self._embedding_separator,
106+
api_key=self._api_key.to_dict(),
107+
config=self._config,
108+
)
109+
110+
@classmethod
111+
def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAIDocumentEmbedder":
112+
"""
113+
Deserializes the component from a dictionary.
114+
115+
:param data:
116+
Dictionary to deserialize from.
117+
:returns:
118+
Deserialized component.
119+
"""
120+
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
121+
return default_from_dict(cls, data)
122+
123+
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
124+
"""
125+
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
126+
"""
127+
texts_to_embed: List[str] = []
128+
for doc in documents:
129+
meta_values_to_embed = [
130+
str(doc.meta[key])
131+
for key in self._meta_fields_to_embed
132+
if key in doc.meta and doc.meta[key] is not None
133+
]
134+
135+
text_to_embed = (
136+
self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix
137+
)
138+
texts_to_embed.append(text_to_embed)
139+
140+
return texts_to_embed
141+
142+
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
143+
"""
144+
Embed a list of texts in batches.
145+
"""
146+
147+
all_embeddings = []
148+
meta: Dict[str, Any] = {}
149+
for batch in tqdm(
150+
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
151+
):
152+
args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
153+
if self._config:
154+
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
155+
156+
response = self._client.models.embed_content(**args)
157+
158+
embeddings = [el.values for el in response.embeddings]
159+
all_embeddings.extend(embeddings)
160+
161+
if "model" not in meta:
162+
meta["model"] = self._model
163+
164+
return all_embeddings, meta
165+
166+
@component.output_types(documents=List[Document], meta=Dict[str, Any])
167+
def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
168+
"""
169+
Embeds a list of documents.
170+
171+
:param documents:
172+
A list of documents to embed.
173+
174+
:returns:
175+
A dictionary with the following keys:
176+
- `documents`: A list of documents with embeddings.
177+
- `meta`: Information about the usage of the model.
178+
"""
179+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
180+
error_message_documents = (
181+
"GoogleGenAIDocumentEmbedder expects a list of Documents as input. "
182+
"In case you want to embed a string, please use the GoogleGenAITextEmbedder."
183+
)
184+
raise TypeError(error_message_documents)
185+
186+
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
187+
188+
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size)
189+
190+
for doc, emb in zip(documents, embeddings):
191+
doc.embedding = emb
192+
193+
return {"documents": documents, "meta": meta}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Any, Dict, List, Optional, Union
6+
7+
from google import genai
8+
from google.genai import types
9+
from haystack import component, default_from_dict, default_to_dict, logging
10+
from haystack.utils import Secret, deserialize_secrets_inplace
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
@component
16+
class GoogleGenAITextEmbedder:
17+
"""
18+
Embeds strings using Google AI models.
19+
20+
You can use it to embed user query and send it to an embedding Retriever.
21+
22+
### Usage example
23+
24+
```python
25+
from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
26+
27+
text_to_embed = "I love pizza!"
28+
29+
text_embedder = GoogleGenAITextEmbedder()
30+
31+
print(text_embedder.run(text_to_embed))
32+
33+
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
34+
# 'meta': {'model': 'text-embedding-004-v2',
35+
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
36+
```
37+
"""
38+
39+
def __init__(
40+
self,
41+
*,
42+
api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"),
43+
model: str = "text-embedding-004",
44+
prefix: str = "",
45+
suffix: str = "",
46+
config: Optional[Dict[str, Any]] = None,
47+
):
48+
"""
49+
Creates an GoogleGenAITextEmbedder component.
50+
51+
:param api_key:
52+
The Google API key.
53+
You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter
54+
during initialization.
55+
:param model:
56+
The name of the model to use for calculating embeddings.
57+
The default model is `text-embedding-004`.
58+
:param prefix:
59+
A string to add at the beginning of each text to embed.
60+
:param suffix:
61+
A string to add at the end of each text to embed.
62+
:param config:
63+
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
64+
If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}.
65+
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
66+
"""
67+
68+
self._api_key = api_key
69+
self._model_name = model
70+
self._prefix = prefix
71+
self._suffix = suffix
72+
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
73+
self._client = genai.Client(api_key=api_key.resolve_value())
74+
75+
def to_dict(self) -> Dict[str, Any]:
76+
"""
77+
Serializes the component to a dictionary.
78+
79+
:returns:
80+
Dictionary with serialized data.
81+
"""
82+
return default_to_dict(
83+
self,
84+
model=self._model_name,
85+
api_key=self._api_key.to_dict(),
86+
prefix=self._prefix,
87+
suffix=self._suffix,
88+
config=self._config,
89+
)
90+
91+
@classmethod
92+
def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAITextEmbedder":
93+
"""
94+
Deserializes the component from a dictionary.
95+
96+
:param data:
97+
Dictionary to deserialize from.
98+
:returns:
99+
Deserialized component.
100+
"""
101+
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
102+
return default_from_dict(cls, data)
103+
104+
def _prepare_input(self, text: str) -> Dict[str, Any]:
105+
if not isinstance(text, str):
106+
error_message_text = (
107+
"GoogleGenAITextEmbedder expects a string as an input. "
108+
"In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder."
109+
)
110+
111+
raise TypeError(error_message_text)
112+
113+
text_to_embed = self._prefix + text + self._suffix
114+
115+
kwargs: Dict[str, Any] = {"model": self._model_name, "contents": text_to_embed}
116+
if self._config:
117+
kwargs["config"] = types.EmbedContentConfig(**self._config)
118+
119+
return kwargs
120+
121+
def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]:
122+
return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}}
123+
124+
@component.output_types(embedding=List[float], meta=Dict[str, Any])
125+
def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]:
126+
"""
127+
Embeds a single string.
128+
129+
:param text:
130+
Text to embed.
131+
132+
:returns:
133+
A dictionary with the following keys:
134+
- `embedding`: The embedding of the input text.
135+
- `meta`: Information about the usage of the model.
136+
"""
137+
create_kwargs = self._prepare_input(text=text)
138+
response = self._client.models.embed_content(**create_kwargs)
139+
return self._prepare_output(result=response)

0 commit comments

Comments
 (0)