From 9a7a0185fcba3ea4129f50f82a3ff066dcd63377 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 15 Sep 2023 11:43:11 +0200 Subject: [PATCH 01/17] Fix pytest option parsing --- libs/experimental/tests/unit_tests/conftest.py | 4 ++-- libs/langchain/tests/unit_tests/conftest.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/experimental/tests/unit_tests/conftest.py b/libs/experimental/tests/unit_tests/conftest.py index da45a330f50af..afb609f3ce31f 100644 --- a/libs/experimental/tests/unit_tests/conftest.py +++ b/libs/experimental/tests/unit_tests/conftest.py @@ -40,8 +40,8 @@ def test_something(): # Used to avoid repeated calls to `util.find_spec` required_pkgs_info: Dict[str, bool] = {} - only_extended = config.getoption("--only-extended") or False - only_core = config.getoption("--only-core") or False + only_extended = config.getoption("--only-extended", False) + only_core = config.getoption("--only-core", False) if only_extended and only_core: raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index da45a330f50af..afb609f3ce31f 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -40,8 +40,8 @@ def test_something(): # Used to avoid repeated calls to `util.find_spec` required_pkgs_info: Dict[str, bool] = {} - only_extended = config.getoption("--only-extended") or False - only_core = config.getoption("--only-core") or False + only_extended = config.getoption("--only-extended", False) + only_core = config.getoption("--only-core", False) if only_extended and only_core: raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") From 299086a2a4cec00e53b274d8ce0709b0cde13c8f Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 15 Sep 2023 11:44:12 +0200 Subject: [PATCH 02/17] pgvector: Slight refactoring to make code a bit more reusable --- .../langchain/vectorstores/pgvector.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index 41deb51f10a0f..0b2d97d6277e6 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -23,12 +23,12 @@ import sqlalchemy from sqlalchemy import delete from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Session try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings @@ -129,6 +129,8 @@ def __init__( self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} # Create a connection if not provided, otherwise use the provided connection + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) self._conn = connection if connection else self.connect() self.__post_init__() @@ -155,14 +157,15 @@ def __del__(self) -> None: def embeddings(self) -> Embeddings: return self.embedding_function + def create_engine(self) -> sqlalchemy.Engine: + return sqlalchemy.create_engine(self.connection_string, echo=False) + def connect(self) -> sqlalchemy.engine.Connection: - engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args) - conn = engine.connect() - return conn + return self._engine.connect() def create_vector_extension(self) -> None: try: - with Session(self._conn) as session: + with self.Session() as session: # The advisor lock fixes issue arising from concurrent # creation of the vector extension. # https://github.com/langchain-ai/langchain/issues/12933 @@ -180,24 +183,22 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e def create_tables_if_not_exists(self) -> None: - with self._conn.begin(): - Base.metadata.create_all(self._conn) + Base.metadata.create_all(self._engine) def drop_tables(self) -> None: - with self._conn.begin(): - Base.metadata.drop_all(self._conn) + Base.metadata.drop_all(self._engine) def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with Session(self._conn) as session: + with self.Session() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -208,7 +209,7 @@ def delete_collection(self) -> None: @contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """Create a context manager for the session, bind to _conn string.""" - yield Session(self._conn) + yield self.Session() def delete( self, @@ -220,7 +221,7 @@ def delete( Args: ids: List of ids to delete. """ - with Session(self._conn) as session: + with self.Session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -236,7 +237,7 @@ def get_collection(self, session: Session) -> Optional["CollectionStore"]: return self.CollectionStore.get_by_name(session, self.collection_name) @classmethod - def __from( + def _from( cls, texts: List[str], embeddings: List[List[float]], @@ -294,7 +295,7 @@ def add_embeddings( if not metadatas: metadatas = [{} for _ in texts] - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -399,7 +400,7 @@ def similarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - results = self.__query_collection(embedding=embedding, k=k, filter=filter) + results = self._query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -417,14 +418,14 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa ] return docs - def __query_collection( + def _query_collection( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" - with Session(self._conn) as session: + with self.Session() as session: collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -511,7 +512,7 @@ def from_texts( """ embeddings = embedding.embed_documents(list(texts)) - return cls.__from( + return cls._from( texts, embeddings, embedding, @@ -556,7 +557,7 @@ def from_embeddings( texts = [t[0] for t in text_embeddings] embeddings = [t[1] for t in text_embeddings] - return cls.__from( + return cls._from( texts, embeddings, embedding, @@ -717,7 +718,7 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self._query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] From 5abea55f5d7914523d82b671b35180082ab6947c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 15 Sep 2023 11:49:40 +0200 Subject: [PATCH 03/17] CrateDB vector: Add vector store support The implementation is based on the `pgvector` adapter, as both PostgreSQL and CrateDB share similar attributes, and can be wrapped well by using the same SQLAlchemy layer on top. --- .../langchain/vectorstores/__init__.py | 9 + .../vectorstores/cratedb/__init__.py | 6 + .../langchain/vectorstores/cratedb/base.py | 396 ++++++++++++++++ .../langchain/vectorstores/cratedb/model.py | 84 ++++ .../vectorstores/cratedb/sqlalchemy_type.py | 84 ++++ libs/langchain/pyproject.toml | 7 + .../vectorstores/docker-compose/cratedb.yml | 20 + .../vectorstores/test_cratedb.py | 445 ++++++++++++++++++ 8 files changed, 1051 insertions(+) create mode 100644 libs/langchain/langchain/vectorstores/cratedb/__init__.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/base.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/model.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py create mode 100644 libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml create mode 100644 libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 4a1c6dd9696a9..1bba69f186815 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -134,6 +134,12 @@ def _import_clickhouse_settings() -> Any: return ClickhouseSettings +def _import_cratedb() -> Any: + from langchain.vectorstores.cratedb import CrateDBVectorSearch + + return CrateDBVectorSearch + + def _import_dashvector() -> Any: from langchain.vectorstores.dashvector import DashVector @@ -459,6 +465,8 @@ def __getattr__(name: str) -> Any: return _import_clickhouse_settings() elif name == "Clickhouse": return _import_clickhouse() + elif name == "CrateDBVectorSearch": + return _import_cratedb() elif name == "DashVector": return _import_dashvector() elif name == "DeepLake": @@ -574,6 +582,7 @@ def __getattr__(name: str) -> Any: "Clarifai", "Clickhouse", "ClickhouseSettings", + "CrateDBVectorSearch", "DashVector", "DeepLake", "Dingo", diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py new file mode 100644 index 0000000000000..303a52babeaea --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseModel, CrateDBVectorSearch + +__all__ = [ + "BaseModel", + "CrateDBVectorSearch", +] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py new file mode 100644 index 0000000000000..ca1c21fad68a7 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import enum +import math +import uuid +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import sqlalchemy +from cratedb_toolkit.sqlalchemy.patch import patch_inspector +from cratedb_toolkit.sqlalchemy.polyfill import ( + polyfill_refresh_after_dml, + refresh_table, +) +from sqlalchemy.orm import declarative_base, sessionmaker + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.pgvector import PGVector + + +class DistanceStrategy(str, enum.Enum): + """Enumerator of the Distance strategies.""" + + EUCLIDEAN = "euclidean" + COSINE = "cosine" + MAX_INNER_PRODUCT = "inner" + + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN + +Base = declarative_base() # type: Any +# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid) + + +def _results_to_docs(docs_and_scores: Any) -> List[Document]: + """Return docs from docs and scores.""" + return [doc for doc, _ in docs_and_scores] + + +class CrateDBVectorSearch(PGVector): + """`CrateDB` vector store. + + To use it, you should have the ``crate[sqlalchemy]`` python package installed. + + Args: + connection_string: Database connection string. + embedding_function: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: EUCLIDEAN) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + + Example: + .. code-block:: python + + from langchain.vectorstores import CrateDBVectorSearch + from langchain.embeddings.openai import OpenAIEmbeddings + + CONNECTION_STRING = "crate://crate@localhost:4200/test3" + COLLECTION_NAME = "state_of_the_union_test" + embeddings = OpenAIEmbeddings() + vectorestore = CrateDBVectorSearch.from_documents( + embedding=embeddings, + documents=docs, + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + ) + + + """ + + def __post_init__( + self, + ) -> None: + """ + Initialize the store. + """ + + # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. + patch_inspector() + + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + + # TODO: See what can be improved here. + polyfill_refresh_after_dml(self.Session) + + self.CollectionStore = None + self.EmbeddingStore = None + + def get_collection( + self, session: sqlalchemy.orm.Session + ) -> Optional["CollectionStore"]: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_name(session, self.collection_name) + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + from langchain.vectorstores.cratedb.model import model_factory + + dimensions = len(embeddings[0]) + self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=dimensions) + if self.pre_delete_collection: + self.delete_collection() + self.create_tables_if_not_exists() + self.create_collection() + return super().add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + def create_tables_if_not_exists(self) -> None: + """ + Need to overwrite because `Base` is different from upstream. + """ + Base.metadata.create_all(self._engine) + + def drop_tables(self) -> None: + """ + Need to overwrite because `Base` is different from upstream. + """ + Base.metadata.drop_all(self._engine) + + def delete( + self, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """ + Delete vectors by ids or uuids. + + Remark: Specialized for CrateDB to synchronize data. + + Args: + ids: List of ids to delete. + + Remark: Patch for CrateDB needs to overwrite this, in order to + add a "REFRESH TABLE" statement afterwards. The other + patch, listening to `after_delete` events seems not be + able to catch it. + """ + super().delete(ids=ids, **kwargs) + + # CrateDB: Synchronize data because `on_flush` does not catch it. + with self.Session() as session: + refresh_table(session, self.EmbeddingStore) + + @property + def distance_strategy(self) -> Any: + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self.EmbeddingStore.embedding.euclidean_distance + elif self._distance_strategy == DistanceStrategy.COSINE: + raise NotImplementedError("Cosine similarity not implemented yet") + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + raise NotImplementedError("Dot-product similarity not implemented yet") + else: + raise ValueError( + f"Got unexpected value for distance: {self._distance_strategy}. " + f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." + ) + + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result._score if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + with self.Session() as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = self.EmbeddingStore.collection_id == collection.uuid + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + IN = "in" + if isinstance(value, dict) and IN in map(str.lower, value): + value_case_insensitive = { + k.lower(): v for k, v in value.items() + } + filter_by_metadata = self.EmbeddingStore.cmetadata[key].in_( + value_case_insensitive[IN] + ) + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key] == str( + value + ) # type: ignore[assignment] + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + + _type = self.EmbeddingStore + + results: List[Any] = ( + session.query( # type: ignore[attr-defined] + self.EmbeddingStore, + # TODO: Original pgvector code uses `self.distance_strategy`. + # CrateDB currently only supports EUCLIDEAN. + # self.distance_strategy(embedding).label("distance") + sqlalchemy.literal_column( + f"{self.EmbeddingStore.__tablename__}._score" + ).label("_score"), + ) + .filter(filter_by) + # CrateDB applies `KNN_MATCH` within the `WHERE` clause. + .filter( + sqlalchemy.func.knn_match( + self.EmbeddingStore.embedding, embedding, k + ) + ) + .order_by(sqlalchemy.desc("_score")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) + return results + + @classmethod + def from_texts( # type: ignore[override] + cls: Type[CrateDBVectorSearch], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> CrateDBVectorSearch: + """ + Return VectorStore initialized from texts and embeddings. + Database connection string is required. + + Either pass it as a parameter, or set the CRATEDB_CONNECTION_STRING + environment variable. + + Remark: Needs to be overwritten, because CrateDB uses a different + DEFAULT_DISTANCE_STRATEGY. + """ + return super().from_texts( # type: ignore[return-value] + texts, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, # type: ignore[arg-type] + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="CRATEDB_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Database connection string is required." + "Either pass it as a parameter, or set the " + "CRATEDB_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + return str( + sqlalchemy.URL.create( + drivername=driver, + host=host, + port=port, + username=user, + password=password, + query={"schema": database}, + ) + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return self._cosine_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function for distance_strategy of " + "{self._distance_strategy}. Consider providing relevance_score_fn to " + "CrateDBVectorSearch constructor." + ) + + @staticmethod + def _euclidean_relevance_score_fn(score: float) -> float: + """Return a similarity score on a scale [0, 1].""" + # The 'correct' relevance function + # may differ depending on a few things, including: + # - the distance / similarity metric used by the VectorStore + # - the scale of your embeddings (OpenAI's are unit normed. Many + # others are not!) + # - embedding dimensionality + # - etc. + # This function converts the euclidean norm of normalized embeddings + # (0 is most similar, sqrt(2) most dissimilar) + # to a similarity function (0 to 1) + + # Original: + # return 1.0 - distance / math.sqrt(2) + return score / math.sqrt(2) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py new file mode 100644 index 0000000000000..ee42e7269dc9d --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -0,0 +1,84 @@ +from functools import lru_cache +from typing import Optional, Tuple + +import sqlalchemy +from crate.client.sqlalchemy.types import ObjectType +from sqlalchemy.orm import Session, relationship + +from langchain.vectorstores.cratedb.base import BaseModel +from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector + + +@lru_cache +def model_factory(dimensions: int): + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + try: + return ( + session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + class EmbeddingStore(BaseModel): + """Embedding store.""" + + __tablename__ = "embedding" + + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") + + embedding = sqlalchemy.Column(FloatVector(dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + return CollectionStore, EmbeddingStore diff --git a/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py new file mode 100644 index 0000000000000..e784c3013a3d9 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py @@ -0,0 +1,84 @@ +# TODO: Refactor to CrateDB SQLAlchemy dialect. +import typing as t + +import numpy as np +import numpy.typing as npt +import sqlalchemy as sa +from sqlalchemy.types import UserDefinedType + +__all__ = ["FloatVector"] + + +def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: + # from `pgvector.utils` + # could be ndarray if already cast by lower-level driver + if value is None or isinstance(value, np.ndarray): + return value + + return np.array(value, dtype=np.float32) + + +def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: + # from `pgvector.utils` + if value is None: + return value + + if isinstance(value, np.ndarray): + if value.ndim != 1: + raise ValueError("expected ndim to be 1") + + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( + value.dtype, np.floating + ): + raise ValueError("dtype must be numeric") + + value = value.tolist() + + if dim is not None and len(value) != dim: + raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) + + return value + + +class FloatVector(UserDefinedType): + """ + https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector + https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match + """ + + cache_ok = True + + def __init__(self, dim: t.Optional[int] = None) -> None: + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw: t.Any) -> str: + if self.dim is None: + return "FLOAT_VECTOR" + return "FLOAT_VECTOR(%d)" % self.dim + + def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def process(value: t.Iterable) -> t.Optional[t.List]: + return to_db(value, self.dim) + + return process + + def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + return from_db(value) + + return process + + """ + CrateDB currently only supports similarity function `VectorSimilarityFunction.EUCLIDEAN`. + -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55 + + On the other hand, pgvector use a comparator to apply different similarity functions as operators, + see `pgvector.sqlalchemy.Vector.comparator_factory`. + + <->: l2/euclidean_distance + <#>: max_inner_product + <=>: cosine_distance + + TODO: Discuss. + """ # noqa: E501 diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index e2a1e27de40dc..a3fd3bd1ea13a 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -141,6 +141,8 @@ upstash-redis = {version = "^0.15.0", optional = true} google-cloud-documentai = {version = "^2.20.1", optional = true} fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"} javelin-sdk = {version = "^0.1.8", optional = true} +crate = {version = "^0.34.0", extras=["sqlalchemy"], optional = true} +cratedb-toolkit = {version = ">=0.0.1", optional = true} [tool.poetry.group.test.dependencies] @@ -224,6 +226,7 @@ cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] javascript = ["esprima"] +cratedb = ["crate", "cratedb-toolkit"] azure = [ "azure-identity", "azure-cosmos", @@ -307,6 +310,8 @@ all = [ "amadeus", "librosa", "python-arango", + "crate", + "cratedb-toolkit", ] cli = [ @@ -375,6 +380,8 @@ extended_testing = [ "rspace_client", "fireworks-ai", "javelin-sdk", + "crate", + "cratedb-toolkit", ] [tool.ruff] diff --git a/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py new file mode 100644 index 0000000000000..d62f0a125f661 --- /dev/null +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -0,0 +1,445 @@ +""" +Test CrateDB `FLOAT_VECTOR` / `KNN_MATCH` functionality. + +cd tests/integration_tests/vectorstores/docker-compose +docker-compose -f cratedb.yml up +""" +import os +from typing import List, Tuple + +import pytest +import sqlalchemy as sa +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.orm import Session + +from langchain.docstore.document import Document +from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings + +CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( + driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), + host=os.environ.get("TEST_CRATEDB_HOST", "localhost"), + port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")), + database=os.environ.get("TEST_CRATEDB_DATABASE", "testdrive"), + user=os.environ.get("TEST_CRATEDB_USER", "crate"), + password=os.environ.get("TEST_CRATEDB_PASSWORD", ""), +) + + +# TODO: Try 1536 after https://github.com/crate/crate/pull/14699. +# ADA_TOKEN_COUNT = 14 +ADA_TOKEN_COUNT = 1024 +# ADA_TOKEN_COUNT = 1536 + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture(autouse=True) +def drop_tables(engine: sa.Engine) -> None: + """ + Drop database tables. + """ + try: + BaseModel.metadata.drop_all(engine, checkfirst=False) + except Exception as ex: + if "RelationUnknown" not in str(ex): + raise + + +@pytest.fixture +def prune_tables(engine: sa.Engine) -> None: + """ + Delete data from database tables. + """ + with engine.connect() as conn: + with Session(conn) as session: + from langchain.vectorstores.cratedb.model import model_factory + + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for deleting records from tables. + CollectionStore, EmbeddingStore = model_factory(dimensions=1024) + + try: + session.query(CollectionStore).delete() + except ProgrammingError: + pass + try: + session.query(EmbeddingStore).delete() + except ProgrammingError: + pass + + +def decode_output( + output: List[Tuple[Document, float]] +) -> Tuple[List[Document], List[float]]: + """ + Decode a typical API result into separate `documents` and `scores`. + It is needed as utility function in some test cases to compensate + for different and/or flaky score values, when compared to the + original implementation. + """ + documents = [item[0] for item in output] + scores = [round(item[1], 1) for item in output] + return documents, scores + + +class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] + + +def test_cratedb_texts() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = CrateDBVectorSearch.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +def test_cratedb_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1) + # TODO: Original: + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + assert output in [ + [(Document(page_content="foo", metadata={"page": "0"}), 1.0828735)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.1307646)], + ] + + +def test_cratedb_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + # TODO: Original: + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + assert output in [ + [(Document(page_content="foo", metadata={"page": "0"}), 1.2615292)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.3979403)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.5065275)], + ] + + +def test_cratedb_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + # TODO: Original: + # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + output = docsearch.similarity_search_with_score("foo", k=3, filter={"page": "2"}) + # TODO: Original: + # assert output == [ + # (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501 + # ] + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores in [ + [0.5], + [0.6], + [0.7], + ] + + +def test_cratedb_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + +def test_cratedb_collection_with_metadata() -> None: + """Test end to end collection construction""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + cratedb_vector = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + collection = cratedb_vector.get_collection(cratedb_vector.Session()) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + +def test_cratedb_collection_no_embedding_dimension() -> None: + """Test end to end collection construction""" + cratedb_vector = CrateDBVectorSearch( + embedding_function=None, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = Session(cratedb_vector.connect()) + with pytest.raises(RuntimeError) as ex: + cratedb_vector.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying dimension size of embedding vectors" + ) + + +def test_cratedb_with_filter_in_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"IN": ["0", "2"]}} + ) + # TODO: Original: + """ + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 0.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), + ] + """ + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores == [2.1, 1.3] + + +def test_cratedb_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.delete(["1", "2"]) + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == ["3"] # type: ignore + + docsearch.delete(["2", "3"]) # Should not raise on missing ids + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == [] # type: ignore + + +def test_cratedb_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = docsearch.similarity_search_with_relevance_scores("foo", k=3) + """ + # TODO: Original code, where the `distance` is stable. + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), + (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), + ] + """ + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores == [0.8, 0.4, 0.2] + + +def test_cratedb_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + # TODO: Original: + # search_kwargs={"k": 3, "score_threshold": 0.999}, + search_kwargs={"k": 3, "score_threshold": 0.333}, + ) + output = retriever.get_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + +def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None: + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = retriever.get_relevant_documents("foo") + assert output == [] + + +def test_cratedb_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + +def test_cratedb_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) + # TODO: Original: + # assert output == [(Document(page_content="foo"), 0.0)] + assert output in [ + [(Document(page_content="foo"), 1.0606961)], + [(Document(page_content="foo"), 1.0828735)], + [(Document(page_content="foo"), 1.1307646)], + ] From 983ca37f178001e2a4ae90b519c99eca8608d345 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 00:15:12 +0200 Subject: [PATCH 04/17] CrateDB vector: Add documentation --- docs/docs/integrations/providers/cratedb.mdx | 136 +++++ .../integrations/vectorstores/cratedb.ipynb | 479 ++++++++++++++++++ docs/vercel.json | 12 + 3 files changed, 627 insertions(+) create mode 100644 docs/docs/integrations/providers/cratedb.mdx create mode 100644 docs/docs/integrations/vectorstores/cratedb.ipynb diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx new file mode 100644 index 0000000000000..d3e60bf53ee21 --- /dev/null +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -0,0 +1,136 @@ +# CrateDB + +This documentation section shows how to use the CrateDB vector store +functionality around [`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn +how to use it for similarity search and other purposes. + + +## What is CrateDB? + +[CrateDB] is an open-source, distributed, and scalable SQL analytics database +for storing and analyzing massive amounts of data in near real-time, even with +complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits +the shared-nothing distribution layer of [Elasticsearch]. + +It provides a distributed, multi-tenant-capable relational database and search +engine with HTTP and PostgreSQL interfaces, and schema-free objects. It supports +sharding, partitioning, and replication out of the box. + +CrateDB enables you to efficiently store billions of records, and terabytes of +data, and query it using SQL. + +- Provides a standards-based SQL interface for querying relational data, nested + documents, geospatial constraints, and vector embeddings at the same time. +- Improves your operations by storing time-series data, relational metadata, + and vector embeddings within a single database. +- Builds upon approved technologies from Lucene and Elasticsearch. + + +## CrateDB Cloud + +- Offers on-demand CrateDB clusters without operational overhead, + with enterprise-grade features and [ISO 27001] certification. +- The entrypoint to [CrateDB Cloud] is the [CrateDB Cloud Console]. +- Crate.io offers a free tier via [CrateDB Cloud CRFREE]. +- To get started, [sign up] to CrateDB Cloud, deploy a database cluster, + and follow the upcoming instructions. + + +## Features + +The CrateDB adapter supports the Vector Store subsystem of LangChain. + +### Vector Store + +`CrateDBVectorSearch` is an API wrapper around CrateDB's `FLOAT_VECTOR` type +and the corresponding `KNN_MATCH` function, based on SQLAlchemy and CrateDB's +SQLAlchemy dialect. It provides an interface to store and retrieve floating +point vectors, and to conduct similarity searches. + +Supports: +- Approximate nearest neighbor search. +- Euclidean distance. + + +## Installation and Setup + +There are multiple ways to get started with CrateDB. + +### Install CrateDB on your local machine + +You can [download CrateDB], or use the [OCI image] to run CrateDB on Docker or Podman. +Note that this is not recommended for production use. + +```shell +docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ + --env=CRATE_HEAP_SIZE=4g crate/crate:nightly \ + -Cdiscovery.type=single-node +``` + +### Deploy a cluster on CrateDB Cloud + +[CrateDB Cloud] is a managed CrateDB service. Sign up for a [free trial]. + +### Install Client + +```bash +pip install 'crate[sqlalchemy]' 'langchain[openai]' +``` + + +## Usage + +For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also +a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html). + +### Acquire text file +The example uses the canonical `state_of_the_union.txt`. +```shell +wget https://raw.githubusercontent.com/langchain-ai/langchain/v0.0.291/docs/extras/modules/state_of_the_union.txt +``` + +### Set environment variables +Use a valid OpenAI API key and SQL connection string. This one fits a local instance of CrateDB. +```shell +export OPENAI_API_KEY=foobar # FIXME +export CRATEDB_CONNECTION_STRING=crate://crate@localhost +``` + +```python +from langchain.document_loaders import TextLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import CrateDBVectorSearch + + +def main(): + # Load the document, split it into chunks, embed each chunk and load it into the vector store. + raw_documents = TextLoader("state_of_the_union.txt").load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + documents = text_splitter.split_documents(raw_documents) + db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings()) + + query = "What did the president say about Ketanji Brown Jackson" + docs = db.similarity_search(query) + print(docs[0].page_content) + + +if __name__ == "__main__": + main() +``` + + +[CrateDB]: https://github.com/crate/crate +[CrateDB Cloud]: https://crate.io/product +[CrateDB Cloud Console]: https://console.cratedb.cloud/ +[CrateDB Cloud CRFREE]: https://community.crate.io/t/new-cratedb-cloud-edge-feature-cratedb-cloud-free-tier/1402 +[CrateDB SQLAlchemy dialect]: https://crate.io/docs/python/en/latest/sqlalchemy.html +[download CrateDB]: https://crate.io/download +[Elastisearch]: https://github.com/elastic/elasticsearch +[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector +[free trial]: https://crate.io/lp-crfree?utm_source=langchain +[ISO 27001]: https://crate.io/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification +[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match +[Lucene]: https://github.com/apache/lucene +[OCI image]: https://hub.docker.com/_/crate +[sign up]: https://console.cratedb.cloud/ diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb new file mode 100644 index 0000000000000..462e721bfff40 --- /dev/null +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook shows how to use the CrateDB vector store functionality around\n", + "[`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn how to use it for similarity\n", + "search and other purposes.\n", + "\n", + "It supports:\n", + "- Similarity Search with Euclidean Distance\n", + "- Maximal Marginal Relevance Search (MMR)\n", + "\n", + "## What is CrateDB?\n", + "\n", + "[CrateDB] is an open-source, distributed, and scalable SQL analytics database\n", + "for storing and analyzing massive amounts of data in near real-time, even with\n", + "complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits\n", + "the shared-nothing distribution layer of [Elasticsearch].\n", + "\n", + "This example uses the [Python client driver for CrateDB]. For more documentation,\n", + "see also [LangChain with CrateDB].\n", + "\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[Elasticsearch]: https://github.com/elastic/elasticsearch\n", + "[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector\n", + "[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match\n", + "[LangChain with CrateDB]: /docs/extras/integrations/providers/cratedb.html\n", + "[Lucene]: https://github.com/apache/lucene\n", + "[Python client driver for CrateDB]: https://crate.io/docs/python/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Getting Started" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [], + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "# Install required packages: LangChain, OpenAI SDK, and the CrateDB Python driver.\n", + "!pip install 'langchain[cratedb,openai]'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to provide an OpenAI API key, optionally using the environment\n", + "variable `OPENAI_API_KEY`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:16.802456Z", + "start_time": "2023-09-09T08:02:07.065604Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "# Run `export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY`.\n", + "# Get OpenAI api key from `.env` file.\n", + "# Otherwise, prompt for it.\n", + "_ = load_dotenv(find_dotenv())\n", + "OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', getpass.getpass(\"OpenAI API key:\"))\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You also need to provide a connection string to your CrateDB database cluster,\n", + "optionally using the environment variable `CRATEDB_CONNECTION_STRING`.\n", + "\n", + "This example uses a CrateDB instance on your workstation, which you can start by\n", + "running [CrateDB using Docker]. Alternatively, you can also connect to a cluster\n", + "running on [CrateDB Cloud].\n", + "\n", + "[CrateDB Cloud]: https://console.cratedb.cloud/\n", + "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import os\n", + "\n", + "CONNECTION_STRING = os.environ.get(\n", + " \"CRATEDB_CONNECTION_STRING\",\n", + " \"crate://crate@localhost:4200/?schema=langchain\",\n", + ")\n", + "\n", + "# For CrateDB Cloud, use:\n", + "# CONNECTION_STRING = os.environ.get(\n", + "# \"CRATEDB_CONNECTION_STRING\",\n", + "# \"crate://username:password@hostname:4200/?ssl=true&schema=langchain\",\n", + "# )" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:28.174088Z", + "start_time": "2023-09-09T08:02:28.162698Z" + } + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "# Alternatively, the connection string can be assembled from individual\n", + "# environment variables.\n", + "import os\n", + "\n", + "CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params(\n", + " driver=os.environ.get(\"CRATEDB_DRIVER\", \"crate\"),\n", + " host=os.environ.get(\"CRATEDB_HOST\", \"localhost\"),\n", + " port=int(os.environ.get(\"CRATEDB_PORT\", \"4200\")),\n", + " database=os.environ.get(\"CRATEDB_DATABASE\", \"langchain\"),\n", + " user=os.environ.get(\"CRATEDB_USER\", \"crate\"),\n", + " password=os.environ.get(\"CRATEDB_PASSWORD\", \"\"),\n", + ")\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You will start by importing all required modules." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import CrateDBVectorSearch\n", + "from langchain.document_loaders import UnstructuredURLLoader\n", + "from langchain.docstore.document import Document" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Next, you will read input data, and tokenize it." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Similarity Search with Euclidean Distance (Default)\n", + "\n", + "The module will create a table with the name of the collection. Make sure\n", + "the collection name is unique and that you have the permission to create\n", + "a table." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:04:16.696625Z", + "start_time": "2023-09-09T08:02:31.817790Z" + } + }, + "outputs": [], + "source": [ + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "db = CrateDBVectorSearch.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:11.104135Z", + "start_time": "2023-09-09T08:05:10.548998Z" + } + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs_with_score = db.similarity_search_with_score(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:13.532334Z", + "start_time": "2023-09-09T08:05:13.523191Z" + } + }, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Maximal Marginal Relevance Search (MMR)\n", + "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "docs_with_score = db.max_marginal_relevance_search_with_score(query)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:23.276819Z", + "start_time": "2023-09-09T08:05:21.972256Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:27.478580Z", + "start_time": "2023-09-09T08:05:27.470138Z" + } + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Working with the vector store\n", + "\n", + "In the example above, you created a vector store from scratch. When\n", + "aiming to work with an existing vector store, you can initialize it directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store = CrateDBVectorSearch(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " embedding_function=embeddings,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Add documents\n", + "\n", + "You can also add documents to an existing vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store.add_documents([Document(page_content=\"foo\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Overwriting a vector store\n", + "\n", + "If you have an existing collection, you can overwrite it by using `from_documents`,\n", + "aad setting `pre_delete_collection = True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db = CrateDBVectorSearch.from_documents(\n", + " documents=docs,\n", + " embedding=embeddings,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " pre_delete_collection=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using a vector store as a retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = store.as_retriever()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(retriever)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/vercel.json b/docs/vercel.json index 87566d64cc57a..ce41ee8f0ce08 100644 --- a/docs/vercel.json +++ b/docs/vercel.json @@ -548,6 +548,10 @@ "source": "/docs/integrations/chaindesk", "destination": "/docs/integrations/providers/chaindesk" }, + { + "source": "/docs/integrations/cratedb", + "destination": "/docs/integrations/providers/cratedb" + }, { "source": "/docs/integrations/databricks", "destination": "/docs/integrations/providers/databricks" @@ -2676,6 +2680,14 @@ "source": "/docs/modules/data_connection/vectorstores/integrations/chroma", "destination": "/docs/integrations/vectorstores/chroma" }, + { + "source": "/en/latest/modules/indexes/vectorstores/examples/cratedb.html", + "destination": "/docs/integrations/vectorstores/cratedb" + }, + { + "source": "/docs/modules/data_connection/vectorstores/integrations/cratedb", + "destination": "/docs/integrations/vectorstores/cratedb" + }, { "source": "/en/latest/modules/indexes/vectorstores/examples/deeplake.html", "destination": "/docs/integrations/vectorstores/activeloop_deeplake" From f6a697ad14d194a64e31e5a5cfd4de9ba36c0ca9 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 20:00:09 +0200 Subject: [PATCH 05/17] Add SQLAlchemy document loader --- .../document_loaders/sqlalchemy.ipynb | 237 ++++++++++++++++++ .../document_loaders/sqlalchemy.mdx | 155 ++++++++++++ .../langchain/document_loaders/__init__.py | 4 + .../langchain/document_loaders/sqlalchemy.py | 112 +++++++++ libs/langchain/pyproject.toml | 2 +- libs/langchain/tests/data.py | 4 + .../docker-compose/postgresql.yml | 19 ++ .../test_sqlalchemy_postgresql.py | 177 +++++++++++++ .../test_sqlalchemy_sqlite.py | 181 +++++++++++++ .../examples/mlb_teams_2012.csv | 32 +++ .../examples/mlb_teams_2012.sql | 38 +++ 11 files changed, 960 insertions(+), 1 deletion(-) create mode 100644 docs/docs/integrations/document_loaders/sqlalchemy.ipynb create mode 100644 docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx create mode 100644 libs/langchain/langchain/document_loaders/sqlalchemy.py create mode 100644 libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml create mode 100644 libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py create mode 100644 libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py create mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv create mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb new file mode 100644 index 0000000000000..7bb141592296a --- /dev/null +++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQLAlchemy\n", + "\n", + "This notebook demonstrates how to load documents from an [SQLite] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/\n", + "[SQLite]: https://sqlite.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install langchain termsql" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Provide input data as SQLite database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting example.csv\n" + ] + } + ], + "source": [ + "%%file example.csv\n", + "Team,Payroll\n", + "Nationals,81.34\n", + "Reds,82.20" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nationals|81.34\r\n", + "Reds|82.2\r\n" + ] + } + ], + "source": [ + "!termsql --infile=example.csv --head --delimiter=\",\" --outfile=example.sqlite --table=payroll" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader\n", + "from pprint import pprint\n", + "\n", + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals', metadata={'Payroll': 81.34}),\n", + " Document(page_content='Team: Reds', metadata={'Payroll': 82.2})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={'source': 'Nationals'}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={'source': 'Reds'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx new file mode 100644 index 0000000000000..9f7e663db075e --- /dev/null +++ b/docs/docs/modules/data_connection/document_loaders/sqlalchemy.mdx @@ -0,0 +1,155 @@ +# SQLAlchemy + + +## About + +The [SQLAlchemy] document loader loads records from any supported database, +see [SQLAlchemy dialects] for all supported SQL databases and dialects. + +You can either use plain SQL for querying, or use an SQLAlchemy `Select` +statement object, if you are using SQLAlchemy-Core or -ORM. + +You can select which columns to place into the document, which columns +to place into its metadata, which columns to use as a `source` attribute +in metadata, and whether to include the result row number and/or the SQL +query expression into the metadata. + + +## Example + +This example uses PostgreSQL, and the `psycopg2` driver. + + +### Prerequisites + +```shell +psql postgresql://postgres@localhost/ --command "CREATE DATABASE testdrive;" +psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +``` + + +### Basic loading + +```python +from langchain.document_loaders import SQLAlchemyLoader +from pprint import pprint + + +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={})] +``` + + + + +## Enriching metadata + +Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to +optionally populate the `metadata` dictionary with corresponding information. + +Having the `query` within metadata is useful when using documents loaded from +database tables for chains that answer questions using their origin queries. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + include_rownum_into_metadata=True, + include_query_into_metadata=True, +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'row': 2, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'})] +``` + + + + +## Customizing metadata + +Use the `page_content_columns`, and `metadata_columns` options to optionally populate +the `metadata` dictionary with corresponding information. When `page_content_columns` +is empty, all columns will be used. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + page_content_columns=["Payroll (millions)", "Wins"], + metadata_columns=["Team"], +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Payroll (millions): 81.34\nWins: 98', metadata={'Team': 'Nationals'}), + Document(page_content='Payroll (millions): 82.2\nWins: 97', metadata={'Team': 'Reds'}), + Document(page_content='Payroll (millions): 197.96\nWins: 95', metadata={'Team': 'Yankees'})] +``` + + + + +## Specify column(s) to identify the document source + +Use the `source_columns` option to specify the columns to use as a "source" for the +document created from each row. This is useful for identifying documents through +their metadata. Typically, you may use the primary key column(s) for that purpose. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + source_columns="Team", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'source': 'Nationals'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'source': 'Reds'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'source': 'Yankees'})] +``` + + + + +[SQLAlchemy]: https://www.sqlalchemy.org/ +[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/20/dialects/ diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 96cfd9e1b5486..a454a1baa9eb4 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -59,6 +59,7 @@ from langchain.document_loaders.concurrent import ConcurrentLoader from langchain.document_loaders.confluence import ConfluenceLoader from langchain.document_loaders.conllu import CoNLLULoader +from langchain.document_loaders.cratedb import CrateDBLoader from langchain.document_loaders.csv_loader import CSVLoader, UnstructuredCSVLoader from langchain.document_loaders.cube_semantic import CubeSemanticLoader from langchain.document_loaders.datadog_logs import DatadogLogsLoader @@ -158,6 +159,7 @@ from langchain.document_loaders.slack_directory import SlackDirectoryLoader from langchain.document_loaders.snowflake_loader import SnowflakeLoader from langchain.document_loaders.spreedly import SpreedlyLoader +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader from langchain.document_loaders.srt import SRTLoader from langchain.document_loaders.stripe import StripeLoader from langchain.document_loaders.telegram import ( @@ -244,6 +246,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CrateDBLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -333,6 +336,7 @@ "SlackDirectoryLoader", "SnowflakeLoader", "SpreedlyLoader", + "SQLAlchemyLoader", "StripeLoader", "TelegramChatApiLoader", "TelegramChatFileLoader", diff --git a/libs/langchain/langchain/document_loaders/sqlalchemy.py b/libs/langchain/langchain/document_loaders/sqlalchemy.py new file mode 100644 index 0000000000000..787c9f339b686 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/sqlalchemy.py @@ -0,0 +1,112 @@ +from typing import Dict, List, Optional, Union + +import sqlalchemy as sa + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class SQLAlchemyLoader(BaseLoader): + """ + Load documents by querying database tables supported by SQLAlchemy. + Each document represents one row of the result. + """ + + def __init__( + self, + query: Union[str, sa.Select], + url: str, + page_content_columns: Optional[List[str]] = None, + metadata_columns: Optional[List[str]] = None, + source_columns: Optional[List[str]] = None, + include_rownum_into_metadata: bool = False, + include_query_into_metadata: bool = False, + sqlalchemy_kwargs: Optional[Dict] = None, + ): + """ + + Args: + query: The query to execute. + url: The SQLAlchemy connection string of the database to connect to. + page_content_columns: The columns to write into the `page_content` + of the document. Optional. + metadata_columns: The columns to write into the `metadata` of the document. + Optional. + source_columns: The names of the columns to use as the `source` within the + metadata dictionary. Optional. + include_rownum_into_metadata: Whether to include the row number into the + metadata dictionary. Optional. Default: False. + include_query_into_metadata: Whether to include the query expression into + the metadata dictionary. Optional. Default: False. + sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`. + """ + self.query = query + self.url = url + self.page_content_columns = page_content_columns + self.metadata_columns = metadata_columns + self.source_columns = source_columns + self.include_rownum_into_metadata = include_rownum_into_metadata + self.include_query_into_metadata = include_query_into_metadata + self.sqlalchemy_kwargs = sqlalchemy_kwargs or {} + + def load(self) -> List[Document]: + try: + import sqlalchemy as sa + except ImportError: + raise ImportError( + "Could not import sqlalchemy python package. " + "Please install it with `pip install sqlalchemy`." + ) + + engine = sa.create_engine(self.url, **self.sqlalchemy_kwargs) + + docs = [] + with engine.connect() as conn: + if isinstance(self.query, sa.Select): + result = conn.execute(self.query) + query_sql = str(self.query.compile(bind=engine)) + elif isinstance(self.query, str): + result = conn.execute(sa.text(self.query)) + query_sql = self.query + else: + raise TypeError( + f"Unable to process query of unknown type: {self.query}" + ) + field_names = list(result.mappings().keys()) + + if self.page_content_columns is None: + page_content_columns = field_names + else: + page_content_columns = self.page_content_columns + + if self.metadata_columns is None: + metadata_columns = [] + else: + metadata_columns = self.metadata_columns + + for i, row in enumerate(result.mappings()): + page_content = "\n".join( + f"{column}: {value}" + for column, value in row.items() + if column in page_content_columns + ) + + metadata: Dict[str, Union[str, int]] = {} + if self.include_rownum_into_metadata: + metadata["row"] = i + if self.include_query_into_metadata: + metadata["query"] = query_sql + + source_values = [] + for column, value in row.items(): + if column in metadata_columns: + metadata[column] = value + if self.source_columns and column in self.source_columns: + source_values.append(value) + if source_values: + metadata["source"] = ",".join(source_values) + + doc = Document(page_content=page_content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index a3fd3bd1ea13a..0fd542ea62697 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -163,7 +163,7 @@ pytest-mock = "^3.10.0" pytest-socket = "^0.6.0" syrupy = "^4.0.2" requests-mock = "^1.11.0" - +sqlparse = "^0.4.4" [tool.poetry.group.codespell.dependencies] codespell = "^2.2.0" diff --git a/libs/langchain/tests/data.py b/libs/langchain/tests/data.py index c3b240bbc57cd..be48867430641 100644 --- a/libs/langchain/tests/data.py +++ b/libs/langchain/tests/data.py @@ -9,3 +9,7 @@ HELLO_PDF = _EXAMPLES_DIR / "hello.pdf" LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf" DUPLICATE_CHARS = _EXAMPLES_DIR / "duplicate-chars.pdf" + +# Paths to data files +MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv" +MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql" diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml new file mode 100644 index 0000000000000..f8ab2cfdeb418 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/postgresql.yml @@ -0,0 +1,19 @@ +version: "3" + +services: + postgresql: + image: postgres:16 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + ports: + - "5432:5432" + command: | + postgres -c log_statement=all + healthcheck: + test: + [ + "CMD-SHELL", + "psql postgresql://postgres@localhost --command 'SELECT 1;' || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py new file mode 100644 index 0000000000000..29f52cb9f7a33 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py @@ -0,0 +1,177 @@ +""" +Test SQLAlchemy/PostgreSQL document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f postgresql.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import psycopg2 # noqa: F401 + + psycopg2_installed = True +except ImportError: + psycopg2_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_POSTGRESQL_CONNECTION_STRING", + "postgresql+psycopg2://postgres@localhost:5432/", +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_no_options() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_rownum_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_query_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py new file mode 100644 index 0000000000000..f1fac2cecbc00 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py @@ -0,0 +1,181 @@ +""" +Test SQLAlchemy/SQLite document loader functionality. +""" +import logging +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse +from _pytest.tmpdir import TempPathFactory + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import sqlite3 # noqa: F401 + + sqlite3_installed = True +except ImportError: + sqlite3_installed = False + + +@pytest.fixture(scope="module") +def db_uri(tmp_path_factory: TempPathFactory) -> str: + """ + Return an SQLAlchemy URI for a temporary SQLite database. + """ + db_path = tmp_path_factory.getbasetemp().joinpath("testdrive.sqlite") + return f"sqlite:///{db_path}" + + +@pytest.fixture(scope="module") +def engine(db_uri: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(db_uri, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_no_options(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=db_uri) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_rownum_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_query_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=db_uri, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_page_content_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=db_uri, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_metadata_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_sql( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=db_uri, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_selectable( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=db_uri, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv new file mode 100644 index 0000000000000..b22ae961a1331 --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv @@ -0,0 +1,32 @@ +"Team", "Payroll (millions)", "Wins" +"Nationals", 81.34, 98 +"Reds", 82.20, 97 +"Yankees", 197.96, 95 +"Giants", 117.62, 94 +"Braves", 83.31, 94 +"Athletics", 55.37, 94 +"Rangers", 120.51, 93 +"Orioles", 81.43, 93 +"Rays", 64.17, 90 +"Angels", 154.49, 89 +"Tigers", 132.30, 88 +"Cardinals", 110.30, 88 +"Dodgers", 95.14, 86 +"White Sox", 96.92, 85 +"Brewers", 97.65, 83 +"Phillies", 174.54, 81 +"Diamondbacks", 74.28, 81 +"Pirates", 63.43, 79 +"Padres", 55.24, 76 +"Mariners", 81.97, 75 +"Mets", 93.35, 74 +"Blue Jays", 75.48, 73 +"Royals", 60.91, 72 +"Marlins", 118.07, 69 +"Red Sox", 173.18, 69 +"Indians", 78.43, 68 +"Twins", 94.08, 66 +"Rockies", 78.06, 64 +"Cubs", 88.19, 61 +"Astros", 60.65, 55 + diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql new file mode 100644 index 0000000000000..91029ddcd3563 --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql @@ -0,0 +1,38 @@ +-- psql postgresql://postgres@localhost < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; From 5f1ca0da5199320db33430ff1d94a8f8e1dedd75 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 20:01:15 +0200 Subject: [PATCH 06/17] CrateDB loader: Add document loader support The implementation is based on the generic SQLAlchemy document loader. --- .../document_loaders/cratedb.ipynb | 232 ++++++++++++++++++ .../example_data/mlb_teams_2012.sql | 41 ++++ .../document_loaders/sqlalchemy.ipynb | 2 +- docs/docs/integrations/providers/cratedb.mdx | 51 +++- docs/vercel.json | 4 + .../langchain/document_loaders/cratedb.py | 5 + .../docker-compose/cratedb.yml | 20 ++ .../test_sqlalchemy_cratedb.py | 146 +++++++++++ .../examples/mlb_teams_2012.sql | 5 +- 9 files changed, 497 insertions(+), 9 deletions(-) create mode 100644 docs/docs/integrations/document_loaders/cratedb.ipynb create mode 100644 docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql create mode 100644 libs/langchain/langchain/document_loaders/cratedb.py create mode 100644 libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml create mode 100644 libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py diff --git a/docs/docs/integrations/document_loaders/cratedb.ipynb b/docs/docs/integrations/document_loaders/cratedb.ipynb new file mode 100644 index 0000000000000..78a0e19138703 --- /dev/null +++ b/docs/docs/integrations/document_loaders/cratedb.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook demonstrates how to load documents from a [CrateDB] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install crash 'langchain[cratedb]'" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Populate database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mPSQL OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[32mDELETE OK, 30 rows affected (0.008 sec)\r\n", + "\u001B[0m\u001B[32mINSERT OK, 30 rows affected (0.011 sec)\r\n", + "\u001B[0m\u001B[0m\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mREFRESH OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[0m" + ] + } + ], + "source": [ + "!crash < ./example_data/mlb_teams_2012.sql\n", + "!crash --command \"REFRESH TABLE mlb_teams_2012;\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import CrateDBLoader\n", + "from pprint import pprint\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost/\"\n", + "\n", + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll (millions)\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels', metadata={'Payroll (millions)': 154.49}),\n", + " Document(page_content='Team: Astros', metadata={'Payroll (millions)': 60.65}),\n", + " Document(page_content='Team: Athletics', metadata={'Payroll (millions)': 55.37}),\n", + " Document(page_content='Team: Blue Jays', metadata={'Payroll (millions)': 75.48}),\n", + " Document(page_content='Team: Braves', metadata={'Payroll (millions)': 83.31})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={'source': 'Angels'}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={'source': 'Astros'}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={'source': 'Athletics'}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={'source': 'Blue Jays'}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={'source': 'Braves'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql new file mode 100644 index 0000000000000..6d94aeaa773b8 --- /dev/null +++ b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql @@ -0,0 +1,41 @@ +-- Provisioning table "mlb_teams_2012". +-- +-- crash < mlb_teams_2012.sql +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb index 7bb141592296a..5d603d7263c53 100644 --- a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb +++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb @@ -103,7 +103,7 @@ }, "outputs": [], "source": [ - "from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader\n", + "from langchain.document_loaders import SQLAlchemyLoader\n", "from pprint import pprint\n", "\n", "loader = SQLAlchemyLoader(\n", diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index d3e60bf53ee21..94a16b8935537 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -38,7 +38,8 @@ data, and query it using SQL. ## Features -The CrateDB adapter supports the Vector Store subsystem of LangChain. +The CrateDB adapter supports the _Vector Store_ and _Document Loader_ +subsystems of LangChain. ### Vector Store @@ -51,6 +52,11 @@ Supports: - Approximate nearest neighbor search. - Euclidean distance. +### Document Loader + +`CrateDBLoader` provides loading documents from a database table by an SQL +query expression or an SQLAlchemy selectable instance. + ## Installation and Setup @@ -74,19 +80,19 @@ docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ ### Install Client ```bash -pip install 'crate[sqlalchemy]' 'langchain[openai]' +pip install 'crate[sqlalchemy]' 'langchain[openai]' 'crash' ``` -## Usage +## Usage » Vector Store For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html). -### Acquire text file +### Provide input data The example uses the canonical `state_of_the_union.txt`. ```shell -wget https://raw.githubusercontent.com/langchain-ai/langchain/v0.0.291/docs/extras/modules/state_of_the_union.txt +wget https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt ``` ### Set environment variables @@ -97,7 +103,7 @@ export CRATEDB_CONNECTION_STRING=crate://crate@localhost ``` ```python -from langchain.document_loaders import TextLoader +from langchain.document_loaders import UnstructuredURLLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import CrateDBVectorSearch @@ -105,7 +111,7 @@ from langchain.vectorstores import CrateDBVectorSearch def main(): # Load the document, split it into chunks, embed each chunk and load it into the vector store. - raw_documents = TextLoader("state_of_the_union.txt").load() + raw_documents = UnstructuredURLLoader("https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt").load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) documents = text_splitter.split_documents(raw_documents) db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings()) @@ -120,6 +126,37 @@ if __name__ == "__main__": ``` +## Usage » Document Loader + +For a more detailed walkthrough of the `CrateDBLoader`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/document_loaders/cratedb.html). + + +### Provide input data +```shell +wget https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql +crash < ./example_data/mlb_teams_2012.sql +crash --command "REFRESH TABLE mlb_teams_2012;" +``` + +### Load documents by SQL query +```python +from langchain.document_loaders import CrateDBLoader +from pprint import pprint + +def main(): + loader = CrateDBLoader( + 'SELECT * FROM mlb_teams_2012 ORDER BY "Team" LIMIT 5;', + url="crate://crate@localhost/", + ) + documents = loader.load() + pprint(documents) + +if __name__ == "__main__": + main() +``` + + [CrateDB]: https://github.com/crate/crate [CrateDB Cloud]: https://crate.io/product [CrateDB Cloud Console]: https://console.cratedb.cloud/ diff --git a/docs/vercel.json b/docs/vercel.json index ce41ee8f0ce08..22130dfa09a85 100644 --- a/docs/vercel.json +++ b/docs/vercel.json @@ -1732,6 +1732,10 @@ "source": "/docs/modules/data_connection/document_loaders/integrations/copypaste", "destination": "/docs/integrations/document_loaders/copypaste" }, + { + "source": "/docs/modules/data_connection/document_loaders/integrations/cratedb", + "destination": "/docs/integrations/document_loaders/cratedb" + }, { "source": "/en/latest/modules/indexes/document_loaders/examples/csv.html", "destination": "/docs/integrations/document_loaders/csv" diff --git a/libs/langchain/langchain/document_loaders/cratedb.py b/libs/langchain/langchain/document_loaders/cratedb.py new file mode 100644 index 0000000000000..9e34b4d0cb9ec --- /dev/null +++ b/libs/langchain/langchain/document_loaders/cratedb.py @@ -0,0 +1,5 @@ +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader + + +class CrateDBLoader(SQLAlchemyLoader): + pass diff --git a/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py new file mode 100644 index 0000000000000..eec3a428a74e8 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py @@ -0,0 +1,146 @@ +""" +Test SQLAlchemy/CrateDB document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f cratedb.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders import CrateDBLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + +try: + import crate.client.sqlalchemy # noqa: F401 + + crate_client_installed = True +except ImportError: + crate_client_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) + connection.commit() + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_no_options() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = CrateDBLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql index 91029ddcd3563..9df72ef19954a 100644 --- a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql @@ -1,4 +1,7 @@ --- psql postgresql://postgres@localhost < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +-- Provisioning table "mlb_teams_2012". +-- +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql +-- crash < mlb_teams_2012.sql DROP TABLE IF EXISTS mlb_teams_2012; CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); From a6b51ef803a7cde6ed25a55d46dd2b43537c310c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 17 Sep 2023 19:40:34 +0200 Subject: [PATCH 07/17] Generalize `SQLChatMessageHistory` to make code a bit more reusable --- .../memory/chat_message_histories/sql.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/memory/chat_message_histories/sql.py b/libs/langchain/langchain/memory/chat_message_histories/sql.py index 610d049c51ed2..6af96bbbe27ff 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/sql.py +++ b/libs/langchain/langchain/memory/chat_message_histories/sql.py @@ -1,9 +1,9 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, List, Optional, Type -from sqlalchemy import Column, Integer, Text, create_engine +from sqlalchemy import Column, Integer, Select, Text, create_engine, select try: from sqlalchemy.orm import declarative_base @@ -22,6 +22,10 @@ class BaseMessageConverter(ABC): """The class responsible for converting BaseMessage to your SQLAlchemy model.""" + @abstractmethod + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + @abstractmethod def from_sql_model(self, sql_message: Any) -> BaseMessage: """Convert a SQLAlchemy model to a BaseMessage instance.""" @@ -51,7 +55,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore """ - # Model decleared inside a function to have a dynamic table name + # Model declared inside a function to have a dynamic table name class Message(DynamicBase): __tablename__ = table_name id = Column(Integer, primary_key=True) @@ -82,6 +86,8 @@ def get_sql_model_class(self) -> Any: class SQLChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in an SQL database.""" + DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter + def __init__( self, session_id: str, @@ -93,7 +99,9 @@ def __init__( self.connection_string = connection_string self.engine = create_engine(connection_string, echo=False) self.session_id_field_name = session_id_field_name - self.converter = custom_message_converter or DefaultMessageConverter(table_name) + self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER( + table_name + ) self.sql_model_class = self.converter.get_sql_model_class() if not hasattr(self.sql_model_class, session_id_field_name): raise ValueError("SQL model class must have session_id column") @@ -105,21 +113,25 @@ def __init__( def _create_table_if_not_exists(self) -> None: self.sql_model_class.metadata.create_all(self.engine) + def _messages_query(self) -> Select: + """Construct an SQLAlchemy selectable to query for messages""" + return ( + select(self.sql_model_class) + .where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + .order_by(self.sql_model_class.id.asc()) + ) + @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" with self.Session() as session: - result = ( - session.query(self.sql_model_class) - .where( - getattr(self.sql_model_class, self.session_id_field_name) - == self.session_id - ) - .order_by(self.sql_model_class.id.asc()) - ) + result = session.execute(self._messages_query()) messages = [] for record in result: - messages.append(self.converter.from_sql_model(record)) + messages.append(self.converter.from_sql_model(record[0])) return messages def add_message(self, message: BaseMessage) -> None: From 01a708a282e42daa08096791aa93594579849ecb Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 17 Sep 2023 19:44:57 +0200 Subject: [PATCH 08/17] CrateDB memory: Add conversational memory support The implementation is based on the generic `SQLChatMessageHistory`. --- .../memory/cratedb_chat_message_history.ipynb | 357 ++++++++++++++++++ docs/docs/integrations/providers/cratedb.mdx | 31 +- docs/vercel.json | 8 + .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/cratedb.py | 113 ++++++ .../integration_tests/memory/test_cratedb.py | 170 +++++++++ 6 files changed, 679 insertions(+), 2 deletions(-) create mode 100644 docs/docs/integrations/memory/cratedb_chat_message_history.ipynb create mode 100644 libs/langchain/langchain/memory/chat_message_histories/cratedb.py create mode 100644 libs/langchain/tests/integration_tests/memory/test_cratedb.py diff --git a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb new file mode 100644 index 0000000000000..f51f5f1d63fca --- /dev/null +++ b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# CrateDB Chat Message History\n", + "\n", + "This notebook demonstrates how to use the `CrateDBChatMessageHistory`\n", + "to manage chat history in CrateDB, for supporting conversational memory." + ], + "metadata": { + "collapsed": false + }, + "id": "f22eab3f84cbeb37" + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!#pip install 'langchain[cratedb]'" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuration\n", + "\n", + "To use the storage wrapper, you will need to configure two details.\n", + "\n", + "1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n", + "2. Database connection string: An SQLAlchemy-compatible URI that specifies the database\n", + " connection. It will be passed to SQLAlchemy create_engine function." + ], + "metadata": { + "collapsed": false + }, + "id": "f8f2830ee9ca1e01" + }, + { + "cell_type": "code", + "execution_count": 52, + "outputs": [], + "source": [ + "from langchain.memory.chat_message_histories import CrateDBChatMessageHistory\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost:4200/?schema=example\"\n", + "\n", + "chat_message_history = CrateDBChatMessageHistory(\n", + "\tsession_id=\"test_session\",\n", + "\tconnection_string=CONNECTION_STRING\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 53, + "outputs": [], + "source": [ + "chat_message_history.add_user_message(\"Hello\")\n", + "chat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.077748Z", + "start_time": "2023-08-28T10:04:36.105894Z" + } + }, + "id": "4576e914a866fb40" + }, + { + "cell_type": "code", + "execution_count": 61, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.929396Z", + "start_time": "2023-08-28T10:04:38.915727Z" + } + }, + "id": "b476688cbb32ba90" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Storage Model\n", + "\n", + "The default data model, which stores information about conversation messages only\n", + "has two slots for storing message details, the session id, and the message dictionary.\n", + "\n", + "If you want to store additional information, like message date, author, language etc.,\n", + "please provide an implementation for a custom message converter.\n", + "\n", + "This example demonstrates how to create a custom message converter, by implementing\n", + "the `BaseMessageConverter` interface." + ], + "metadata": { + "collapsed": false + }, + "id": "2e5337719d5614fd" + }, + { + "cell_type": "code", + "execution_count": 55, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from typing import Any\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n", + "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", + "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", + "\n", + "import sqlalchemy as sa\n", + "from sqlalchemy.orm import declarative_base\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "\n", + "class CustomMessage(Base):\n", + "\t__tablename__ = \"custom_message_store\"\n", + "\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tsession_id = sa.Column(sa.Text)\n", + "\ttype = sa.Column(sa.Text)\n", + "\tcontent = sa.Column(sa.Text)\n", + "\tcreated_at = sa.Column(sa.DateTime)\n", + "\tauthor_email = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverter(BaseMessageConverter):\n", + "\tdef __init__(self, author_email: str):\n", + "\t\tself.author_email = author_email\n", + "\t\n", + "\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n", + "\t\tif sql_message.type == \"human\":\n", + "\t\t\treturn HumanMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"ai\":\n", + "\t\t\treturn AIMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"system\":\n", + "\t\t\treturn SystemMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telse:\n", + "\t\t\traise ValueError(f\"Unknown message type: {sql_message.type}\")\n", + "\t\n", + "\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", + "\t\tnow = datetime.now()\n", + "\t\treturn CustomMessage(\n", + "\t\t\tsession_id=session_id,\n", + "\t\t\ttype=message.type,\n", + "\t\t\tcontent=message.content,\n", + "\t\t\tcreated_at=now,\n", + "\t\t\tauthor_email=self.author_email\n", + "\t\t)\n", + "\t\n", + "\tdef get_sql_model_class(self) -> Any:\n", + "\t\treturn CustomMessage\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverter(\n", + "\t\t\tauthor_email=\"test@example.com\"\n", + "\t\t)\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:41.510498Z", + "start_time": "2023-08-28T10:04:41.494912Z" + } + }, + "id": "fdfde84c07d071bb" + }, + { + "cell_type": "code", + "execution_count": 60, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:43.497990Z", + "start_time": "2023-08-28T10:04:43.492517Z" + } + }, + "id": "4a6a54d8a9e2856f" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Name for Session Column\n", + "\n", + "The session id, a unique token identifying the session, is an important property of\n", + "this subsystem. If your database table stores it in a different column, you can use\n", + "the `session_id_field_name` keyword argument to adjust the name correspondingly." + ], + "metadata": { + "collapsed": false + }, + "id": "622aded629a1adeb" + }, + { + "cell_type": "code", + "execution_count": 57, + "outputs": [], + "source": [ + "import json\n", + "import typing as t\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n", + "from langchain.schema import _message_to_dict\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "class MessageWithDifferentSessionIdColumn(Base):\n", + "\t__tablename__ = \"message_store_different_session_id\"\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tcustom_session_id = sa.Column(sa.Text)\n", + "\tmessage = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverterWithDifferentSessionIdColumn(CrateDBMessageConverter):\n", + " def __init__(self):\n", + " self.model_class = MessageWithDifferentSessionIdColumn\n", + "\n", + " def to_sql_model(self, message: BaseMessage, custom_session_id: str) -> t.Any:\n", + " return self.model_class(\n", + " custom_session_id=custom_session_id, message=json.dumps(_message_to_dict(message))\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverterWithDifferentSessionIdColumn(),\n", + "\t\tsession_id_field_name=\"custom_session_id\",\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 58, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 94a16b8935537..5472f875f05da 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -38,8 +38,8 @@ data, and query it using SQL. ## Features -The CrateDB adapter supports the _Vector Store_ and _Document Loader_ -subsystems of LangChain. +The CrateDB adapter supports the _Vector Store_, _Document Loader_, +and _Conversational Memory_ subsystems of LangChain. ### Vector Store @@ -57,6 +57,10 @@ Supports: `CrateDBLoader` provides loading documents from a database table by an SQL query expression or an SQLAlchemy selectable instance. +### Conversational Memory + +`CrateDBChatMessageHistory` uses CrateDB to manage conversation history. + ## Installation and Setup @@ -157,6 +161,29 @@ if __name__ == "__main__": ``` +## Usage » Conversational Memory + +For a more detailed walkthrough of the `CrateDBChatMessageHistory`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/memory/cratedb_chat_message_history.html). + +```python +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from pprint import pprint + +def main(): + chat_message_history = CrateDBChatMessageHistory( + session_id="test_session", + connection_string="crate://crate@localhost/", + ) + chat_message_history.add_user_message("Hello") + chat_message_history.add_ai_message("Hi") + pprint(chat_message_history) + +if __name__ == "__main__": + main() +``` + + [CrateDB]: https://github.com/crate/crate [CrateDB Cloud]: https://crate.io/product [CrateDB Cloud Console]: https://console.cratedb.cloud/ diff --git a/docs/vercel.json b/docs/vercel.json index 22130dfa09a85..076882d796bc7 100644 --- a/docs/vercel.json +++ b/docs/vercel.json @@ -2952,6 +2952,14 @@ "source": "/docs/integrations/memory/entity_memory_with_sqlite", "destination": "/docs/integrations/memory/sqlite" }, + { + "source": "/en/latest/modules/memory/examples/cratedb_chat_message_history.html", + "destination": "/docs/integrations/memory/cratedb_chat_message_history" + }, + { + "source": "/docs/modules/memory/integrations/cratedb_chat_message_history", + "destination": "/docs/integrations/memory/cratedb_chat_message_history" + }, { "source": "/en/latest/modules/memory/examples/dynamodb_chat_message_history.html", "destination": "/docs/integrations/memory/dynamodb_chat_message_history" diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index a1497e8a12225..88409aff86402 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -2,6 +2,7 @@ CassandraChatMessageHistory, ) from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory +from langchain.memory.chat_message_histories.cratedb import CrateDBChatMessageHistory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.elasticsearch import ( ElasticsearchChatMessageHistory, @@ -34,6 +35,7 @@ "ChatMessageHistory", "CassandraChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py new file mode 100644 index 0000000000000..19007176cb193 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py @@ -0,0 +1,113 @@ +import json +import typing as t + +import sqlalchemy as sa +from cratedb_toolkit.sqlalchemy import ( + patch_inspector, + polyfill_refresh_after_dml, + refresh_table, +) + +from langchain.memory.chat_message_histories.sql import ( + BaseMessageConverter, + SQLChatMessageHistory, +) +from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict + + +def create_message_model(table_name, DynamicBase): # type: ignore + """ + Create a message model for a given table name. + + This is a specialized version for CrateDB for generating integer-based primary keys. + TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant + returning its integer value. + + Args: + table_name: The name of the table to use. + DynamicBase: The base class to use for the model. + + Returns: + The model class. + """ + + # Model is declared inside a function to be able to use a dynamic table name. + class Message(DynamicBase): + __tablename__ = table_name + id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) + session_id = sa.Column(sa.Text) + message = sa.Column(sa.Text) + + return Message + + +class CrateDBMessageConverter(BaseMessageConverter): + """ + The default message converter for CrateDBMessageConverter. + + It is the same as the generic `SQLChatMessageHistory` converter, + but swaps in a different `create_message_model` function. + """ + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, sa.orm.declarative_base()) + + def from_sql_model(self, sql_message: t.Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> t.Any: + return self.model_class + + +class CrateDBChatMessageHistory(SQLChatMessageHistory): + """ + It is the same as the generic `SQLChatMessageHistory` implementation, + but swaps in a different message converter by default. + """ + + DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: t.Optional[BaseMessageConverter] = None, + ): + # FIXME: Refactor elsewhere. + patch_inspector() + + super().__init__( + session_id, + connection_string, + table_name=table_name, + session_id_field_name=session_id_field_name, + custom_message_converter=custom_message_converter, + ) + + # TODO: Check how this can be improved. + polyfill_refresh_after_dml(self.Session) + + def _messages_query(self) -> sa.Select: + """ + Construct an SQLAlchemy selectable to query for messages. + For CrateDB, add an `ORDER BY` clause on the primary key. + """ + selectable = super()._messages_query() + selectable = selectable.order_by(self.sql_model_class.id) + return selectable + + def clear(self) -> None: + """ + Needed for CrateDB to synchronize data because `on_flush` does not catch it. + """ + outcome = super().clear() + with self.Session() as session: + refresh_table(session, self.sql_model_class) + return outcome diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/langchain/tests/integration_tests/memory/test_cratedb.py new file mode 100644 index 0000000000000..2c00b5d2b200b --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_cratedb.py @@ -0,0 +1,170 @@ +import json +import os +from typing import Any, Generator, Tuple + +import pytest +import sqlalchemy as sa +from sqlalchemy import Column, Integer, Text +from sqlalchemy.orm import DeclarativeBase + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict + + +@pytest.fixture() +def connection_string() -> str: + return os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" + ) + + +@pytest.fixture() +def engine(connection_string: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(connection_string, echo=True) + + +@pytest.fixture(autouse=True) +def reset_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS test_table;")) + connection.commit() + + +@pytest.fixture() +def sql_histories( + connection_string: str, +) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]: + """ + Provide the test cases with data fixtures. + """ + message_history = CrateDBChatMessageHistory( + session_id="123", connection_string=connection_string, table_name="test_table" + ) + # Create history for other session + other_history = CrateDBChatMessageHistory( + session_id="456", connection_string=connection_string, table_name="test_table" + ) + + yield message_history, other_history + message_history.clear() + other_history.clear() + + +def test_add_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, _ = sql_histories + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + + messages = history1.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_multiple_sessions( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, history2 = sql_histories + + # first session + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + history1.add_user_message("Whats cracking?") + + # second session + history2.add_user_message("Hellox") + + messages1 = history1.messages + messages2 = history2.messages + + # Ensure the messages are added correctly in the first session + assert len(messages1) == 3, "waat" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + assert len(messages2) == 1 + assert len(messages1) == 3 + assert messages2[0].content == "Hellox" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + +def test_clear_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + assert len(sql_history.messages) == 2 + # Now create another history with different session id + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 2 + # Now clear the first history + sql_history.clear() + assert len(sql_history.messages) == 0 + assert len(other_history.messages) == 1 + + +def test_model_no_session_id_field_error(connection_string: str) -> None: + class Base(DeclarativeBase): + pass + + class Model(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + test_field = Column(Text) + + class CustomMessageConverter(DefaultMessageConverter): + def get_sql_model_class(self) -> Any: + return Model + + with pytest.raises(ValueError): + CrateDBChatMessageHistory( + "test", + connection_string, + custom_message_converter=CustomMessageConverter("test_table"), + ) + + +def test_memory_with_message_store(connection_string: str) -> None: + """ + Test ConversationBufferMemory with a message store. + """ + # Setup CrateDB as a message store. + message_history = CrateDBChatMessageHistory( + connection_string=connection_string, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # Add a few messages. + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # Get the message history from the memory store and turn it into JSON. + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + # Verify the outcome. + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Clear the conversation history, and verify that. + memory.chat_memory.clear() + assert memory.chat_memory.messages == [] From c461c4c8fcf014ff73f69bddac577ac6a9cdddfc Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 27 Oct 2023 16:46:24 +0200 Subject: [PATCH 09/17] CrateDB vector: Fix usage when only reading, and not storing When not adding any embeddings upfront, the runtime model factory was not able to derive the vector dimension size, because the SQLAlchemy models have not been initialized correctly. --- .../langchain/vectorstores/cratedb/base.py | 23 ++++- .../vectorstores/test_cratedb.py | 95 ++++++++++++++++++- 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index ca1c21fad68a7..b6158dd0d4769 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -112,9 +112,26 @@ def __post_init__( # TODO: See what can be improved here. polyfill_refresh_after_dml(self.Session) + # Need to defer initialization, because dimension size + # can only be figured out at runtime. self.CollectionStore = None self.EmbeddingStore = None + def _init_models(self, embedding: List[float]): + """ + Create SQLAlchemy models at runtime, when not established yet. + """ + if self.CollectionStore is not None and self.EmbeddingStore is not None: + return + + size = len(embedding) + self._init_models_with_dimensionality(size=size) + + def _init_models_with_dimensionality(self, size: int): + from langchain.vectorstores.cratedb.model import model_factory + + self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size) + def get_collection( self, session: sqlalchemy.orm.Session ) -> Optional["CollectionStore"]: @@ -140,10 +157,7 @@ def add_embeddings( metadatas: List of metadatas associated with the texts. kwargs: vectorstore specific parameters """ - from langchain.vectorstores.cratedb.model import model_factory - - dimensions = len(embeddings[0]) - self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=dimensions) + self._init_models(embeddings[0]) if self.pre_delete_collection: self.delete_collection() self.create_tables_if_not_exists() @@ -223,6 +237,7 @@ def _query_collection( filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" + self._init_models(embedding) with self.Session() as session: collection = self.get_collection(session) if not collection: diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index d62f0a125f661..8f62919842fc0 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -14,7 +14,10 @@ from langchain.docstore.document import Document from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fake_embeddings import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, +) CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), @@ -40,6 +43,13 @@ def engine() -> sa.Engine: return sa.create_engine(CONNECTION_STRING, echo=False) +@pytest.fixture +def session(engine) -> sa.orm.Session: + with engine.connect() as conn: + with Session(conn) as session: + yield session + + @pytest.fixture(autouse=True) def drop_tables(engine: sa.Engine) -> None: """ @@ -89,6 +99,46 @@ def decode_output( return documents, scores +def ensure_collection(session: sa.orm.Session, name: str): + """ + Create a (fake) collection item. + """ + session.execute( + sa.text( + f""" + CREATE TABLE IF NOT EXISTS collection ( + uuid TEXT, + name TEXT, + cmetadata OBJECT + ); + """ + ) + ) + session.execute( + sa.text( + f""" + CREATE TABLE IF NOT EXISTS embedding ( + uuid TEXT, + collection_id TEXT, + embedding FLOAT_VECTOR(123), + document TEXT, + cmetadata OBJECT, + custom_id TEXT + ); + """ + ) + ) + try: + session.execute( + sa.text( + f"INSERT INTO collection (uuid, name, cmetadata) VALUES ('uuid-{name}', '{name}', {{}});" + ) + ) + session.execute(sa.text("REFRESH TABLE collection")) + except sa.exc.IntegrityError: + pass + + class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): """Fake embeddings functionality for testing.""" @@ -103,6 +153,19 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] +class ConsistentFakeEmbeddingsWithAdaDimension( + FakeEmbeddingsWithAdaDimension, ConsistentFakeEmbeddings +): + """ + Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts. + + Other than this, they also have a dimensionality, which is important in this case. + """ + + pass + + def test_cratedb_texts() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -274,6 +337,36 @@ def test_cratedb_collection_no_embedding_dimension() -> None: ) +def test_cratedb_collection_read_only(session) -> None: + """ + Test using a collection, without adding any embeddings upfront. + + This happens when just invoking the "retrieval" case. + + In this scenario, embedding dimensionality needs to be figured out + from the supplied `embedding_function`. + """ + + # Create a fake collection item. + ensure_collection(session, "baz2") + + # This test case needs an embedding _with_ dimensionality. + # Otherwise, the data access layer is unable to figure it + # out at runtime. + embedding = ConsistentFakeEmbeddingsWithAdaDimension() + + vectorstore = CrateDBVectorSearch( + collection_name="baz2", + connection_string=CONNECTION_STRING, + embedding_function=embedding, + ) + output = vectorstore.similarity_search("foo", k=1) + + # No documents/embeddings have been loaded, the collection is empty. + # This is why there are also no results. + assert output == [] + + def test_cratedb_with_filter_in_set() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] From 93e897051e991f09f3497e31134fd1c13801fce1 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 27 Oct 2023 22:16:51 +0200 Subject: [PATCH 10/17] CrateDB vector: Unable to invoke `add_embeddings` without embeddings --- libs/langchain/langchain/vectorstores/cratedb/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index b6158dd0d4769..e7c651bea9822 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -157,6 +157,8 @@ def add_embeddings( metadatas: List of metadatas associated with the texts. kwargs: vectorstore specific parameters """ + if not embeddings: + return [] self._init_models(embeddings[0]) if self.pre_delete_collection: self.delete_collection() From a104222e98edbb7c94cd8c3946f544010a340d1a Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 20 Nov 2023 21:34:09 +0100 Subject: [PATCH 11/17] CrateDB vector: Improve SQLAlchemy model factory From now on, _all_ instances of SQLAlchemy model types will be created at runtime through the `ModelFactory` utility. By using `__table_args__ = {"keep_existing": True}` on the ORM entity definitions, this seems to work well, even with multiple invocations of `CrateDBVectorSearch.from_texts()` using different `collection_name` argument values. While being at it, this patch also fixes a few linter errors. --- .../vectorstores/cratedb/__init__.py | 3 +- .../langchain/vectorstores/cratedb/base.py | 68 +++---- .../langchain/vectorstores/cratedb/model.py | 168 ++++++++++-------- .../vectorstores/test_cratedb.py | 37 ++-- 4 files changed, 154 insertions(+), 122 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py index 303a52babeaea..14b02ad126867 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/__init__.py +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -1,6 +1,5 @@ -from .base import BaseModel, CrateDBVectorSearch +from .base import CrateDBVectorSearch __all__ = [ - "BaseModel", "CrateDBVectorSearch", ] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index e7c651bea9822..ec3c4c19d70a6 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -2,7 +2,6 @@ import enum import math -import uuid from typing import ( Any, Callable, @@ -20,11 +19,12 @@ polyfill_refresh_after_dml, refresh_table, ) -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import sessionmaker from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.cratedb.model import ModelFactory from langchain.vectorstores.pgvector import PGVector @@ -38,23 +38,10 @@ class DistanceStrategy(str, enum.Enum): DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN -Base = declarative_base() # type: Any -# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" -def generate_uuid() -> str: - return str(uuid.uuid4()) - - -class BaseModel(Base): - """Base model for the SQL stores.""" - - __abstract__ = True - uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid) - - def _results_to_docs(docs_and_scores: Any) -> List[Document]: """Return docs from docs and scores.""" return [doc for doc, _ in docs_and_scores] @@ -114,30 +101,47 @@ def __post_init__( # Need to defer initialization, because dimension size # can only be figured out at runtime. - self.CollectionStore = None - self.EmbeddingStore = None + self.BaseModel = None + self.CollectionStore = None # type: ignore[assignment] + self.EmbeddingStore = None # type: ignore[assignment] + + def __del__(self) -> None: + """ + Work around premature session close. + + sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound + to a Session; lazy load operation of attribute 'embeddings' cannot proceed. + -- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3 + + TODO: Review! + """ # noqa: E501 + pass - def _init_models(self, embedding: List[float]): + def _init_models(self, embedding: List[float]) -> None: """ Create SQLAlchemy models at runtime, when not established yet. """ + + # TODO: Use a better way to run this only once. if self.CollectionStore is not None and self.EmbeddingStore is not None: return size = len(embedding) self._init_models_with_dimensionality(size=size) - def _init_models_with_dimensionality(self, size: int): - from langchain.vectorstores.cratedb.model import model_factory - - self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size) + def _init_models_with_dimensionality(self, size: int) -> None: + mf = ModelFactory(dimensions=size) + self.BaseModel, self.CollectionStore, self.EmbeddingStore = ( + mf.BaseModel, # type: ignore[assignment] + mf.CollectionStore, + mf.EmbeddingStore, + ) - def get_collection( - self, session: sqlalchemy.orm.Session - ) -> Optional["CollectionStore"]: + def get_collection(self, session: sqlalchemy.orm.Session) -> Any: if self.CollectionStore is None: raise RuntimeError( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) return self.CollectionStore.get_by_name(session, self.collection_name) @@ -170,15 +174,17 @@ def add_embeddings( def create_tables_if_not_exists(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.create_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.create_all(self._engine) def drop_tables(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.drop_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.drop_all(self._engine) def delete( self, diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index ee42e7269dc9d..b8b14c05010f5 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -1,84 +1,112 @@ -from functools import lru_cache -from typing import Optional, Tuple +import uuid +from typing import Any, Optional, Tuple import sqlalchemy from crate.client.sqlalchemy.types import ObjectType -from sqlalchemy.orm import Session, relationship +from sqlalchemy.orm import Session, declarative_base, relationship -from langchain.vectorstores.cratedb.base import BaseModel from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector -@lru_cache -def model_factory(dimensions: int): - class CollectionStore(BaseModel): - """Collection store.""" - - __tablename__ = "collection" - - name = sqlalchemy.Column(sqlalchemy.String) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - try: - return ( - session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """ - Get or create a collection. - Returns [Collection, bool] where the bool is True if the collection was created. - """ - created = False - collection = cls.get_by_name(session, name) - if collection: +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class ModelFactory: + """Provide SQLAlchemy model objects at runtime.""" + + def __init__(self, dimensions: Optional[int] = None): + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for operations like deleting records. + self.dimensions = dimensions or 1024 + + Base: Any = declarative_base() + + # Optional: Use a custom schema for the langchain tables. + # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + + class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column( + sqlalchemy.String, primary_key=True, default=generate_uuid + ) + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "collection" + __table_args__ = {"keep_existing": True} + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + try: + return ( + session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True + if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True return collection, created - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() - created = True - return collection, created + class EmbeddingStore(BaseModel): + """Embedding store.""" - class EmbeddingStore(BaseModel): - """Embedding store.""" + __tablename__ = "embedding" + __table_args__ = {"keep_existing": True} - __tablename__ = "embedding" + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") - collection_id = sqlalchemy.Column( - sqlalchemy.String, - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship("CollectionStore", back_populates="embeddings") + embedding = sqlalchemy.Column(FloatVector(self.dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) - embedding = sqlalchemy.Column(FloatVector(dimensions)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - # custom_id : any user defined id - custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - - return CollectionStore, EmbeddingStore + self.Base = Base + self.BaseModel = BaseModel + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index 8f62919842fc0..8f054fc07a0b3 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,7 +5,7 @@ docker-compose -f cratedb.yml up """ import os -from typing import List, Tuple +from typing import Generator, List, Tuple import pytest import sqlalchemy as sa @@ -13,7 +13,8 @@ from sqlalchemy.orm import Session from langchain.docstore.document import Document -from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch +from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, @@ -44,7 +45,7 @@ def engine() -> sa.Engine: @pytest.fixture -def session(engine) -> sa.orm.Session: +def session(engine: sa.Engine) -> Generator[sa.orm.Session, None, None]: with engine.connect() as conn: with Session(conn) as session: yield session @@ -56,7 +57,8 @@ def drop_tables(engine: sa.Engine) -> None: Drop database tables. """ try: - BaseModel.metadata.drop_all(engine, checkfirst=False) + mf = ModelFactory() + mf.BaseModel.metadata.drop_all(engine, checkfirst=False) except Exception as ex: if "RelationUnknown" not in str(ex): raise @@ -69,18 +71,13 @@ def prune_tables(engine: sa.Engine) -> None: """ with engine.connect() as conn: with Session(conn) as session: - from langchain.vectorstores.cratedb.model import model_factory - - # While it does not have any function here, you will still need to supply a - # dummy dimension size value for deleting records from tables. - CollectionStore, EmbeddingStore = model_factory(dimensions=1024) - + mf = ModelFactory() try: - session.query(CollectionStore).delete() + session.query(mf.CollectionStore).delete() except ProgrammingError: pass try: - session.query(EmbeddingStore).delete() + session.query(mf.EmbeddingStore).delete() except ProgrammingError: pass @@ -99,13 +96,13 @@ def decode_output( return documents, scores -def ensure_collection(session: sa.orm.Session, name: str): +def ensure_collection(session: sa.orm.Session, name: str) -> None: """ Create a (fake) collection item. """ session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS collection ( uuid TEXT, name TEXT, @@ -116,7 +113,7 @@ def ensure_collection(session: sa.orm.Session, name: str): ) session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS embedding ( uuid TEXT, collection_id TEXT, @@ -131,7 +128,8 @@ def ensure_collection(session: sa.orm.Session, name: str): try: session.execute( sa.text( - f"INSERT INTO collection (uuid, name, cmetadata) VALUES ('uuid-{name}', '{name}', {{}});" + f"INSERT INTO collection (uuid, name, cmetadata) " + f"VALUES ('uuid-{name}', '{name}', {{}});" ) ) session.execute(sa.text("REFRESH TABLE collection")) @@ -325,7 +323,7 @@ def test_cratedb_collection_with_metadata() -> None: def test_cratedb_collection_no_embedding_dimension() -> None: """Test end to end collection construction""" cratedb_vector = CrateDBVectorSearch( - embedding_function=None, + embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, pre_delete_collection=True, ) @@ -333,11 +331,12 @@ def test_cratedb_collection_no_embedding_dimension() -> None: with pytest.raises(RuntimeError) as ex: cratedb_vector.get_collection(session) assert ex.match( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) -def test_cratedb_collection_read_only(session) -> None: +def test_cratedb_collection_read_only(session: Session) -> None: """ Test using a collection, without adding any embeddings upfront. From fde3486f37e94ed40ec42175cf9e3c01bc7b3206 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 00:25:57 +0100 Subject: [PATCH 12/17] CrateDB vector: Fix cascading deletes When deleting a collection, also delete its associated embeddings. --- .../langchain/vectorstores/cratedb/model.py | 3 +- .../vectorstores/test_cratedb.py | 46 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index b8b14c05010f5..656de41bf4d45 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -45,7 +45,8 @@ class CollectionStore(BaseModel): embeddings = relationship( "EmbeddingStore", back_populates="collection", - passive_deletes=True, + cascade="all, delete-orphan", + passive_deletes=False, ) @classmethod diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index 8f054fc07a0b3..d573843d2f02f 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -299,6 +299,52 @@ def test_cratedb_with_filter_no_match() -> None: assert output == [] +def test_cratedb_collection_delete() -> None: + """ + Test end to end collection construction and deletion. + Uses two different collections of embeddings. + """ + + store_foo = CrateDBVectorSearch.from_texts( + texts=["foo"], + collection_name="test_collection_foo", + collection_metadata={"category": "foo"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "foo"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store_bar = CrateDBVectorSearch.from_texts( + texts=["bar"], + collection_name="test_collection_bar", + collection_metadata={"category": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "bar"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = store_foo.Session() + + # Verify data in database. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Delete first collection. + store_foo.delete_collection() + + # Verify that the "foo" collection has been deleted. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo is None + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify that associated embeddings also have been deleted. + embeddings_count = session.query(store_foo.EmbeddingStore).count() + assert embeddings_count == 1 + + def test_cratedb_collection_with_metadata() -> None: """Test end to end collection construction""" texts = ["foo", "bar", "baz"] From 624bf2ea3bac4f6ca0a3f4023f59cf84eed5862c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 13:12:21 +0100 Subject: [PATCH 13/17] CrateDB vector: Add CrateDBVectorSearchMultiCollection It is a special adapter which provides similarity search across multiple collections. It can not be used for indexing documents. --- docs/docs/integrations/providers/cratedb.mdx | 3 + .../integrations/vectorstores/cratedb.ipynb | 88 +++++++----- .../vectorstores/cratedb/__init__.py | 2 + .../langchain/vectorstores/cratedb/base.py | 22 ++- .../vectorstores/cratedb/extended.py | 92 ++++++++++++ .../langchain/vectorstores/cratedb/model.py | 15 +- .../vectorstores/fake_embeddings.py | 1 - .../vectorstores/test_cratedb.py | 131 +++++++++++++----- 8 files changed, 281 insertions(+), 73 deletions(-) create mode 100644 libs/langchain/langchain/vectorstores/cratedb/extended.py diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 5472f875f05da..4764a7ad92369 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -106,6 +106,9 @@ export OPENAI_API_KEY=foobar # FIXME export CRATEDB_CONNECTION_STRING=crate://crate@localhost ``` +### Example + +Load and index documents, and invoke query. ```python from langchain.document_loaders import UnstructuredURLLoader from langchain.embeddings.openai import OpenAIEmbeddings diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb index 462e721bfff40..06430e6355ae9 100644 --- a/docs/docs/integrations/vectorstores/cratedb.ipynb +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -182,7 +182,11 @@ { "cell_type": "markdown", "source": [ - "Next, you will read input data, and tokenize it." + "## Load and Index Documents\n", + "\n", + "Next, you will read input data, and tokenize it. The module will create a table\n", + "with the name of the collection. Make sure the collection name is unique, and\n", + "that you have the permission to create a table." ], "metadata": { "collapsed": false @@ -196,7 +200,18 @@ "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", - "docs = text_splitter.split_documents(documents)" + "docs = text_splitter.split_documents(documents)\n", + "\n", + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "db = CrateDBVectorSearch.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" ], "metadata": { "collapsed": false, @@ -208,39 +223,15 @@ { "cell_type": "markdown", "source": [ - "## Similarity Search with Euclidean Distance (Default)\n", + "## Search Documents\n", "\n", - "The module will create a table with the name of the collection. Make sure\n", - "the collection name is unique and that you have the permission to create\n", - "a table." + "### Similarity Search with Euclidean Distance\n", + "Searching by euclidean distance is the default." ], "metadata": { "collapsed": false } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-09T08:04:16.696625Z", - "start_time": "2023-09-09T08:02:31.817790Z" - } - }, - "outputs": [], - "source": [ - "COLLECTION_NAME = \"state_of_the_union_test\"\n", - "\n", - "embeddings = OpenAIEmbeddings()\n", - "\n", - "db = CrateDBVectorSearch.from_documents(\n", - " embedding=embeddings,\n", - " documents=docs,\n", - " collection_name=COLLECTION_NAME,\n", - " connection_string=CONNECTION_STRING,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -277,7 +268,7 @@ { "cell_type": "markdown", "source": [ - "## Maximal Marginal Relevance Search (MMR)\n", + "### Maximal Marginal Relevance Search (MMR)\n", "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." ], "metadata": { @@ -318,11 +309,40 @@ } } }, + { + "cell_type": "markdown", + "source": [ + "### Searching in Multiple Collections\n", + "`CrateDBVectorSearchMultiCollection` is a special adapter which provides similarity search across\n", + "multiple collections. It can not be used for indexing documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection\n", + "\n", + "multisearch = CrateDBVectorSearchMultiCollection(\n", + " collection_names=[\"test_collection_1\", \"test_collection_2\"],\n", + " embedding_function=embeddings,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "docs_with_score = multisearch.similarity_search_with_score(query)" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Working with the vector store\n", + "## Working with the Vector Store\n", "\n", "In the example above, you created a vector store from scratch. When\n", "aiming to work with an existing vector store, you can initialize it directly." @@ -345,7 +365,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Add documents\n", + "### Add Documents\n", "\n", "You can also add documents to an existing vector store." ] @@ -390,7 +410,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Overwriting a vector store\n", + "### Overwriting a Vector Store\n", "\n", "If you have an existing collection, you can overwrite it by using `from_documents`,\n", "aad setting `pre_delete_collection = True`." @@ -433,7 +453,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Using a vector store as a retriever" + "### Using a Vector Store as a Retriever" ] }, { diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py index 14b02ad126867..62462bce1eba9 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/__init__.py +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -1,5 +1,7 @@ from .base import CrateDBVectorSearch +from .extended import CrateDBVectorSearchMultiCollection __all__ = [ "CrateDBVectorSearch", + "CrateDBVectorSearchMultiCollection", ] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index ec3c4c19d70a6..922ba2ed659d6 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -250,8 +250,26 @@ def _query_collection( collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") + return self._query_collection_multi( + collections=[collection], embedding=embedding, k=k, filter=filter + ) - filter_by = self.EmbeddingStore.collection_id == collection.uuid + def _query_collection_multi( + self, + collections: List[Any], + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + self._init_models(embedding) + + collection_names = [coll.name for coll in collections] + collection_uuids = [coll.uuid for coll in collections] + self.logger.info(f"Querying collections: {collection_names}") + + with self.Session() as session: + filter_by = self.EmbeddingStore.collection_id.in_(collection_uuids) if filter is not None: filter_clauses = [] @@ -271,7 +289,7 @@ def _query_collection( ) # type: ignore[assignment] filter_clauses.append(filter_by_metadata) - filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # type: ignore[assignment] _type = self.EmbeddingStore diff --git a/libs/langchain/langchain/vectorstores/cratedb/extended.py b/libs/langchain/langchain/vectorstores/cratedb/extended.py new file mode 100644 index 0000000000000..9266438787368 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/extended.py @@ -0,0 +1,92 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, +) + +import sqlalchemy +from sqlalchemy.orm import sessionmaker + +from langchain.schema.embeddings import Embeddings +from langchain.vectorstores.cratedb.base import ( + DEFAULT_DISTANCE_STRATEGY, + CrateDBVectorSearch, + DistanceStrategy, +) +from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME + + +class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch): + """ + Provide functionality for searching multiple collections. + It can not be used for indexing documents. + + To use it, you should have the ``crate[sqlalchemy]`` Python package installed. + + Synopsis:: + + from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection + + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["collection_foo", "collection_bar"], + embedding_function=embeddings, + connection_string=CONNECTION_STRING, + ) + docs_with_score = multisearch.similarity_search_with_score(query) + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME], + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, # type: ignore[arg-type] + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + *, + connection: Optional[sqlalchemy.engine.Connection] = None, + engine_args: Optional[dict[str, Any]] = None, + ) -> None: + self.connection_string = connection_string + self.embedding_function = embedding_function + self.collection_names = collection_names + self._distance_strategy = distance_strategy # type: ignore[assignment] + self.logger = logger or logging.getLogger(__name__) + self.override_relevance_score_fn = relevance_score_fn + self.engine_args = engine_args or {} + # Create a connection if not provided, otherwise use the provided connection + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + self._conn = connection if connection else self.connect() + self.__post_init__() + + @classmethod + def _from(cls, *args: List, **kwargs: Dict): # type: ignore[no-untyped-def,override] + raise NotImplementedError("This adapter can not be used for indexing documents") + + def get_collections(self, session: sqlalchemy.orm.Session) -> Any: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_names(session, self.collection_names) + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query multiple collections.""" + self._init_models(embedding) + with self.Session() as session: + collections = self.get_collections(session) + if not collections: + raise ValueError("No collections found") + return self._query_collection_multi( + collections=collections, embedding=embedding, k=k, filter=filter + ) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index 656de41bf4d45..1aec9b49a7260 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import sqlalchemy from crate.client.sqlalchemy.types import ObjectType @@ -62,6 +62,19 @@ def get_by_name( raise return None + @classmethod + def get_by_names( + cls, session: Session, names: List[str] + ) -> Optional["List[CollectionStore]"]: + try: + return ( + session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + @classmethod def get_or_create( cls, diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py index 87ea1edc6a00b..209e933b24b61 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py @@ -52,7 +52,6 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" - return self.embed_documents([text])[0] if text not in self.known_texts: return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] return [float(1.0)] * (self.dimensionality - 1) + [ diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index d573843d2f02f..5a732ca5332f9 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,15 +5,17 @@ docker-compose -f cratedb.yml up """ import os -from typing import Generator, List, Tuple +from typing import Dict, Generator, List, Tuple import pytest import sqlalchemy as sa +import sqlalchemy.orm from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm import Session from langchain.docstore.document import Document from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.extended import CrateDBVectorSearchMultiCollection from langchain.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, @@ -151,17 +153,17 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] -class ConsistentFakeEmbeddingsWithAdaDimension( - FakeEmbeddingsWithAdaDimension, ConsistentFakeEmbeddings -): +class ConsistentFakeEmbeddingsWithAdaDimension(ConsistentFakeEmbeddings): """ - Fake embeddings which remember all the texts seen so far to return consistent - vectors for the same texts. + Fake embeddings which remember all the texts seen so far to return + consistent vectors for the same texts. - Other than this, they also have a dimensionality, which is important in this case. + Other than this, they also have a fixed dimensionality, which is + important in this case. """ - pass + def __init__(self, *args: List, **kwargs: Dict) -> None: + super().__init__(dimensionality=ADA_TOKEN_COUNT) def test_cratedb_texts() -> None: @@ -223,12 +225,7 @@ def test_cratedb_with_metadatas_with_scores() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1) - # TODO: Original: - # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 - assert output in [ - [(Document(page_content="foo", metadata={"page": "0"}), 1.0828735)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.1307646)], - ] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)] def test_cratedb_with_filter_match() -> None: @@ -247,9 +244,8 @@ def test_cratedb_with_filter_match() -> None: # TODO: Original: # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 assert output in [ - [(Document(page_content="foo", metadata={"page": "0"}), 1.2615292)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.3979403)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.5065275)], + [(Document(page_content="foo", metadata={"page": "0"}), 2.1307645)], + [(Document(page_content="foo", metadata={"page": "0"}), 2.3150668)], ] @@ -265,10 +261,9 @@ def test_cratedb_with_filter_distant_match() -> None: connection_string=CONNECTION_STRING, pre_delete_collection=True, ) + output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"}) # TODO: Original: - # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) - output = docsearch.similarity_search_with_score("foo", k=3, filter={"page": "2"}) - # TODO: Original: + # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) # noqa: E501 # assert output == [ # (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501 # ] @@ -277,9 +272,10 @@ def test_cratedb_with_filter_distant_match() -> None: Document(page_content="baz", metadata={"page": "2"}), ] assert scores in [ - [0.5], - [0.6], - [0.7], + [1.3], + [1.5], + [1.6], + [1.7], ] @@ -439,7 +435,7 @@ def test_cratedb_with_filter_in_set() -> None: Document(page_content="foo", metadata={"page": "0"}), Document(page_content="baz", metadata={"page": "2"}), ] - assert scores == [2.1, 1.3] + assert scores == [3.0, 2.2] def test_cratedb_delete_docs() -> None: @@ -498,7 +494,7 @@ def test_cratedb_relevance_score() -> None: Document(page_content="bar", metadata={"page": "1"}), Document(page_content="baz", metadata={"page": "2"}), ] - assert scores == [0.8, 0.4, 0.2] + assert scores == [1.4, 1.1, 0.8] def test_cratedb_retriever_search_threshold() -> None: @@ -516,9 +512,7 @@ def test_cratedb_retriever_search_threshold() -> None: retriever = docsearch.as_retriever( search_type="similarity_score_threshold", - # TODO: Original: - # search_kwargs={"k": 3, "score_threshold": 0.999}, - search_kwargs={"k": 3, "score_threshold": 0.333}, + search_kwargs={"k": 3, "score_threshold": 0.999}, ) output = retriever.get_relevant_documents("summer") assert output == [ @@ -574,10 +568,77 @@ def test_cratedb_max_marginal_relevance_search_with_score() -> None: pre_delete_collection=True, ) output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) - # TODO: Original: - # assert output == [(Document(page_content="foo"), 0.0)] - assert output in [ - [(Document(page_content="foo"), 1.0606961)], - [(Document(page_content="foo"), 1.0828735)], - [(Document(page_content="foo"), 1.1307646)], - ] + assert output == [(Document(page_content="foo"), 2.0)] + + +def test_cratedb_multicollection_search_success() -> None: + """ + `CrateDBVectorSearchMultiCollection` provides functionality for + searching multiple collections. + """ + + store_1 = CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection_1", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + _ = CrateDBVectorSearch.from_texts( + texts=["John", "Doe"], + collection_name="test_collection_2", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + # Probe the first store. + output = store_1.similarity_search("Räuber", k=1) + assert Document(page_content="Räuber") in output[:2] + output = store_1.similarity_search("Hotzenplotz", k=1) + assert Document(page_content="Hotzenplotz") in output[:2] + output = store_1.similarity_search("John Doe", k=1) + assert Document(page_content="Räuber") in output[:2] + + # Probe the multi-store. + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["test_collection_1", "test_collection_2"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + output = multisearch.similarity_search("Räuber Hotzenplotz", k=2) + assert Document(page_content="Räuber") in output[:2] + output = multisearch.similarity_search("John Doe", k=2) + assert Document(page_content="John") in output[:2] + + +def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: + """ + `CrateDBVectorSearchMultiCollection` does not provide functionality for + indexing documents. + """ + + with pytest.raises(NotImplementedError) as ex: + CrateDBVectorSearchMultiCollection.from_texts( + texts=["foo"], + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + assert ex.match("This adapter can not be used for indexing documents") + + +def test_cratedb_multicollection_search_no_collections() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when not able to identify + collections to search in. + """ + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ValueError) as ex: + store.similarity_search("foo") + assert ex.match("No collections found") From cf22e81efec3322bda7fbc601d031fac176ca5f4 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 16:32:21 +0100 Subject: [PATCH 14/17] CrateDB vector: Improve SQLAlchemy data model query utility functions The CrateDB adapter works a bit different compared to the pgvector adapter it is building upon: Because the dimensionality of the vector field needs to be specified at table creation time, but because it is also a runtime parameter in LangChain, the table creation needs to be delayed. In some cases, the tables do not exist yet, but this is only relevant for the case when the user requests to pre-delete the collection, using the `pre_delete_collection` argument. So, do the error handling only there instead, and _not_ on the generic data model utility functions. --- .../langchain/vectorstores/cratedb/base.py | 16 ++++++++++- .../langchain/vectorstores/cratedb/model.py | 18 ++----------- .../vectorstores/test_cratedb.py | 27 ++++++++++++++++++- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index 922ba2ed659d6..166371325e190 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -164,10 +164,24 @@ def add_embeddings( if not embeddings: return [] self._init_models(embeddings[0]) + + # When the user requested to delete the collection before running subsequent + # operations on it, run the deletion gracefully if the table does not exist + # yet. if self.pre_delete_collection: - self.delete_collection() + try: + self.delete_collection() + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + + # Tables need to be created at runtime, because the `EmbeddingStore.embedding` + # field, a `FloatVector`, needs to be initialized with a dimensionality + # parameter, which is only obtained at runtime. self.create_tables_if_not_exists() self.create_collection() + + # After setting up the table/collection at runtime, add embeddings. return super().add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index 1aec9b49a7260..c540ba2eb217f 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -53,27 +53,13 @@ class CollectionStore(BaseModel): def get_by_name( cls, session: Session, name: str ) -> Optional["CollectionStore"]: - try: - return ( - session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None + return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] @classmethod def get_by_names( cls, session: Session, names: List[str] ) -> Optional["List[CollectionStore]"]: - try: - return ( - session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None + return session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] @classmethod def get_or_create( diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index 5a732ca5332f9..1862ca895733b 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,6 +5,7 @@ docker-compose -f cratedb.yml up """ import os +import re from typing import Dict, Generator, List, Tuple import pytest @@ -628,12 +629,36 @@ def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: assert ex.match("This adapter can not be used for indexing documents") -def test_cratedb_multicollection_search_no_collections() -> None: +def test_cratedb_multicollection_search_table_does_not_exist() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when the `collection` + table does not exist. + """ + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ProgrammingError) as ex: + store.similarity_search("foo") + assert ex.match(re.escape("RelationUnknown[Relation 'collection' unknown]")) + + +def test_cratedb_multicollection_search_unknown_collection() -> None: """ `CrateDBVectorSearchMultiCollection` will fail when not able to identify collections to search in. """ + CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store = CrateDBVectorSearchMultiCollection( collection_names=["unknown"], embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), From ef485de7287ec3139eaac0890e87538913088f2a Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 16:45:01 +0100 Subject: [PATCH 15/17] CrateDB vector: Improve testing when initialized without dimensionality --- .../vectorstores/test_cratedb.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py index 1862ca895733b..86370439e6dce 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cratedb.py @@ -364,11 +364,12 @@ def test_cratedb_collection_with_metadata() -> None: def test_cratedb_collection_no_embedding_dimension() -> None: - """Test end to end collection construction""" + """ + Verify that addressing collections fails when not specifying dimensions. + """ cratedb_vector = CrateDBVectorSearch( embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, - pre_delete_collection=True, ) session = Session(cratedb_vector.connect()) with pytest.raises(RuntimeError) as ex: @@ -667,3 +668,20 @@ def test_cratedb_multicollection_search_unknown_collection() -> None: with pytest.raises(ValueError) as ex: store.similarity_search("foo") assert ex.match("No collections found") + + +def test_cratedb_multicollection_no_embedding_dimension() -> None: + """ + Verify that addressing collections fails when not specifying dimensions. + """ + store = CrateDBVectorSearchMultiCollection( + embedding_function=None, # type: ignore[arg-type] + connection_string=CONNECTION_STRING, + ) + session = Session(store.connect()) + with pytest.raises(RuntimeError) as ex: + store.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) From 8574c91c9854ed110f9808d70374e85d94abf3d8 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 15:11:16 +0100 Subject: [PATCH 16/17] pgvector: Use SA's `bulk_save_objects` method for inserting embeddings The performance gains can be substantially. --- libs/langchain/langchain/vectorstores/cratedb/base.py | 4 +++- libs/langchain/langchain/vectorstores/pgvector.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index 166371325e190..6abf327e940e1 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -182,9 +182,11 @@ def add_embeddings( self.create_collection() # After setting up the table/collection at runtime, add embeddings. - return super().add_embeddings( + embedding_ids = super().add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) + refresh_table(self.Session(), self.EmbeddingStore) + return embedding_ids def create_tables_if_not_exists(self) -> None: """ diff --git a/libs/langchain/langchain/vectorstores/pgvector.py b/libs/langchain/langchain/vectorstores/pgvector.py index 0b2d97d6277e6..e9ec56c71df3c 100644 --- a/libs/langchain/langchain/vectorstores/pgvector.py +++ b/libs/langchain/langchain/vectorstores/pgvector.py @@ -299,6 +299,7 @@ def add_embeddings( collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") + documents = [] for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): embedding_store = self.EmbeddingStore( embedding=embedding, @@ -307,7 +308,8 @@ def add_embeddings( custom_id=id, collection_id=collection.uuid, ) - session.add(embedding_store) + documents.append(embedding_store) + session.bulk_save_objects(documents) session.commit() return ids From 8e0bea6b6a1990836508fd539f90afcb4ba7d2f7 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 21:47:15 +0100 Subject: [PATCH 17/17] CrateDB: Add support for SQLRecordManager It does not work, because this subsystem uses composite unique keys in combination with an `ON CONFLICT DO UPDATE` operation, on behalf of the model entity definition `UpsertionRecord`. Because the composite uniqueness constraint is currently being emulated already, it can't also emulate ON CONFLICT behaviour on top easily. __table_args__ = ( UniqueConstraint("key", "namespace", name="uix_key_namespace"), Index("ix_key_namespace", "key", "namespace"), ) stmt = insert_stmt.on_conflict_do_update( [UpsertionRecord.key, UpsertionRecord.namespace], ... ) --- .../langchain/indexes/_sql_record_manager.py | 4 +- .../integration_tests/indexes/__init__.py | 0 .../test_cratedb_sql_record_manager.py | 117 ++++++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 libs/langchain/tests/integration_tests/indexes/__init__.py create mode 100644 libs/langchain/tests/integration_tests/indexes/test_cratedb_sql_record_manager.py diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index 14e2355af223a..542d51456967b 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -199,7 +199,7 @@ def get_time(self) -> float: # in a day (24 hours * 60 minutes * 60 seconds) if self.dialect == "sqlite": query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") - elif self.dialect == "postgresql": + elif self.dialect in ["crate", "postgresql"]: query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") else: raise NotImplementedError(f"Not implemented for dialect {self.dialect}") @@ -283,7 +283,7 @@ def update( ] with self._make_session() as session: - if self.dialect == "sqlite": + if self.dialect in ["crate", "sqlite"]: from sqlalchemy.dialects.sqlite import insert as sqlite_insert # Note: uses SQLite insert to make on_conflict_do_update work. diff --git a/libs/langchain/tests/integration_tests/indexes/__init__.py b/libs/langchain/tests/integration_tests/indexes/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/tests/integration_tests/indexes/test_cratedb_sql_record_manager.py b/libs/langchain/tests/integration_tests/indexes/test_cratedb_sql_record_manager.py new file mode 100644 index 0000000000000..8b6a3857f9e49 --- /dev/null +++ b/libs/langchain/tests/integration_tests/indexes/test_cratedb_sql_record_manager.py @@ -0,0 +1,117 @@ +import os +import typing as t + +import pytest +import sqlalchemy as sa + +from langchain.indexes._sql_record_manager import Base, SQLRecordManager +from langchain.vectorstores.cratedb import CrateDBVectorSearch + +CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( + driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), + host=os.environ.get("TEST_CRATEDB_HOST", "localhost"), + port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")), + database=os.environ.get("TEST_CRATEDB_DATABASE", "testdrive"), + user=os.environ.get("TEST_CRATEDB_USER", "crate"), + password=os.environ.get("TEST_CRATEDB_PASSWORD", ""), +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture(scope="session", autouse=True) +def dialect_patch_session(session_mocker: t.Any) -> None: + """ + Patch the CrateDB SQLAlchemy dialect to ignore INDEX constraints. + """ + import warnings + + from crate.client.sqlalchemy.compiler import CrateDDLCompiler + + def visit_create_index( + self: t.Type[CrateDDLCompiler], *args: t.List, **kwargs: t.Dict + ) -> str: + """ + CrateDB does not support index constraints. + + CREATE INDEX ix_upsertion_record_group_id ON upsertion_record (group_id) + """ + warnings.warn( + "CrateDB does not support index constraints, " + "they will be omitted when generating DDL statements." + ) + return "SELECT 1;" + + session_mocker.patch( + "crate.client.sqlalchemy.compiler.CrateDDLCompiler.visit_create_index", + visit_create_index, + ) + + +@pytest.fixture(autouse=True) +def dialect_patch_function(monkeypatch: t.Any) -> None: + """ + Patch the CrateDB SQLAlchemy dialect to handle `INSERT ... ON CONFLICT` + operations like PostgreSQL. + """ + from crate.client.sqlalchemy.compiler import CrateCompiler + from sqlalchemy.dialects.postgresql.base import PGCompiler + + monkeypatch.setattr( + CrateCompiler, + "_on_conflict_target", + PGCompiler._on_conflict_target, + raising=False, + ) + monkeypatch.setattr( + CrateCompiler, + "visit_on_conflict_do_nothing", + PGCompiler.visit_on_conflict_do_nothing, + raising=False, + ) + monkeypatch.setattr( + CrateCompiler, + "visit_on_conflict_do_update", + PGCompiler.visit_on_conflict_do_update, + raising=False, + ) + + +@pytest.fixture(autouse=True) +def drop_tables(engine: sa.Engine) -> None: + """ + Drop database tables before invoking test case function. + """ + try: + Base.metadata.drop_all(engine, checkfirst=False) + except Exception as ex: + if "RelationUnknown" not in str(ex): + raise + + +@pytest.fixture() +def manager() -> SQLRecordManager: + """Initialize the test database and yield the TimestampedSet instance.""" + # Initialize and yield the TimestampedSet instance + record_manager = SQLRecordManager("kittens", db_url=CONNECTION_STRING) + record_manager.create_schema() + return record_manager + + +def test_update(manager: SQLRecordManager) -> None: + """Test updating records in the database.""" + # no keys should be present in the set + read_keys = manager.list_keys() + assert read_keys == [] + # Insert records + keys = ["key1", "key2", "key3"] + manager.update(keys) + # Retrieve the records + read_keys = manager.list_keys() + assert read_keys == ["key1", "key2", "key3"]