Skip to content

Commit ebce009

Browse files
authored
feat: Chroma - fix typing + ship types by adding py.typed files (#1910)
* chroma: fix typing + ship types * stricter type checking * rm unused configs
1 parent 9831cb5 commit ebce009

File tree

6 files changed

+23
-18
lines changed

6 files changed

+23
-18
lines changed

integrations/chroma/pyproject.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,13 @@ integration = 'pytest -m "integration" {args:tests}'
6666
all = 'pytest {args:tests}'
6767
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
6868

69-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
69+
types = "mypy -p haystack_integrations.components.retrievers.chroma -p haystack_integrations.document_stores.chroma {args}"
70+
71+
[tool.mypy]
72+
install_types = true
73+
non_interactive = true
74+
check_untyped_defs = true
75+
disallow_incomplete_defs = true
7076

7177
[tool.hatch.metadata]
7278
allow-direct-references = true
@@ -160,7 +166,3 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
160166
[tool.pytest.ini_options]
161167
minversion = "6.0"
162168
markers = ["integration: integration tests"]
163-
164-
[[tool.mypy.overrides]]
165-
module = ["haystack_integrations.*"]
166-
ignore_missing_imports = true

integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def run(
7171
query: str,
7272
filters: Optional[Dict[str, Any]] = None,
7373
top_k: Optional[int] = None,
74-
):
74+
) -> Dict[str, Any]:
7575
"""
7676
Run the retriever on the given input data.
7777
@@ -96,7 +96,7 @@ async def run_async(
9696
query: str,
9797
filters: Optional[Dict[str, Any]] = None,
9898
top_k: Optional[int] = None,
99-
):
99+
) -> Dict[str, Any]:
100100
"""
101101
Asynchronously run the retriever on the given input data.
102102
@@ -115,7 +115,7 @@ async def run_async(
115115
"""
116116
filters = apply_filter_policy(self.filter_policy, self.filters, filters)
117117
top_k = top_k or self.top_k
118-
return {"documents": await self.document_store.search_async([query], top_k, filters)[0]}
118+
return {"documents": (await self.document_store.search_async([query], top_k, filters))[0]}
119119

120120
@classmethod
121121
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
@@ -184,7 +184,7 @@ def run(
184184
query_embedding: List[float],
185185
filters: Optional[Dict[str, Any]] = None,
186186
top_k: Optional[int] = None,
187-
):
187+
) -> Dict[str, Any]:
188188
"""
189189
Run the retriever on the given input data.
190190
@@ -211,7 +211,7 @@ async def run_async(
211211
query_embedding: List[float],
212212
filters: Optional[Dict[str, Any]] = None,
213213
top_k: Optional[int] = None,
214-
):
214+
) -> Dict[str, Any]:
215215
"""
216216
Asynchronously run the retriever on the given input data.
217217
@@ -232,7 +232,7 @@ async def run_async(
232232
top_k = top_k or self.top_k
233233

234234
query_embeddings = [query_embedding]
235-
return {"documents": await self.document_store.search_embeddings_async(query_embeddings, top_k, filters)[0]}
235+
return {"documents": (await self.document_store.search_embeddings_async(query_embeddings, top_k, filters))[0]}
236236

237237
@classmethod
238238
def from_dict(cls, data: Dict[str, Any]) -> "ChromaEmbeddingRetriever":

integrations/chroma/src/haystack_integrations/components/retrievers/py.typed

Whitespace-only changes.

integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Literal, Optional
5+
from typing import Any, Dict, List, Literal, Optional, Sequence, cast
66

77
import chromadb
8+
from chromadb.api.models.AsyncCollection import AsyncCollection
89
from chromadb.api.types import GetResult, QueryResult
910
from haystack import default_from_dict, default_to_dict, logging
1011
from haystack.dataclasses import Document
@@ -37,7 +38,7 @@ def __init__(
3738
port: Optional[int] = None,
3839
distance_function: Literal["l2", "cosine", "ip"] = "l2",
3940
metadata: Optional[dict] = None,
40-
**embedding_function_params,
41+
**embedding_function_params: Any,
4142
):
4243
"""
4344
Creates a new ChromaDocumentStore instance.
@@ -86,8 +87,8 @@ def __init__(
8687
self._host = host
8788
self._port = port
8889

89-
self._collection = None
90-
self._async_collection = None
90+
self._collection: Optional[chromadb.Collection] = None
91+
self._async_collection: Optional[AsyncCollection] = None
9192

9293
def _ensure_initialized(self):
9394
if not self._collection:
@@ -482,7 +483,7 @@ def search_embeddings(
482483

483484
kwargs = self._prepare_query_kwargs(filters)
484485
results = self._collection.query(
485-
query_embeddings=query_embeddings,
486+
query_embeddings=cast(List[Sequence[float]], query_embeddings),
486487
n_results=top_k,
487488
**kwargs,
488489
)
@@ -513,7 +514,7 @@ async def search_embeddings_async(
513514

514515
kwargs = self._prepare_query_kwargs(filters)
515516
results = await self._async_collection.query(
516-
query_embeddings=query_embeddings,
517+
query_embeddings=cast(List[Sequence[float]], query_embeddings),
517518
n_results=top_k,
518519
**kwargs,
519520
)

integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from typing import Any
6+
57
from chromadb.api.types import EmbeddingFunction
68
from chromadb.utils.embedding_functions import (
79
CohereEmbeddingFunction,
@@ -34,7 +36,7 @@
3436
}
3537

3638

37-
def get_embedding_function(function_name: str, **kwargs) -> EmbeddingFunction:
39+
def get_embedding_function(function_name: str, **kwargs: Any) -> EmbeddingFunction:
3840
"""Load an embedding function by name.
3941
4042
:param function_name: the name of the embedding function.

integrations/chroma/src/haystack_integrations/document_stores/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)