From 284c5d95bf81e1a68d685a4f6ebe2bf2c67a9dd5 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 15 Sep 2023 11:49:40 +0200 Subject: [PATCH 01/28] CrateDB vector: Add vector store support The implementation is based on the `pgvector` adapter, as both PostgreSQL and CrateDB share similar attributes, and can be wrapped well by using the same SQLAlchemy layer on top. --- libs/community/extended_testing_deps.txt | 2 + .../vectorstores/docker-compose/cratedb.yml | 20 + .../vectorstores/test_cratedb.py | 445 ++++++++++++++++++ .../langchain/vectorstores/__init__.py | 3 + .../vectorstores/cratedb/__init__.py | 6 + .../langchain/vectorstores/cratedb/base.py | 396 ++++++++++++++++ .../langchain/vectorstores/cratedb/model.py | 84 ++++ .../vectorstores/cratedb/sqlalchemy_type.py | 84 ++++ 8 files changed, 1040 insertions(+) create mode 100644 libs/community/tests/integration_tests/vectorstores/docker-compose/cratedb.yml create mode 100644 libs/community/tests/integration_tests/vectorstores/test_cratedb.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/__init__.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/base.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/model.py create mode 100644 libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index b2548b2219394..5bc78eaddb021 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -14,6 +14,8 @@ chardet>=5.1.0,<6 cloudpathlib>=0.18,<0.19 cloudpickle>=2.0.0 cohere>=4,<6 +crate>=0.34.0,<1 +cratedb-toolkit==0.0.12 databricks-vectorsearch>=0.21,<0.22 datasets>=2.15.0,<3 dgml-utils>=0.3.0,<0.4 diff --git a/libs/community/tests/integration_tests/vectorstores/docker-compose/cratedb.yml b/libs/community/tests/integration_tests/vectorstores/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py new file mode 100644 index 0000000000000..d62f0a125f661 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -0,0 +1,445 @@ +""" +Test CrateDB `FLOAT_VECTOR` / `KNN_MATCH` functionality. + +cd tests/integration_tests/vectorstores/docker-compose +docker-compose -f cratedb.yml up +""" +import os +from typing import List, Tuple + +import pytest +import sqlalchemy as sa +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.orm import Session + +from langchain.docstore.document import Document +from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings + +CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( + driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), + host=os.environ.get("TEST_CRATEDB_HOST", "localhost"), + port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")), + database=os.environ.get("TEST_CRATEDB_DATABASE", "testdrive"), + user=os.environ.get("TEST_CRATEDB_USER", "crate"), + password=os.environ.get("TEST_CRATEDB_PASSWORD", ""), +) + + +# TODO: Try 1536 after https://github.com/crate/crate/pull/14699. +# ADA_TOKEN_COUNT = 14 +ADA_TOKEN_COUNT = 1024 +# ADA_TOKEN_COUNT = 1536 + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture(autouse=True) +def drop_tables(engine: sa.Engine) -> None: + """ + Drop database tables. + """ + try: + BaseModel.metadata.drop_all(engine, checkfirst=False) + except Exception as ex: + if "RelationUnknown" not in str(ex): + raise + + +@pytest.fixture +def prune_tables(engine: sa.Engine) -> None: + """ + Delete data from database tables. + """ + with engine.connect() as conn: + with Session(conn) as session: + from langchain.vectorstores.cratedb.model import model_factory + + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for deleting records from tables. + CollectionStore, EmbeddingStore = model_factory(dimensions=1024) + + try: + session.query(CollectionStore).delete() + except ProgrammingError: + pass + try: + session.query(EmbeddingStore).delete() + except ProgrammingError: + pass + + +def decode_output( + output: List[Tuple[Document, float]] +) -> Tuple[List[Document], List[float]]: + """ + Decode a typical API result into separate `documents` and `scores`. + It is needed as utility function in some test cases to compensate + for different and/or flaky score values, when compared to the + original implementation. + """ + documents = [item[0] for item in output] + scores = [round(item[1], 1) for item in output] + return documents, scores + + +class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] + + +def test_cratedb_texts() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = CrateDBVectorSearch.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cratedb_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + +def test_cratedb_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1) + # TODO: Original: + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + assert output in [ + [(Document(page_content="foo", metadata={"page": "0"}), 1.0828735)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.1307646)], + ] + + +def test_cratedb_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + # TODO: Original: + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + assert output in [ + [(Document(page_content="foo", metadata={"page": "0"}), 1.2615292)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.3979403)], + [(Document(page_content="foo", metadata={"page": "0"}), 1.5065275)], + ] + + +def test_cratedb_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + # TODO: Original: + # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + output = docsearch.similarity_search_with_score("foo", k=3, filter={"page": "2"}) + # TODO: Original: + # assert output == [ + # (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501 + # ] + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores in [ + [0.5], + [0.6], + [0.7], + ] + + +def test_cratedb_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) + assert output == [] + + +def test_cratedb_collection_with_metadata() -> None: + """Test end to end collection construction""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + cratedb_vector = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + collection = cratedb_vector.get_collection(cratedb_vector.Session()) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + +def test_cratedb_collection_no_embedding_dimension() -> None: + """Test end to end collection construction""" + cratedb_vector = CrateDBVectorSearch( + embedding_function=None, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = Session(cratedb_vector.connect()) + with pytest.raises(RuntimeError) as ex: + cratedb_vector.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying dimension size of embedding vectors" + ) + + +def test_cratedb_with_filter_in_set() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=2, filter={"page": {"IN": ["0", "2"]}} + ) + # TODO: Original: + """ + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 0.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), + ] + """ + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores == [2.1, 1.3] + + +def test_cratedb_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + docsearch.delete(["1", "2"]) + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == ["3"] # type: ignore + + docsearch.delete(["2", "3"]) # Should not raise on missing ids + with docsearch._make_session() as session: + records = list(session.query(docsearch.EmbeddingStore).all()) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.custom_id for record in records) == [] # type: ignore + + +def test_cratedb_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = docsearch.similarity_search_with_relevance_scores("foo", k=3) + """ + # TODO: Original code, where the `distance` is stable. + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), + (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), + ] + """ + documents, scores = decode_output(output) + assert documents == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + Document(page_content="baz", metadata={"page": "2"}), + ] + assert scores == [0.8, 0.4, 0.2] + + +def test_cratedb_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + # TODO: Original: + # search_kwargs={"k": 3, "score_threshold": 0.999}, + search_kwargs={"k": 3, "score_threshold": 0.333}, + ) + output = retriever.get_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + +def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None: + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = retriever.get_relevant_documents("foo") + assert output == [] + + +def test_cratedb_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + +def test_cratedb_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) + # TODO: Original: + # assert output == [(Document(page_content="foo"), 0.0)] + assert output in [ + [(Document(page_content="foo"), 1.0606961)], + [(Document(page_content="foo"), 1.0828735)], + [(Document(page_content="foo"), 1.1307646)], + ] diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 603421aad0e08..27934e93a7b33 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -43,6 +43,7 @@ Clarifai, Clickhouse, ClickhouseSettings, + CrateDBVectorSearch, DashVector, DatabricksVectorSearch, DeepLake, @@ -119,6 +120,7 @@ "Clarifai": "langchain_community.vectorstores", "Clickhouse": "langchain_community.vectorstores", "ClickhouseSettings": "langchain_community.vectorstores", + "CrateDBVectorSearch": "langchain_community.vectorstores", "DashVector": "langchain_community.vectorstores", "DatabricksVectorSearch": "langchain_community.vectorstores", "DeepLake": "langchain_community.vectorstores", @@ -202,6 +204,7 @@ def __getattr__(name: str) -> Any: "Clarifai", "Clickhouse", "ClickhouseSettings", + "CrateDBVectorSearch", "DashVector", "DatabricksVectorSearch", "DeepLake", diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py new file mode 100644 index 0000000000000..303a52babeaea --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseModel, CrateDBVectorSearch + +__all__ = [ + "BaseModel", + "CrateDBVectorSearch", +] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py new file mode 100644 index 0000000000000..ca1c21fad68a7 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import enum +import math +import uuid +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import sqlalchemy +from cratedb_toolkit.sqlalchemy.patch import patch_inspector +from cratedb_toolkit.sqlalchemy.polyfill import ( + polyfill_refresh_after_dml, + refresh_table, +) +from sqlalchemy.orm import declarative_base, sessionmaker + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.pgvector import PGVector + + +class DistanceStrategy(str, enum.Enum): + """Enumerator of the Distance strategies.""" + + EUCLIDEAN = "euclidean" + COSINE = "cosine" + MAX_INNER_PRODUCT = "inner" + + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN + +Base = declarative_base() # type: Any +# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid) + + +def _results_to_docs(docs_and_scores: Any) -> List[Document]: + """Return docs from docs and scores.""" + return [doc for doc, _ in docs_and_scores] + + +class CrateDBVectorSearch(PGVector): + """`CrateDB` vector store. + + To use it, you should have the ``crate[sqlalchemy]`` python package installed. + + Args: + connection_string: Database connection string. + embedding_function: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + collection_name: The name of the collection to use. (default: langchain) + NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + distance_strategy: The distance strategy to use. (default: EUCLIDEAN) + pre_delete_collection: If True, will delete the collection if it exists. + (default: False). Useful for testing. + + Example: + .. code-block:: python + + from langchain.vectorstores import CrateDBVectorSearch + from langchain.embeddings.openai import OpenAIEmbeddings + + CONNECTION_STRING = "crate://crate@localhost:4200/test3" + COLLECTION_NAME = "state_of_the_union_test" + embeddings = OpenAIEmbeddings() + vectorestore = CrateDBVectorSearch.from_documents( + embedding=embeddings, + documents=docs, + collection_name=COLLECTION_NAME, + connection_string=CONNECTION_STRING, + ) + + + """ + + def __post_init__( + self, + ) -> None: + """ + Initialize the store. + """ + + # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. + patch_inspector() + + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + + # TODO: See what can be improved here. + polyfill_refresh_after_dml(self.Session) + + self.CollectionStore = None + self.EmbeddingStore = None + + def get_collection( + self, session: sqlalchemy.orm.Session + ) -> Optional["CollectionStore"]: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_name(session, self.collection_name) + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + from langchain.vectorstores.cratedb.model import model_factory + + dimensions = len(embeddings[0]) + self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=dimensions) + if self.pre_delete_collection: + self.delete_collection() + self.create_tables_if_not_exists() + self.create_collection() + return super().add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + def create_tables_if_not_exists(self) -> None: + """ + Need to overwrite because `Base` is different from upstream. + """ + Base.metadata.create_all(self._engine) + + def drop_tables(self) -> None: + """ + Need to overwrite because `Base` is different from upstream. + """ + Base.metadata.drop_all(self._engine) + + def delete( + self, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """ + Delete vectors by ids or uuids. + + Remark: Specialized for CrateDB to synchronize data. + + Args: + ids: List of ids to delete. + + Remark: Patch for CrateDB needs to overwrite this, in order to + add a "REFRESH TABLE" statement afterwards. The other + patch, listening to `after_delete` events seems not be + able to catch it. + """ + super().delete(ids=ids, **kwargs) + + # CrateDB: Synchronize data because `on_flush` does not catch it. + with self.Session() as session: + refresh_table(session, self.EmbeddingStore) + + @property + def distance_strategy(self) -> Any: + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self.EmbeddingStore.embedding.euclidean_distance + elif self._distance_strategy == DistanceStrategy.COSINE: + raise NotImplementedError("Cosine similarity not implemented yet") + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + raise NotImplementedError("Dot-product similarity not implemented yet") + else: + raise ValueError( + f"Got unexpected value for distance: {self._distance_strategy}. " + f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." + ) + + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result._score if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + with self.Session() as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + + filter_by = self.EmbeddingStore.collection_id == collection.uuid + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + IN = "in" + if isinstance(value, dict) and IN in map(str.lower, value): + value_case_insensitive = { + k.lower(): v for k, v in value.items() + } + filter_by_metadata = self.EmbeddingStore.cmetadata[key].in_( + value_case_insensitive[IN] + ) + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = self.EmbeddingStore.cmetadata[key] == str( + value + ) # type: ignore[assignment] + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + + _type = self.EmbeddingStore + + results: List[Any] = ( + session.query( # type: ignore[attr-defined] + self.EmbeddingStore, + # TODO: Original pgvector code uses `self.distance_strategy`. + # CrateDB currently only supports EUCLIDEAN. + # self.distance_strategy(embedding).label("distance") + sqlalchemy.literal_column( + f"{self.EmbeddingStore.__tablename__}._score" + ).label("_score"), + ) + .filter(filter_by) + # CrateDB applies `KNN_MATCH` within the `WHERE` clause. + .filter( + sqlalchemy.func.knn_match( + self.EmbeddingStore.embedding, embedding, k + ) + ) + .order_by(sqlalchemy.desc("_score")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) + return results + + @classmethod + def from_texts( # type: ignore[override] + cls: Type[CrateDBVectorSearch], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> CrateDBVectorSearch: + """ + Return VectorStore initialized from texts and embeddings. + Database connection string is required. + + Either pass it as a parameter, or set the CRATEDB_CONNECTION_STRING + environment variable. + + Remark: Needs to be overwritten, because CrateDB uses a different + DEFAULT_DISTANCE_STRATEGY. + """ + return super().from_texts( # type: ignore[return-value] + texts, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, # type: ignore[arg-type] + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="CRATEDB_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Database connection string is required." + "Either pass it as a parameter, or set the " + "CRATEDB_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def connection_string_from_db_params( + cls, + driver: str, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> str: + """Return connection string from database parameters.""" + return str( + sqlalchemy.URL.create( + drivername=driver, + host=host, + port=port, + username=user, + password=password, + query={"schema": database}, + ) + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return self._cosine_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: + return self._euclidean_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function for distance_strategy of " + "{self._distance_strategy}. Consider providing relevance_score_fn to " + "CrateDBVectorSearch constructor." + ) + + @staticmethod + def _euclidean_relevance_score_fn(score: float) -> float: + """Return a similarity score on a scale [0, 1].""" + # The 'correct' relevance function + # may differ depending on a few things, including: + # - the distance / similarity metric used by the VectorStore + # - the scale of your embeddings (OpenAI's are unit normed. Many + # others are not!) + # - embedding dimensionality + # - etc. + # This function converts the euclidean norm of normalized embeddings + # (0 is most similar, sqrt(2) most dissimilar) + # to a similarity function (0 to 1) + + # Original: + # return 1.0 - distance / math.sqrt(2) + return score / math.sqrt(2) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py new file mode 100644 index 0000000000000..ee42e7269dc9d --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -0,0 +1,84 @@ +from functools import lru_cache +from typing import Optional, Tuple + +import sqlalchemy +from crate.client.sqlalchemy.types import ObjectType +from sqlalchemy.orm import Session, relationship + +from langchain.vectorstores.cratedb.base import BaseModel +from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector + + +@lru_cache +def model_factory(dimensions: int): + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + try: + return ( + session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + class EmbeddingStore(BaseModel): + """Embedding store.""" + + __tablename__ = "embedding" + + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") + + embedding = sqlalchemy.Column(FloatVector(dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + return CollectionStore, EmbeddingStore diff --git a/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py new file mode 100644 index 0000000000000..e784c3013a3d9 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py @@ -0,0 +1,84 @@ +# TODO: Refactor to CrateDB SQLAlchemy dialect. +import typing as t + +import numpy as np +import numpy.typing as npt +import sqlalchemy as sa +from sqlalchemy.types import UserDefinedType + +__all__ = ["FloatVector"] + + +def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: + # from `pgvector.utils` + # could be ndarray if already cast by lower-level driver + if value is None or isinstance(value, np.ndarray): + return value + + return np.array(value, dtype=np.float32) + + +def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: + # from `pgvector.utils` + if value is None: + return value + + if isinstance(value, np.ndarray): + if value.ndim != 1: + raise ValueError("expected ndim to be 1") + + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( + value.dtype, np.floating + ): + raise ValueError("dtype must be numeric") + + value = value.tolist() + + if dim is not None and len(value) != dim: + raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) + + return value + + +class FloatVector(UserDefinedType): + """ + https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector + https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match + """ + + cache_ok = True + + def __init__(self, dim: t.Optional[int] = None) -> None: + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw: t.Any) -> str: + if self.dim is None: + return "FLOAT_VECTOR" + return "FLOAT_VECTOR(%d)" % self.dim + + def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def process(value: t.Iterable) -> t.Optional[t.List]: + return to_db(value, self.dim) + + return process + + def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + return from_db(value) + + return process + + """ + CrateDB currently only supports similarity function `VectorSimilarityFunction.EUCLIDEAN`. + -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55 + + On the other hand, pgvector use a comparator to apply different similarity functions as operators, + see `pgvector.sqlalchemy.Vector.comparator_factory`. + + <->: l2/euclidean_distance + <#>: max_inner_product + <=>: cosine_distance + + TODO: Discuss. + """ # noqa: E501 From ba95bde5448abdbc9b79ba46f077aa24c21c5021 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 00:15:12 +0200 Subject: [PATCH 02/28] CrateDB vector: Add documentation --- docs/docs/integrations/providers/cratedb.mdx | 136 +++++ .../integrations/vectorstores/cratedb.ipynb | 479 ++++++++++++++++++ 2 files changed, 615 insertions(+) create mode 100644 docs/docs/integrations/providers/cratedb.mdx create mode 100644 docs/docs/integrations/vectorstores/cratedb.ipynb diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx new file mode 100644 index 0000000000000..2321a51912d09 --- /dev/null +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -0,0 +1,136 @@ +# CrateDB + +This documentation section shows how to use the CrateDB vector store +functionality around [`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn +how to use it for similarity search and other purposes. + + +## What is CrateDB? + +[CrateDB] is an open-source, distributed, and scalable SQL analytics database +for storing and analyzing massive amounts of data in near real-time, even with +complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits +the shared-nothing distribution layer of [Elasticsearch]. + +It provides a distributed, multi-tenant-capable relational database and search +engine with HTTP and PostgreSQL interfaces, and schema-free objects. It supports +sharding, partitioning, and replication out of the box. + +CrateDB enables you to efficiently store billions of records, and terabytes of +data, and query it using SQL. + +- Provides a standards-based SQL interface for querying relational data, nested + documents, geospatial constraints, and vector embeddings at the same time. +- Improves your operations by storing time-series data, relational metadata, + and vector embeddings within a single database. +- Builds upon approved technologies from Lucene and Elasticsearch. + + +## CrateDB Cloud + +- Offers on-demand CrateDB clusters without operational overhead, + with enterprise-grade features and [ISO 27001] certification. +- The entrypoint to [CrateDB Cloud] is the [CrateDB Cloud Console]. +- Crate.io offers a free tier via [CrateDB Cloud CRFREE]. +- To get started, [sign up] to CrateDB Cloud, deploy a database cluster, + and follow the upcoming instructions. + + +## Features + +The CrateDB adapter supports the Vector Store subsystem of LangChain. + +### Vector Store + +`CrateDBVectorSearch` is an API wrapper around CrateDB's `FLOAT_VECTOR` type +and the corresponding `KNN_MATCH` function, based on SQLAlchemy and CrateDB's +SQLAlchemy dialect. It provides an interface to store and retrieve floating +point vectors, and to conduct similarity searches. + +Supports: +- Approximate nearest neighbor search. +- Euclidean distance. + + +## Installation and Setup + +There are multiple ways to get started with CrateDB. + +### Install CrateDB on your local machine + +You can [download CrateDB], or use the [OCI image] to run CrateDB on Docker or Podman. +Note that this is not recommended for production use. + +```shell +docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ + --env=CRATE_HEAP_SIZE=4g crate/crate:nightly \ + -Cdiscovery.type=single-node +``` + +### Deploy a cluster on CrateDB Cloud + +[CrateDB Cloud] is a managed CrateDB service. Sign up for a [free trial]. + +### Install Client + +```bash +pip install 'crate[sqlalchemy]' 'langchain[openai]' +``` + + +## Usage + +For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also +a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html). + +### Acquire text file +The example uses the canonical `state_of_the_union.txt`. +```shell +wget https://raw.githubusercontent.com/langchain-ai/langchain/v0.0.291/docs/extras/modules/state_of_the_union.txt +``` + +### Set environment variables +Use a valid OpenAI API key and SQL connection string. This one fits a local instance of CrateDB. +```shell +export OPENAI_API_KEY=foobar +export CRATEDB_CONNECTION_STRING=crate://crate@localhost +``` + +```python +from langchain.document_loaders import TextLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import CrateDBVectorSearch + + +def main(): + # Load the document, split it into chunks, embed each chunk and load it into the vector store. + raw_documents = TextLoader("state_of_the_union.txt").load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + documents = text_splitter.split_documents(raw_documents) + db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings()) + + query = "What did the president say about Ketanji Brown Jackson" + docs = db.similarity_search(query) + print(docs[0].page_content) + + +if __name__ == "__main__": + main() +``` + + +[CrateDB]: https://github.com/crate/crate +[CrateDB Cloud]: https://crate.io/product +[CrateDB Cloud Console]: https://console.cratedb.cloud/ +[CrateDB Cloud CRFREE]: https://community.crate.io/t/new-cratedb-cloud-edge-feature-cratedb-cloud-free-tier/1402 +[CrateDB SQLAlchemy dialect]: https://crate.io/docs/python/en/latest/sqlalchemy.html +[download CrateDB]: https://crate.io/download +[Elastisearch]: https://github.com/elastic/elasticsearch +[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector +[free trial]: https://crate.io/lp-crfree?utm_source=langchain +[ISO 27001]: https://crate.io/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification +[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match +[Lucene]: https://github.com/apache/lucene +[OCI image]: https://hub.docker.com/_/crate +[sign up]: https://console.cratedb.cloud/ diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb new file mode 100644 index 0000000000000..462e721bfff40 --- /dev/null +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook shows how to use the CrateDB vector store functionality around\n", + "[`FLOAT_VECTOR`] and [`KNN_MATCH`]. You will learn how to use it for similarity\n", + "search and other purposes.\n", + "\n", + "It supports:\n", + "- Similarity Search with Euclidean Distance\n", + "- Maximal Marginal Relevance Search (MMR)\n", + "\n", + "## What is CrateDB?\n", + "\n", + "[CrateDB] is an open-source, distributed, and scalable SQL analytics database\n", + "for storing and analyzing massive amounts of data in near real-time, even with\n", + "complex queries. It is PostgreSQL-compatible, based on [Lucene], and inherits\n", + "the shared-nothing distribution layer of [Elasticsearch].\n", + "\n", + "This example uses the [Python client driver for CrateDB]. For more documentation,\n", + "see also [LangChain with CrateDB].\n", + "\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[Elasticsearch]: https://github.com/elastic/elasticsearch\n", + "[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector\n", + "[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match\n", + "[LangChain with CrateDB]: /docs/extras/integrations/providers/cratedb.html\n", + "[Lucene]: https://github.com/apache/lucene\n", + "[Python client driver for CrateDB]: https://crate.io/docs/python/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Getting Started" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [], + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "# Install required packages: LangChain, OpenAI SDK, and the CrateDB Python driver.\n", + "!pip install 'langchain[cratedb,openai]'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to provide an OpenAI API key, optionally using the environment\n", + "variable `OPENAI_API_KEY`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:16.802456Z", + "start_time": "2023-09-09T08:02:07.065604Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "# Run `export OPENAI_API_KEY=sk-YOUR_OPENAI_API_KEY`.\n", + "# Get OpenAI api key from `.env` file.\n", + "# Otherwise, prompt for it.\n", + "_ = load_dotenv(find_dotenv())\n", + "OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', getpass.getpass(\"OpenAI API key:\"))\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You also need to provide a connection string to your CrateDB database cluster,\n", + "optionally using the environment variable `CRATEDB_CONNECTION_STRING`.\n", + "\n", + "This example uses a CrateDB instance on your workstation, which you can start by\n", + "running [CrateDB using Docker]. Alternatively, you can also connect to a cluster\n", + "running on [CrateDB Cloud].\n", + "\n", + "[CrateDB Cloud]: https://console.cratedb.cloud/\n", + "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import os\n", + "\n", + "CONNECTION_STRING = os.environ.get(\n", + " \"CRATEDB_CONNECTION_STRING\",\n", + " \"crate://crate@localhost:4200/?schema=langchain\",\n", + ")\n", + "\n", + "# For CrateDB Cloud, use:\n", + "# CONNECTION_STRING = os.environ.get(\n", + "# \"CRATEDB_CONNECTION_STRING\",\n", + "# \"crate://username:password@hostname:4200/?ssl=true&schema=langchain\",\n", + "# )" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:02:28.174088Z", + "start_time": "2023-09-09T08:02:28.162698Z" + } + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "# Alternatively, the connection string can be assembled from individual\n", + "# environment variables.\n", + "import os\n", + "\n", + "CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params(\n", + " driver=os.environ.get(\"CRATEDB_DRIVER\", \"crate\"),\n", + " host=os.environ.get(\"CRATEDB_HOST\", \"localhost\"),\n", + " port=int(os.environ.get(\"CRATEDB_PORT\", \"4200\")),\n", + " database=os.environ.get(\"CRATEDB_DATABASE\", \"langchain\"),\n", + " user=os.environ.get(\"CRATEDB_USER\", \"crate\"),\n", + " password=os.environ.get(\"CRATEDB_PASSWORD\", \"\"),\n", + ")\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "You will start by importing all required modules." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import CrateDBVectorSearch\n", + "from langchain.document_loaders import UnstructuredURLLoader\n", + "from langchain.docstore.document import Document" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Next, you will read input data, and tokenize it." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Similarity Search with Euclidean Distance (Default)\n", + "\n", + "The module will create a table with the name of the collection. Make sure\n", + "the collection name is unique and that you have the permission to create\n", + "a table." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:04:16.696625Z", + "start_time": "2023-09-09T08:02:31.817790Z" + } + }, + "outputs": [], + "source": [ + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "db = CrateDBVectorSearch.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:11.104135Z", + "start_time": "2023-09-09T08:05:10.548998Z" + } + }, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs_with_score = db.similarity_search_with_score(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2023-09-09T08:05:13.532334Z", + "start_time": "2023-09-09T08:05:13.523191Z" + } + }, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Maximal Marginal Relevance Search (MMR)\n", + "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "docs_with_score = db.max_marginal_relevance_search_with_score(query)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:23.276819Z", + "start_time": "2023-09-09T08:05:21.972256Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-09T08:05:27.478580Z", + "start_time": "2023-09-09T08:05:27.470138Z" + } + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Working with the vector store\n", + "\n", + "In the example above, you created a vector store from scratch. When\n", + "aiming to work with an existing vector store, you can initialize it directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store = CrateDBVectorSearch(\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " embedding_function=embeddings,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Add documents\n", + "\n", + "You can also add documents to an existing vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store.add_documents([Document(page_content=\"foo\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Overwriting a vector store\n", + "\n", + "If you have an existing collection, you can overwrite it by using `from_documents`,\n", + "aad setting `pre_delete_collection = True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db = CrateDBVectorSearch.from_documents(\n", + " documents=docs,\n", + " embedding=embeddings,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + " pre_delete_collection=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score = db.similarity_search_with_score(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "docs_with_score[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using a vector store as a retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = store.as_retriever()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(retriever)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 00159ff8cc7d56c05c5e3f6be0f8842e14b7e00e Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 20:00:09 +0200 Subject: [PATCH 03/28] CrateDB loader: Add SQLAlchemy document loader This will become SQLDatabaseLoader later. --- docs/docs/how_to/sqlalchemy.mdx | 155 ++++++++++++ .../document_loaders/sqlalchemy.ipynb | 237 ++++++++++++++++++ .../docker-compose/postgresql.yml | 19 ++ .../test_sqlalchemy_postgresql.py | 177 +++++++++++++ .../test_sqlalchemy_sqlite.py | 181 +++++++++++++ .../langchain/document_loaders/__init__.py | 4 + .../langchain/document_loaders/sqlalchemy.py | 112 +++++++++ libs/langchain/tests/data.py | 4 + .../examples/mlb_teams_2012.csv | 32 +++ .../examples/mlb_teams_2012.sql | 38 +++ 10 files changed, 959 insertions(+) create mode 100644 docs/docs/how_to/sqlalchemy.mdx create mode 100644 docs/docs/integrations/document_loaders/sqlalchemy.ipynb create mode 100644 libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml create mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py create mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py create mode 100644 libs/langchain/langchain/document_loaders/sqlalchemy.py create mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv create mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql diff --git a/docs/docs/how_to/sqlalchemy.mdx b/docs/docs/how_to/sqlalchemy.mdx new file mode 100644 index 0000000000000..9f7e663db075e --- /dev/null +++ b/docs/docs/how_to/sqlalchemy.mdx @@ -0,0 +1,155 @@ +# SQLAlchemy + + +## About + +The [SQLAlchemy] document loader loads records from any supported database, +see [SQLAlchemy dialects] for all supported SQL databases and dialects. + +You can either use plain SQL for querying, or use an SQLAlchemy `Select` +statement object, if you are using SQLAlchemy-Core or -ORM. + +You can select which columns to place into the document, which columns +to place into its metadata, which columns to use as a `source` attribute +in metadata, and whether to include the result row number and/or the SQL +query expression into the metadata. + + +## Example + +This example uses PostgreSQL, and the `psycopg2` driver. + + +### Prerequisites + +```shell +psql postgresql://postgres@localhost/ --command "CREATE DATABASE testdrive;" +psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +``` + + +### Basic loading + +```python +from langchain.document_loaders import SQLAlchemyLoader +from pprint import pprint + + +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={})] +``` + + + + +## Enriching metadata + +Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to +optionally populate the `metadata` dictionary with corresponding information. + +Having the `query` within metadata is useful when using documents loaded from +database tables for chains that answer questions using their origin queries. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + include_rownum_into_metadata=True, + include_query_into_metadata=True, +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'row': 2, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'})] +``` + + + + +## Customizing metadata + +Use the `page_content_columns`, and `metadata_columns` options to optionally populate +the `metadata` dictionary with corresponding information. When `page_content_columns` +is empty, all columns will be used. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + page_content_columns=["Payroll (millions)", "Wins"], + metadata_columns=["Team"], +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Payroll (millions): 81.34\nWins: 98', metadata={'Team': 'Nationals'}), + Document(page_content='Payroll (millions): 82.2\nWins: 97', metadata={'Team': 'Reds'}), + Document(page_content='Payroll (millions): 197.96\nWins: 95', metadata={'Team': 'Yankees'})] +``` + + + + +## Specify column(s) to identify the document source + +Use the `source_columns` option to specify the columns to use as a "source" for the +document created from each row. This is useful for identifying documents through +their metadata. Typically, you may use the primary key column(s) for that purpose. + +```python +loader = SQLAlchemyLoader( + query="SELECT * FROM mlb_teams_2012 LIMIT 3;", + url="postgresql+psycopg2://postgres@localhost:5432/testdrive", + source_columns="Team", +) +docs = loader.load() +``` + +```python +pprint(docs) +``` + + + +``` +[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'source': 'Nationals'}), + Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'source': 'Reds'}), + Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'source': 'Yankees'})] +``` + + + + +[SQLAlchemy]: https://www.sqlalchemy.org/ +[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/20/dialects/ diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb new file mode 100644 index 0000000000000..7bb141592296a --- /dev/null +++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQLAlchemy\n", + "\n", + "This notebook demonstrates how to load documents from an [SQLite] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/\n", + "[SQLite]: https://sqlite.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install langchain termsql" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Provide input data as SQLite database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting example.csv\n" + ] + } + ], + "source": [ + "%%file example.csv\n", + "Team,Payroll\n", + "Nationals,81.34\n", + "Reds,82.20" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nationals|81.34\r\n", + "Reds|82.2\r\n" + ] + } + ], + "source": [ + "!termsql --infile=example.csv --head --delimiter=\",\" --outfile=example.sqlite --table=payroll" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader\n", + "from pprint import pprint\n", + "\n", + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals', metadata={'Payroll': 81.34}),\n", + " Document(page_content='Team: Reds', metadata={'Payroll': 82.2})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLAlchemyLoader(\n", + " \"SELECT * FROM payroll\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={'source': 'Nationals'}),\n", + " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={'source': 'Reds'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml b/libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml new file mode 100644 index 0000000000000..f8ab2cfdeb418 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml @@ -0,0 +1,19 @@ +version: "3" + +services: + postgresql: + image: postgres:16 + environment: + - POSTGRES_HOST_AUTH_METHOD=trust + ports: + - "5432:5432" + command: | + postgres -c log_statement=all + healthcheck: + test: + [ + "CMD-SHELL", + "psql postgresql://postgres@localhost --command 'SELECT 1;' || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py new file mode 100644 index 0000000000000..29f52cb9f7a33 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py @@ -0,0 +1,177 @@ +""" +Test SQLAlchemy/PostgreSQL document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f postgresql.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import psycopg2 # noqa: F401 + + psycopg2_installed = True +except ImportError: + psycopg2_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_POSTGRESQL_CONNECTION_STRING", + "postgresql+psycopg2://postgres@localhost:5432/", +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_no_options() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_rownum_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_include_query_into_metadata() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") +def test_postgresql_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with psycopg2.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py new file mode 100644 index 0000000000000..f1fac2cecbc00 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py @@ -0,0 +1,181 @@ +""" +Test SQLAlchemy/SQLite document loader functionality. +""" +import logging +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse +from _pytest.tmpdir import TempPathFactory + +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import sqlite3 # noqa: F401 + + sqlite3_installed = True +except ImportError: + sqlite3_installed = False + + +@pytest.fixture(scope="module") +def db_uri(tmp_path_factory: TempPathFactory) -> str: + """ + Return an SQLAlchemy URI for a temporary SQLite database. + """ + db_path = tmp_path_factory.getbasetemp().joinpath("testdrive.sqlite") + return f"sqlite:///{db_path}" + + +@pytest.fixture(scope="module") +def engine(db_uri: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(db_uri, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.commit() + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_no_options(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=db_uri) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_rownum_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_include_query_into_metadata(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", url=db_uri, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_page_content_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=db_uri, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_metadata_columns(db_uri: str) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + "SELECT 1 AS a, 2 AS b", + url=db_uri, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_sql( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + loader = SQLAlchemyLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=db_uri, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") +def test_sqlite_loader_real_data_with_selectable( + db_uri: str, provision_database: None +) -> None: + """Test SQLAlchemy loader with sqlite3.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLAlchemyLoader( + query=select, + url=db_uri, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index d5a6b9726ce3c..058993541ab5b 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -57,6 +57,7 @@ ConfluenceLoader, CoNLLULoader, CouchbaseLoader, + CrateDBLoader, CSVLoader, CubeSemanticLoader, DatadogLogsLoader, @@ -240,6 +241,7 @@ "ConcurrentLoader": "langchain_community.document_loaders", "ConfluenceLoader": "langchain_community.document_loaders", "CouchbaseLoader": "langchain_community.document_loaders", + "CrateDBLoader": "langchain_community.document_loaders", "CubeSemanticLoader": "langchain_community.document_loaders", "DataFrameLoader": "langchain_community.document_loaders", "DatadogLogsLoader": "langchain_community.document_loaders", @@ -421,6 +423,7 @@ def __getattr__(name: str) -> Any: "ConcurrentLoader", "ConfluenceLoader", "CouchbaseLoader", + "CrateDBLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -509,6 +512,7 @@ def __getattr__(name: str) -> Any: "SlackDirectoryLoader", "SnowflakeLoader", "SpreedlyLoader", + "SQLAlchemyLoader", "StripeLoader", "TelegramChatApiLoader", "TelegramChatFileLoader", diff --git a/libs/langchain/langchain/document_loaders/sqlalchemy.py b/libs/langchain/langchain/document_loaders/sqlalchemy.py new file mode 100644 index 0000000000000..787c9f339b686 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/sqlalchemy.py @@ -0,0 +1,112 @@ +from typing import Dict, List, Optional, Union + +import sqlalchemy as sa + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class SQLAlchemyLoader(BaseLoader): + """ + Load documents by querying database tables supported by SQLAlchemy. + Each document represents one row of the result. + """ + + def __init__( + self, + query: Union[str, sa.Select], + url: str, + page_content_columns: Optional[List[str]] = None, + metadata_columns: Optional[List[str]] = None, + source_columns: Optional[List[str]] = None, + include_rownum_into_metadata: bool = False, + include_query_into_metadata: bool = False, + sqlalchemy_kwargs: Optional[Dict] = None, + ): + """ + + Args: + query: The query to execute. + url: The SQLAlchemy connection string of the database to connect to. + page_content_columns: The columns to write into the `page_content` + of the document. Optional. + metadata_columns: The columns to write into the `metadata` of the document. + Optional. + source_columns: The names of the columns to use as the `source` within the + metadata dictionary. Optional. + include_rownum_into_metadata: Whether to include the row number into the + metadata dictionary. Optional. Default: False. + include_query_into_metadata: Whether to include the query expression into + the metadata dictionary. Optional. Default: False. + sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`. + """ + self.query = query + self.url = url + self.page_content_columns = page_content_columns + self.metadata_columns = metadata_columns + self.source_columns = source_columns + self.include_rownum_into_metadata = include_rownum_into_metadata + self.include_query_into_metadata = include_query_into_metadata + self.sqlalchemy_kwargs = sqlalchemy_kwargs or {} + + def load(self) -> List[Document]: + try: + import sqlalchemy as sa + except ImportError: + raise ImportError( + "Could not import sqlalchemy python package. " + "Please install it with `pip install sqlalchemy`." + ) + + engine = sa.create_engine(self.url, **self.sqlalchemy_kwargs) + + docs = [] + with engine.connect() as conn: + if isinstance(self.query, sa.Select): + result = conn.execute(self.query) + query_sql = str(self.query.compile(bind=engine)) + elif isinstance(self.query, str): + result = conn.execute(sa.text(self.query)) + query_sql = self.query + else: + raise TypeError( + f"Unable to process query of unknown type: {self.query}" + ) + field_names = list(result.mappings().keys()) + + if self.page_content_columns is None: + page_content_columns = field_names + else: + page_content_columns = self.page_content_columns + + if self.metadata_columns is None: + metadata_columns = [] + else: + metadata_columns = self.metadata_columns + + for i, row in enumerate(result.mappings()): + page_content = "\n".join( + f"{column}: {value}" + for column, value in row.items() + if column in page_content_columns + ) + + metadata: Dict[str, Union[str, int]] = {} + if self.include_rownum_into_metadata: + metadata["row"] = i + if self.include_query_into_metadata: + metadata["query"] = query_sql + + source_values = [] + for column, value in row.items(): + if column in metadata_columns: + metadata[column] = value + if self.source_columns and column in self.source_columns: + source_values.append(value) + if source_values: + metadata["source"] = ",".join(source_values) + + doc = Document(page_content=page_content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/libs/langchain/tests/data.py b/libs/langchain/tests/data.py index b4f53baf356b4..c1206fb2ccbbe 100644 --- a/libs/langchain/tests/data.py +++ b/libs/langchain/tests/data.py @@ -10,3 +10,7 @@ HELLO_PDF = _EXAMPLES_DIR / "hello.pdf" LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf" DUPLICATE_CHARS = _EXAMPLES_DIR / "duplicate-chars.pdf" + +# Paths to data files +MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv" +MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql" diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv new file mode 100644 index 0000000000000..b22ae961a1331 --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv @@ -0,0 +1,32 @@ +"Team", "Payroll (millions)", "Wins" +"Nationals", 81.34, 98 +"Reds", 82.20, 97 +"Yankees", 197.96, 95 +"Giants", 117.62, 94 +"Braves", 83.31, 94 +"Athletics", 55.37, 94 +"Rangers", 120.51, 93 +"Orioles", 81.43, 93 +"Rays", 64.17, 90 +"Angels", 154.49, 89 +"Tigers", 132.30, 88 +"Cardinals", 110.30, 88 +"Dodgers", 95.14, 86 +"White Sox", 96.92, 85 +"Brewers", 97.65, 83 +"Phillies", 174.54, 81 +"Diamondbacks", 74.28, 81 +"Pirates", 63.43, 79 +"Padres", 55.24, 76 +"Mariners", 81.97, 75 +"Mets", 93.35, 74 +"Blue Jays", 75.48, 73 +"Royals", 60.91, 72 +"Marlins", 118.07, 69 +"Red Sox", 173.18, 69 +"Indians", 78.43, 68 +"Twins", 94.08, 66 +"Rockies", 78.06, 64 +"Cubs", 88.19, 61 +"Astros", 60.65, 55 + diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql new file mode 100644 index 0000000000000..91029ddcd3563 --- /dev/null +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql @@ -0,0 +1,38 @@ +-- psql postgresql://postgres@localhost < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; From 473b66a936383c4dd3a1e5ea10e47dd53ca02eaa Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 16 Sep 2023 20:01:15 +0200 Subject: [PATCH 04/28] CrateDB loader: Add document loader support The implementation is based on the generic SQLAlchemy document loader. --- .../document_loaders/cratedb.ipynb | 232 ++++++++++++++++++ .../example_data/mlb_teams_2012.sql | 1 + .../document_loaders/sqlalchemy.ipynb | 2 +- docs/docs/integrations/providers/cratedb.mdx | 51 +++- .../docker-compose/cratedb.yml | 20 ++ .../test_sqlalchemy_cratedb.py | 146 +++++++++++ .../langchain/document_loaders/cratedb.py | 5 + .../examples/mlb_teams_2012.sql | 5 +- 8 files changed, 453 insertions(+), 9 deletions(-) create mode 100644 docs/docs/integrations/document_loaders/cratedb.ipynb create mode 100644 libs/community/tests/integration_tests/document_loaders/docker-compose/cratedb.yml create mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py create mode 100644 libs/langchain/langchain/document_loaders/cratedb.py diff --git a/docs/docs/integrations/document_loaders/cratedb.ipynb b/docs/docs/integrations/document_loaders/cratedb.ipynb new file mode 100644 index 0000000000000..78a0e19138703 --- /dev/null +++ b/docs/docs/integrations/document_loaders/cratedb.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CrateDB\n", + "\n", + "This notebook demonstrates how to load documents from a [CrateDB] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[CrateDB]: https://github.com/crate/crate\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install crash 'langchain[cratedb]'" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Populate database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mPSQL OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[32mDELETE OK, 30 rows affected (0.008 sec)\r\n", + "\u001B[0m\u001B[32mINSERT OK, 30 rows affected (0.011 sec)\r\n", + "\u001B[0m\u001B[0m\u001B[32mCONNECT OK\r\n", + "\u001B[0m\u001B[32mREFRESH OK, 1 row affected (0.001 sec)\r\n", + "\u001B[0m\u001B[0m" + ] + } + ], + "source": [ + "!crash < ./example_data/mlb_teams_2012.sql\n", + "!crash --command \"REFRESH TABLE mlb_teams_2012;\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import CrateDBLoader\n", + "from pprint import pprint\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost/\"\n", + "\n", + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying Which Columns are Content vs Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " page_content_columns=[\"Team\"],\n", + " metadata_columns=[\"Payroll (millions)\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels', metadata={'Payroll (millions)': 154.49}),\n", + " Document(page_content='Team: Astros', metadata={'Payroll (millions)': 60.65}),\n", + " Document(page_content='Team: Athletics', metadata={'Payroll (millions)': 55.37}),\n", + " Document(page_content='Team: Blue Jays', metadata={'Payroll (millions)': 75.48}),\n", + " Document(page_content='Team: Braves', metadata={'Payroll (millions)': 83.31})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Source to Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CrateDBLoader(\n", + " 'SELECT * FROM mlb_teams_2012 ORDER BY \"Team\" LIMIT 5;',\n", + " url=CONNECTION_STRING,\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Angels\\nPayroll (millions): 154.49\\nWins: 89', metadata={'source': 'Angels'}),\n", + " Document(page_content='Team: Astros\\nPayroll (millions): 60.65\\nWins: 55', metadata={'source': 'Astros'}),\n", + " Document(page_content='Team: Athletics\\nPayroll (millions): 55.37\\nWins: 94', metadata={'source': 'Athletics'}),\n", + " Document(page_content='Team: Blue Jays\\nPayroll (millions): 75.48\\nWins: 73', metadata={'source': 'Blue Jays'}),\n", + " Document(page_content='Team: Braves\\nPayroll (millions): 83.31\\nWins: 94', metadata={'source': 'Braves'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql index 33cb765a38ebe..6d94aeaa773b8 100644 --- a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql +++ b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql @@ -1,5 +1,6 @@ -- Provisioning table "mlb_teams_2012". -- +-- crash < mlb_teams_2012.sql -- psql postgresql://postgres@localhost < mlb_teams_2012.sql DROP TABLE IF EXISTS mlb_teams_2012; diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb index 7bb141592296a..5d603d7263c53 100644 --- a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb +++ b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb @@ -103,7 +103,7 @@ }, "outputs": [], "source": [ - "from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader\n", + "from langchain.document_loaders import SQLAlchemyLoader\n", "from pprint import pprint\n", "\n", "loader = SQLAlchemyLoader(\n", diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 2321a51912d09..948bdd85dee6a 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -38,7 +38,8 @@ data, and query it using SQL. ## Features -The CrateDB adapter supports the Vector Store subsystem of LangChain. +The CrateDB adapter supports the _Vector Store_ and _Document Loader_ +subsystems of LangChain. ### Vector Store @@ -51,6 +52,11 @@ Supports: - Approximate nearest neighbor search. - Euclidean distance. +### Document Loader + +`CrateDBLoader` provides loading documents from a database table by an SQL +query expression or an SQLAlchemy selectable instance. + ## Installation and Setup @@ -74,19 +80,19 @@ docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ ### Install Client ```bash -pip install 'crate[sqlalchemy]' 'langchain[openai]' +pip install 'crate[sqlalchemy]' 'langchain[openai]' 'crash' ``` -## Usage +## Usage » Vector Store For a more detailed walkthrough of the `CrateDBVectorSearch` wrapper, there is also a corresponding [Jupyter notebook](/docs/extras/integrations/vectorstores/cratedb.html). -### Acquire text file +### Provide input data The example uses the canonical `state_of_the_union.txt`. ```shell -wget https://raw.githubusercontent.com/langchain-ai/langchain/v0.0.291/docs/extras/modules/state_of_the_union.txt +wget https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt ``` ### Set environment variables @@ -97,7 +103,7 @@ export CRATEDB_CONNECTION_STRING=crate://crate@localhost ``` ```python -from langchain.document_loaders import TextLoader +from langchain.document_loaders import UnstructuredURLLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import CrateDBVectorSearch @@ -105,7 +111,7 @@ from langchain.vectorstores import CrateDBVectorSearch def main(): # Load the document, split it into chunks, embed each chunk and load it into the vector store. - raw_documents = TextLoader("state_of_the_union.txt").load() + raw_documents = UnstructuredURLLoader("https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt").load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) documents = text_splitter.split_documents(raw_documents) db = CrateDBVectorSearch.from_documents(documents, OpenAIEmbeddings()) @@ -120,6 +126,37 @@ if __name__ == "__main__": ``` +## Usage » Document Loader + +For a more detailed walkthrough of the `CrateDBLoader`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/document_loaders/cratedb.html). + + +### Provide input data +```shell +wget https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql +crash < ./example_data/mlb_teams_2012.sql +crash --command "REFRESH TABLE mlb_teams_2012;" +``` + +### Load documents by SQL query +```python +from langchain.document_loaders import CrateDBLoader +from pprint import pprint + +def main(): + loader = CrateDBLoader( + 'SELECT * FROM mlb_teams_2012 ORDER BY "Team" LIMIT 5;', + url="crate://crate@localhost/", + ) + documents = loader.load() + pprint(documents) + +if __name__ == "__main__": + main() +``` + + [CrateDB]: https://github.com/crate/crate [CrateDB Cloud]: https://crate.io/product [CrateDB Cloud Console]: https://console.cratedb.cloud/ diff --git a/libs/community/tests/integration_tests/document_loaders/docker-compose/cratedb.yml b/libs/community/tests/integration_tests/document_loaders/docker-compose/cratedb.yml new file mode 100644 index 0000000000000..b547b2c766f20 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/docker-compose/cratedb.yml @@ -0,0 +1,20 @@ +version: "3" + +services: + postgresql: + image: crate/crate:nightly + environment: + - CRATE_HEAP_SIZE=4g + ports: + - "4200:4200" + - "5432:5432" + command: | + crate -Cdiscovery.type=single-node + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:4200/ || exit 1", + ] + interval: 5s + retries: 60 diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py new file mode 100644 index 0000000000000..eec3a428a74e8 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py @@ -0,0 +1,146 @@ +""" +Test SQLAlchemy/CrateDB document loader functionality. + +cd tests/integration_tests/document_loaders/docker-compose +docker-compose -f cratedb.yml up +""" +import logging +import os +import unittest + +import pytest +import sqlalchemy as sa +import sqlparse + +from langchain.document_loaders import CrateDBLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + +try: + import crate.client.sqlalchemy # noqa: F401 + + crate_client_installed = True +except ImportError: + crate_client_installed = False + + +CONNECTION_STRING = os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" +) + + +@pytest.fixture +def engine() -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(CONNECTION_STRING, echo=False) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sqlparse.split(sql_statements): + connection.execute(sa.text(statement)) + connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) + connection.commit() + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_no_options() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_page_content_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_metadata_columns() -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + "SELECT 1 AS a, 2 AS b", + url=CONNECTION_STRING, + page_content_columns=["a"], + metadata_columns=["b"], + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {"b": 2} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_sql(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + loader = CrateDBLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + url=CONNECTION_STRING, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") +def test_cratedb_loader_real_data_with_selectable(provision_database: None) -> None: + """Test SQLAlchemy loader with CrateDB.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = CrateDBLoader( + query=select, + url=CONNECTION_STRING, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/langchain/langchain/document_loaders/cratedb.py b/libs/langchain/langchain/document_loaders/cratedb.py new file mode 100644 index 0000000000000..9e34b4d0cb9ec --- /dev/null +++ b/libs/langchain/langchain/document_loaders/cratedb.py @@ -0,0 +1,5 @@ +from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader + + +class CrateDBLoader(SQLAlchemyLoader): + pass diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql index 91029ddcd3563..9df72ef19954a 100644 --- a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +++ b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql @@ -1,4 +1,7 @@ --- psql postgresql://postgres@localhost < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +-- Provisioning table "mlb_teams_2012". +-- +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql +-- crash < mlb_teams_2012.sql DROP TABLE IF EXISTS mlb_teams_2012; CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); From b9015c906f6524356b287aaab8440cea5e53d4e8 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 17 Sep 2023 19:40:34 +0200 Subject: [PATCH 05/28] Community: Generalize `SQLChatMessageHistory` to improve code reusability --- .../chat_message_histories/sql.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index 2c3b2351c471d..14c0fe79db5d7 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -10,12 +10,13 @@ List, Optional, Sequence, + Type, Union, cast, ) from langchain_core._api import deprecated, warn_deprecated -from sqlalchemy import Column, Integer, Text, delete, select +from sqlalchemy import Column, Integer, Select, Text, create_engine, delete, select try: from sqlalchemy.orm import declarative_base @@ -27,7 +28,6 @@ message_to_dict, messages_from_dict, ) -from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -38,7 +38,6 @@ Session as SQLSession, ) from sqlalchemy.orm import ( - declarative_base, scoped_session, sessionmaker, ) @@ -55,6 +54,10 @@ class BaseMessageConverter(ABC): """Convert BaseMessage to the SQLAlchemy model.""" + @abstractmethod + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + @abstractmethod def from_sql_model(self, sql_message: Any) -> BaseMessage: """Convert a SQLAlchemy model to a BaseMessage instance.""" @@ -146,6 +149,8 @@ class SQLChatMessageHistory(BaseChatMessageHistory): """ + DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter + @property @deprecated("0.2.2", removal="1.0", alternative="session_maker") def Session(self) -> Union[scoped_session, async_sessionmaker]: @@ -220,7 +225,9 @@ def __init__( self.session_maker = scoped_session(sessionmaker(bind=self.engine)) self.session_id_field_name = session_id_field_name - self.converter = custom_message_converter or DefaultMessageConverter(table_name) + self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER( + table_name + ) self.sql_model_class = self.converter.get_sql_model_class() if not hasattr(self.sql_model_class, session_id_field_name): raise ValueError("SQL model class must have session_id column") @@ -241,6 +248,17 @@ async def _acreate_table_if_not_exists(self) -> None: await conn.run_sync(self.sql_model_class.metadata.create_all) self._table_created = True + def _messages_query(self) -> Select: + """Construct an SQLAlchemy selectable to query for messages""" + return ( + select(self.sql_model_class) + .where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id + ) + .order_by(self.sql_model_class.id.asc()) + ) + @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" From 8a0f3d6dd9ab6fc3e63e4f5e47473d4f7f6b5483 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 17 Sep 2023 19:44:57 +0200 Subject: [PATCH 06/28] CrateDB memory: Add conversational memory support The implementation is based on the generic `SQLChatMessageHistory`. --- .../memory/cratedb_chat_message_history.ipynb | 357 ++++++++++++++++++ docs/docs/integrations/providers/cratedb.mdx | 31 +- .../chat_message_histories/cratedb.py | 113 ++++++ .../memory/chat_message_histories/__init__.py | 3 + .../memory/chat_message_histories/cratedb.py | 113 ++++++ .../integration_tests/memory/test_cratedb.py | 170 +++++++++ 6 files changed, 785 insertions(+), 2 deletions(-) create mode 100644 docs/docs/integrations/memory/cratedb_chat_message_history.ipynb create mode 100644 libs/community/langchain_community/chat_message_histories/cratedb.py create mode 100644 libs/langchain/langchain/memory/chat_message_histories/cratedb.py create mode 100644 libs/langchain/tests/integration_tests/memory/test_cratedb.py diff --git a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb new file mode 100644 index 0000000000000..f51f5f1d63fca --- /dev/null +++ b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# CrateDB Chat Message History\n", + "\n", + "This notebook demonstrates how to use the `CrateDBChatMessageHistory`\n", + "to manage chat history in CrateDB, for supporting conversational memory." + ], + "metadata": { + "collapsed": false + }, + "id": "f22eab3f84cbeb37" + }, + { + "cell_type": "markdown", + "source": [ + "## Prerequisites" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!#pip install 'langchain[cratedb]'" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuration\n", + "\n", + "To use the storage wrapper, you will need to configure two details.\n", + "\n", + "1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n", + "2. Database connection string: An SQLAlchemy-compatible URI that specifies the database\n", + " connection. It will be passed to SQLAlchemy create_engine function." + ], + "metadata": { + "collapsed": false + }, + "id": "f8f2830ee9ca1e01" + }, + { + "cell_type": "code", + "execution_count": 52, + "outputs": [], + "source": [ + "from langchain.memory.chat_message_histories import CrateDBChatMessageHistory\n", + "\n", + "CONNECTION_STRING = \"crate://crate@localhost:4200/?schema=example\"\n", + "\n", + "chat_message_history = CrateDBChatMessageHistory(\n", + "\tsession_id=\"test_session\",\n", + "\tconnection_string=CONNECTION_STRING\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic Usage" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 53, + "outputs": [], + "source": [ + "chat_message_history.add_user_message(\"Hello\")\n", + "chat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.077748Z", + "start_time": "2023-08-28T10:04:36.105894Z" + } + }, + "id": "4576e914a866fb40" + }, + { + "cell_type": "code", + "execution_count": 61, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.929396Z", + "start_time": "2023-08-28T10:04:38.915727Z" + } + }, + "id": "b476688cbb32ba90" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Storage Model\n", + "\n", + "The default data model, which stores information about conversation messages only\n", + "has two slots for storing message details, the session id, and the message dictionary.\n", + "\n", + "If you want to store additional information, like message date, author, language etc.,\n", + "please provide an implementation for a custom message converter.\n", + "\n", + "This example demonstrates how to create a custom message converter, by implementing\n", + "the `BaseMessageConverter` interface." + ], + "metadata": { + "collapsed": false + }, + "id": "2e5337719d5614fd" + }, + { + "cell_type": "code", + "execution_count": 55, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from typing import Any\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n", + "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", + "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", + "\n", + "import sqlalchemy as sa\n", + "from sqlalchemy.orm import declarative_base\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "\n", + "class CustomMessage(Base):\n", + "\t__tablename__ = \"custom_message_store\"\n", + "\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tsession_id = sa.Column(sa.Text)\n", + "\ttype = sa.Column(sa.Text)\n", + "\tcontent = sa.Column(sa.Text)\n", + "\tcreated_at = sa.Column(sa.DateTime)\n", + "\tauthor_email = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverter(BaseMessageConverter):\n", + "\tdef __init__(self, author_email: str):\n", + "\t\tself.author_email = author_email\n", + "\t\n", + "\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n", + "\t\tif sql_message.type == \"human\":\n", + "\t\t\treturn HumanMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"ai\":\n", + "\t\t\treturn AIMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == \"system\":\n", + "\t\t\treturn SystemMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telse:\n", + "\t\t\traise ValueError(f\"Unknown message type: {sql_message.type}\")\n", + "\t\n", + "\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", + "\t\tnow = datetime.now()\n", + "\t\treturn CustomMessage(\n", + "\t\t\tsession_id=session_id,\n", + "\t\t\ttype=message.type,\n", + "\t\t\tcontent=message.content,\n", + "\t\t\tcreated_at=now,\n", + "\t\t\tauthor_email=self.author_email\n", + "\t\t)\n", + "\t\n", + "\tdef get_sql_model_class(self) -> Any:\n", + "\t\treturn CustomMessage\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverter(\n", + "\t\t\tauthor_email=\"test@example.com\"\n", + "\t\t)\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:41.510498Z", + "start_time": "2023-08-28T10:04:41.494912Z" + } + }, + "id": "fdfde84c07d071bb" + }, + { + "cell_type": "code", + "execution_count": 60, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:43.497990Z", + "start_time": "2023-08-28T10:04:43.492517Z" + } + }, + "id": "4a6a54d8a9e2856f" + }, + { + "cell_type": "markdown", + "source": [ + "## Custom Name for Session Column\n", + "\n", + "The session id, a unique token identifying the session, is an important property of\n", + "this subsystem. If your database table stores it in a different column, you can use\n", + "the `session_id_field_name` keyword argument to adjust the name correspondingly." + ], + "metadata": { + "collapsed": false + }, + "id": "622aded629a1adeb" + }, + { + "cell_type": "code", + "execution_count": 57, + "outputs": [], + "source": [ + "import json\n", + "import typing as t\n", + "\n", + "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n", + "from langchain.schema import _message_to_dict\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "class MessageWithDifferentSessionIdColumn(Base):\n", + "\t__tablename__ = \"message_store_different_session_id\"\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tcustom_session_id = sa.Column(sa.Text)\n", + "\tmessage = sa.Column(sa.Text)\n", + "\n", + "\n", + "class CustomMessageConverterWithDifferentSessionIdColumn(CrateDBMessageConverter):\n", + " def __init__(self):\n", + " self.model_class = MessageWithDifferentSessionIdColumn\n", + "\n", + " def to_sql_model(self, message: BaseMessage, custom_session_id: str) -> t.Any:\n", + " return self.model_class(\n", + " custom_session_id=custom_session_id, message=json.dumps(_message_to_dict(message))\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\tBase.metadata.drop_all(bind=sa.create_engine(CONNECTION_STRING))\n", + "\n", + "\tchat_message_history = CrateDBChatMessageHistory(\n", + "\t\tsession_id=\"test_session\",\n", + "\t\tconnection_string=CONNECTION_STRING,\n", + "\t\tcustom_message_converter=CustomMessageConverterWithDifferentSessionIdColumn(),\n", + "\t\tsession_id_field_name=\"custom_session_id\",\n", + "\t)\n", + "\n", + "\tchat_message_history.add_user_message(\"Hello\")\n", + "\tchat_message_history.add_ai_message(\"Hi\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 58, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 948bdd85dee6a..1327719ce2b28 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -38,8 +38,8 @@ data, and query it using SQL. ## Features -The CrateDB adapter supports the _Vector Store_ and _Document Loader_ -subsystems of LangChain. +The CrateDB adapter supports the _Vector Store_, _Document Loader_, +and _Conversational Memory_ subsystems of LangChain. ### Vector Store @@ -57,6 +57,10 @@ Supports: `CrateDBLoader` provides loading documents from a database table by an SQL query expression or an SQLAlchemy selectable instance. +### Conversational Memory + +`CrateDBChatMessageHistory` uses CrateDB to manage conversation history. + ## Installation and Setup @@ -157,6 +161,29 @@ if __name__ == "__main__": ``` +## Usage » Conversational Memory + +For a more detailed walkthrough of the `CrateDBChatMessageHistory`, there is also a corresponding +[Jupyter notebook](/docs/extras/integrations/memory/cratedb_chat_message_history.html). + +```python +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from pprint import pprint + +def main(): + chat_message_history = CrateDBChatMessageHistory( + session_id="test_session", + connection_string="crate://crate@localhost/", + ) + chat_message_history.add_user_message("Hello") + chat_message_history.add_ai_message("Hi") + pprint(chat_message_history) + +if __name__ == "__main__": + main() +``` + + [CrateDB]: https://github.com/crate/crate [CrateDB Cloud]: https://crate.io/product [CrateDB Cloud Console]: https://console.cratedb.cloud/ diff --git a/libs/community/langchain_community/chat_message_histories/cratedb.py b/libs/community/langchain_community/chat_message_histories/cratedb.py new file mode 100644 index 0000000000000..45e287ec1f344 --- /dev/null +++ b/libs/community/langchain_community/chat_message_histories/cratedb.py @@ -0,0 +1,113 @@ +import json +import typing as t + +import sqlalchemy as sa +from cratedb_toolkit.sqlalchemy import ( + patch_inspector, + polyfill_refresh_after_dml, + refresh_table, +) +from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict + +from langchain_community.chat_message_histories.sql import ( + BaseMessageConverter, + SQLChatMessageHistory, +) + + +def create_message_model(table_name, DynamicBase): # type: ignore + """ + Create a message model for a given table name. + + This is a specialized version for CrateDB for generating integer-based primary keys. + TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant + returning its integer value. + + Args: + table_name: The name of the table to use. + DynamicBase: The base class to use for the model. + + Returns: + The model class. + """ + + # Model is declared inside a function to be able to use a dynamic table name. + class Message(DynamicBase): + __tablename__ = table_name + id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) + session_id = sa.Column(sa.Text) + message = sa.Column(sa.Text) + + return Message + + +class CrateDBMessageConverter(BaseMessageConverter): + """ + The default message converter for CrateDBMessageConverter. + + It is the same as the generic `SQLChatMessageHistory` converter, + but swaps in a different `create_message_model` function. + """ + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, sa.orm.declarative_base()) + + def from_sql_model(self, sql_message: t.Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> t.Any: + return self.model_class + + +class CrateDBChatMessageHistory(SQLChatMessageHistory): + """ + It is the same as the generic `SQLChatMessageHistory` implementation, + but swaps in a different message converter by default. + """ + + DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: t.Optional[BaseMessageConverter] = None, + ): + # FIXME: Refactor elsewhere. + patch_inspector() + + super().__init__( + session_id, + connection_string, + table_name=table_name, + session_id_field_name=session_id_field_name, + custom_message_converter=custom_message_converter, + ) + + # TODO: Check how this can be improved. + polyfill_refresh_after_dml(self.Session) + + def _messages_query(self) -> sa.Select: + """ + Construct an SQLAlchemy selectable to query for messages. + For CrateDB, add an `ORDER BY` clause on the primary key. + """ + selectable = super()._messages_query() + selectable = selectable.order_by(self.sql_model_class.id) + return selectable + + def clear(self) -> None: + """ + Needed for CrateDB to synchronize data because `on_flush` does not catch it. + """ + outcome = super().clear() + with self.Session() as session: + refresh_table(session, self.sql_model_class) + return outcome diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index 91910137f627f..ff95fd21563d0 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -8,6 +8,7 @@ CassandraChatMessageHistory, ChatMessageHistory, CosmosDBChatMessageHistory, + CrateDBChatMessageHistory, DynamoDBChatMessageHistory, ElasticsearchChatMessageHistory, FileChatMessageHistory, @@ -34,6 +35,7 @@ "CassandraChatMessageHistory": "langchain_community.chat_message_histories", "ChatMessageHistory": "langchain_community.chat_message_histories", "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories", + "CrateDBChatMessageHistory": "langchain_community.chat_message_histories", "DynamoDBChatMessageHistory": "langchain_community.chat_message_histories", "ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories", "FileChatMessageHistory": "langchain_community.chat_message_histories", @@ -65,6 +67,7 @@ def __getattr__(name: str) -> Any: "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py new file mode 100644 index 0000000000000..19007176cb193 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py @@ -0,0 +1,113 @@ +import json +import typing as t + +import sqlalchemy as sa +from cratedb_toolkit.sqlalchemy import ( + patch_inspector, + polyfill_refresh_after_dml, + refresh_table, +) + +from langchain.memory.chat_message_histories.sql import ( + BaseMessageConverter, + SQLChatMessageHistory, +) +from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict + + +def create_message_model(table_name, DynamicBase): # type: ignore + """ + Create a message model for a given table name. + + This is a specialized version for CrateDB for generating integer-based primary keys. + TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant + returning its integer value. + + Args: + table_name: The name of the table to use. + DynamicBase: The base class to use for the model. + + Returns: + The model class. + """ + + # Model is declared inside a function to be able to use a dynamic table name. + class Message(DynamicBase): + __tablename__ = table_name + id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) + session_id = sa.Column(sa.Text) + message = sa.Column(sa.Text) + + return Message + + +class CrateDBMessageConverter(BaseMessageConverter): + """ + The default message converter for CrateDBMessageConverter. + + It is the same as the generic `SQLChatMessageHistory` converter, + but swaps in a different `create_message_model` function. + """ + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, sa.orm.declarative_base()) + + def from_sql_model(self, sql_message: t.Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> t.Any: + return self.model_class + + +class CrateDBChatMessageHistory(SQLChatMessageHistory): + """ + It is the same as the generic `SQLChatMessageHistory` implementation, + but swaps in a different message converter by default. + """ + + DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: t.Optional[BaseMessageConverter] = None, + ): + # FIXME: Refactor elsewhere. + patch_inspector() + + super().__init__( + session_id, + connection_string, + table_name=table_name, + session_id_field_name=session_id_field_name, + custom_message_converter=custom_message_converter, + ) + + # TODO: Check how this can be improved. + polyfill_refresh_after_dml(self.Session) + + def _messages_query(self) -> sa.Select: + """ + Construct an SQLAlchemy selectable to query for messages. + For CrateDB, add an `ORDER BY` clause on the primary key. + """ + selectable = super()._messages_query() + selectable = selectable.order_by(self.sql_model_class.id) + return selectable + + def clear(self) -> None: + """ + Needed for CrateDB to synchronize data because `on_flush` does not catch it. + """ + outcome = super().clear() + with self.Session() as session: + refresh_table(session, self.sql_model_class) + return outcome diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/langchain/tests/integration_tests/memory/test_cratedb.py new file mode 100644 index 0000000000000..2c00b5d2b200b --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_cratedb.py @@ -0,0 +1,170 @@ +import json +import os +from typing import Any, Generator, Tuple + +import pytest +import sqlalchemy as sa +from sqlalchemy import Column, Integer, Text +from sqlalchemy.orm import DeclarativeBase + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import CrateDBChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter +from langchain.schema.messages import AIMessage, HumanMessage, _message_to_dict + + +@pytest.fixture() +def connection_string() -> str: + return os.environ.get( + "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" + ) + + +@pytest.fixture() +def engine(connection_string: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(connection_string, echo=True) + + +@pytest.fixture(autouse=True) +def reset_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS test_table;")) + connection.commit() + + +@pytest.fixture() +def sql_histories( + connection_string: str, +) -> Generator[Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], None, None]: + """ + Provide the test cases with data fixtures. + """ + message_history = CrateDBChatMessageHistory( + session_id="123", connection_string=connection_string, table_name="test_table" + ) + # Create history for other session + other_history = CrateDBChatMessageHistory( + session_id="456", connection_string=connection_string, table_name="test_table" + ) + + yield message_history, other_history + message_history.clear() + other_history.clear() + + +def test_add_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, _ = sql_histories + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + + messages = history1.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_multiple_sessions( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + history1, history2 = sql_histories + + # first session + history1.add_user_message("Hello!") + history1.add_ai_message("Hi there!") + history1.add_user_message("Whats cracking?") + + # second session + history2.add_user_message("Hellox") + + messages1 = history1.messages + messages2 = history2.messages + + # Ensure the messages are added correctly in the first session + assert len(messages1) == 3, "waat" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + assert len(messages2) == 1 + assert len(messages1) == 3 + assert messages2[0].content == "Hellox" + assert messages1[0].content == "Hello!" + assert messages1[1].content == "Hi there!" + assert messages1[2].content == "Whats cracking?" + + +def test_clear_messages( + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + assert len(sql_history.messages) == 2 + # Now create another history with different session id + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 2 + # Now clear the first history + sql_history.clear() + assert len(sql_history.messages) == 0 + assert len(other_history.messages) == 1 + + +def test_model_no_session_id_field_error(connection_string: str) -> None: + class Base(DeclarativeBase): + pass + + class Model(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + test_field = Column(Text) + + class CustomMessageConverter(DefaultMessageConverter): + def get_sql_model_class(self) -> Any: + return Model + + with pytest.raises(ValueError): + CrateDBChatMessageHistory( + "test", + connection_string, + custom_message_converter=CustomMessageConverter("test_table"), + ) + + +def test_memory_with_message_store(connection_string: str) -> None: + """ + Test ConversationBufferMemory with a message store. + """ + # Setup CrateDB as a message store. + message_history = CrateDBChatMessageHistory( + connection_string=connection_string, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # Add a few messages. + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # Get the message history from the memory store and turn it into JSON. + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + # Verify the outcome. + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Clear the conversation history, and verify that. + memory.chat_memory.clear() + assert memory.chat_memory.messages == [] From 3330b0d781871ee48acebc9a2f8206c23c3ef741 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 27 Oct 2023 16:46:24 +0200 Subject: [PATCH 07/28] CrateDB vector: Fix usage when only reading, and not storing When not adding any embeddings upfront, the runtime model factory was not able to derive the vector dimension size, because the SQLAlchemy models have not been initialized correctly. --- .../vectorstores/test_cratedb.py | 95 ++++++++++++++++++- .../langchain/vectorstores/cratedb/base.py | 23 ++++- 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index d62f0a125f661..8f62919842fc0 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -14,7 +14,10 @@ from langchain.docstore.document import Document from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fake_embeddings import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, +) CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), @@ -40,6 +43,13 @@ def engine() -> sa.Engine: return sa.create_engine(CONNECTION_STRING, echo=False) +@pytest.fixture +def session(engine) -> sa.orm.Session: + with engine.connect() as conn: + with Session(conn) as session: + yield session + + @pytest.fixture(autouse=True) def drop_tables(engine: sa.Engine) -> None: """ @@ -89,6 +99,46 @@ def decode_output( return documents, scores +def ensure_collection(session: sa.orm.Session, name: str): + """ + Create a (fake) collection item. + """ + session.execute( + sa.text( + f""" + CREATE TABLE IF NOT EXISTS collection ( + uuid TEXT, + name TEXT, + cmetadata OBJECT + ); + """ + ) + ) + session.execute( + sa.text( + f""" + CREATE TABLE IF NOT EXISTS embedding ( + uuid TEXT, + collection_id TEXT, + embedding FLOAT_VECTOR(123), + document TEXT, + cmetadata OBJECT, + custom_id TEXT + ); + """ + ) + ) + try: + session.execute( + sa.text( + f"INSERT INTO collection (uuid, name, cmetadata) VALUES ('uuid-{name}', '{name}', {{}});" + ) + ) + session.execute(sa.text("REFRESH TABLE collection")) + except sa.exc.IntegrityError: + pass + + class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): """Fake embeddings functionality for testing.""" @@ -103,6 +153,19 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] +class ConsistentFakeEmbeddingsWithAdaDimension( + FakeEmbeddingsWithAdaDimension, ConsistentFakeEmbeddings +): + """ + Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts. + + Other than this, they also have a dimensionality, which is important in this case. + """ + + pass + + def test_cratedb_texts() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -274,6 +337,36 @@ def test_cratedb_collection_no_embedding_dimension() -> None: ) +def test_cratedb_collection_read_only(session) -> None: + """ + Test using a collection, without adding any embeddings upfront. + + This happens when just invoking the "retrieval" case. + + In this scenario, embedding dimensionality needs to be figured out + from the supplied `embedding_function`. + """ + + # Create a fake collection item. + ensure_collection(session, "baz2") + + # This test case needs an embedding _with_ dimensionality. + # Otherwise, the data access layer is unable to figure it + # out at runtime. + embedding = ConsistentFakeEmbeddingsWithAdaDimension() + + vectorstore = CrateDBVectorSearch( + collection_name="baz2", + connection_string=CONNECTION_STRING, + embedding_function=embedding, + ) + output = vectorstore.similarity_search("foo", k=1) + + # No documents/embeddings have been loaded, the collection is empty. + # This is why there are also no results. + assert output == [] + + def test_cratedb_with_filter_in_set() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index ca1c21fad68a7..b6158dd0d4769 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -112,9 +112,26 @@ def __post_init__( # TODO: See what can be improved here. polyfill_refresh_after_dml(self.Session) + # Need to defer initialization, because dimension size + # can only be figured out at runtime. self.CollectionStore = None self.EmbeddingStore = None + def _init_models(self, embedding: List[float]): + """ + Create SQLAlchemy models at runtime, when not established yet. + """ + if self.CollectionStore is not None and self.EmbeddingStore is not None: + return + + size = len(embedding) + self._init_models_with_dimensionality(size=size) + + def _init_models_with_dimensionality(self, size: int): + from langchain.vectorstores.cratedb.model import model_factory + + self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size) + def get_collection( self, session: sqlalchemy.orm.Session ) -> Optional["CollectionStore"]: @@ -140,10 +157,7 @@ def add_embeddings( metadatas: List of metadatas associated with the texts. kwargs: vectorstore specific parameters """ - from langchain.vectorstores.cratedb.model import model_factory - - dimensions = len(embeddings[0]) - self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=dimensions) + self._init_models(embeddings[0]) if self.pre_delete_collection: self.delete_collection() self.create_tables_if_not_exists() @@ -223,6 +237,7 @@ def _query_collection( filter: Optional[Dict[str, str]] = None, ) -> List[Any]: """Query the collection.""" + self._init_models(embedding) with self.Session() as session: collection = self.get_collection(session) if not collection: From 38c23744f4cc937a0d3387bc382a9914391be97e Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 27 Oct 2023 22:16:51 +0200 Subject: [PATCH 08/28] CrateDB vector: Unable to invoke `add_embeddings` without embeddings --- libs/langchain/langchain/vectorstores/cratedb/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index b6158dd0d4769..e7c651bea9822 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -157,6 +157,8 @@ def add_embeddings( metadatas: List of metadatas associated with the texts. kwargs: vectorstore specific parameters """ + if not embeddings: + return [] self._init_models(embeddings[0]) if self.pre_delete_collection: self.delete_collection() From 0f6adf92ce69c1b57a2065e4be94ef51c0ecdcd3 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 20 Nov 2023 21:34:09 +0100 Subject: [PATCH 09/28] CrateDB vector: Improve SQLAlchemy model factory From now on, _all_ instances of SQLAlchemy model types will be created at runtime through the `ModelFactory` utility. By using `__table_args__ = {"keep_existing": True}` on the ORM entity definitions, this seems to work well, even with multiple invocations of `CrateDBVectorSearch.from_texts()` using different `collection_name` argument values. While being at it, this patch also fixes a few linter errors. --- .../vectorstores/test_cratedb.py | 37 ++-- .../vectorstores/cratedb/__init__.py | 3 +- .../langchain/vectorstores/cratedb/base.py | 68 +++---- .../langchain/vectorstores/cratedb/model.py | 168 ++++++++++-------- 4 files changed, 154 insertions(+), 122 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 8f62919842fc0..8f054fc07a0b3 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,7 +5,7 @@ docker-compose -f cratedb.yml up """ import os -from typing import List, Tuple +from typing import Generator, List, Tuple import pytest import sqlalchemy as sa @@ -13,7 +13,8 @@ from sqlalchemy.orm import Session from langchain.docstore.document import Document -from langchain.vectorstores.cratedb import BaseModel, CrateDBVectorSearch +from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, @@ -44,7 +45,7 @@ def engine() -> sa.Engine: @pytest.fixture -def session(engine) -> sa.orm.Session: +def session(engine: sa.Engine) -> Generator[sa.orm.Session, None, None]: with engine.connect() as conn: with Session(conn) as session: yield session @@ -56,7 +57,8 @@ def drop_tables(engine: sa.Engine) -> None: Drop database tables. """ try: - BaseModel.metadata.drop_all(engine, checkfirst=False) + mf = ModelFactory() + mf.BaseModel.metadata.drop_all(engine, checkfirst=False) except Exception as ex: if "RelationUnknown" not in str(ex): raise @@ -69,18 +71,13 @@ def prune_tables(engine: sa.Engine) -> None: """ with engine.connect() as conn: with Session(conn) as session: - from langchain.vectorstores.cratedb.model import model_factory - - # While it does not have any function here, you will still need to supply a - # dummy dimension size value for deleting records from tables. - CollectionStore, EmbeddingStore = model_factory(dimensions=1024) - + mf = ModelFactory() try: - session.query(CollectionStore).delete() + session.query(mf.CollectionStore).delete() except ProgrammingError: pass try: - session.query(EmbeddingStore).delete() + session.query(mf.EmbeddingStore).delete() except ProgrammingError: pass @@ -99,13 +96,13 @@ def decode_output( return documents, scores -def ensure_collection(session: sa.orm.Session, name: str): +def ensure_collection(session: sa.orm.Session, name: str) -> None: """ Create a (fake) collection item. """ session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS collection ( uuid TEXT, name TEXT, @@ -116,7 +113,7 @@ def ensure_collection(session: sa.orm.Session, name: str): ) session.execute( sa.text( - f""" + """ CREATE TABLE IF NOT EXISTS embedding ( uuid TEXT, collection_id TEXT, @@ -131,7 +128,8 @@ def ensure_collection(session: sa.orm.Session, name: str): try: session.execute( sa.text( - f"INSERT INTO collection (uuid, name, cmetadata) VALUES ('uuid-{name}', '{name}', {{}});" + f"INSERT INTO collection (uuid, name, cmetadata) " + f"VALUES ('uuid-{name}', '{name}', {{}});" ) ) session.execute(sa.text("REFRESH TABLE collection")) @@ -325,7 +323,7 @@ def test_cratedb_collection_with_metadata() -> None: def test_cratedb_collection_no_embedding_dimension() -> None: """Test end to end collection construction""" cratedb_vector = CrateDBVectorSearch( - embedding_function=None, + embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, pre_delete_collection=True, ) @@ -333,11 +331,12 @@ def test_cratedb_collection_no_embedding_dimension() -> None: with pytest.raises(RuntimeError) as ex: cratedb_vector.get_collection(session) assert ex.match( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) -def test_cratedb_collection_read_only(session) -> None: +def test_cratedb_collection_read_only(session: Session) -> None: """ Test using a collection, without adding any embeddings upfront. diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py index 303a52babeaea..14b02ad126867 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/__init__.py +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -1,6 +1,5 @@ -from .base import BaseModel, CrateDBVectorSearch +from .base import CrateDBVectorSearch __all__ = [ - "BaseModel", "CrateDBVectorSearch", ] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index e7c651bea9822..ec3c4c19d70a6 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -2,7 +2,6 @@ import enum import math -import uuid from typing import ( Any, Callable, @@ -20,11 +19,12 @@ polyfill_refresh_after_dml, refresh_table, ) -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import sessionmaker from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.cratedb.model import ModelFactory from langchain.vectorstores.pgvector import PGVector @@ -38,23 +38,10 @@ class DistanceStrategy(str, enum.Enum): DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN -Base = declarative_base() # type: Any -# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" -def generate_uuid() -> str: - return str(uuid.uuid4()) - - -class BaseModel(Base): - """Base model for the SQL stores.""" - - __abstract__ = True - uuid = sqlalchemy.Column(sqlalchemy.String, primary_key=True, default=generate_uuid) - - def _results_to_docs(docs_and_scores: Any) -> List[Document]: """Return docs from docs and scores.""" return [doc for doc, _ in docs_and_scores] @@ -114,30 +101,47 @@ def __post_init__( # Need to defer initialization, because dimension size # can only be figured out at runtime. - self.CollectionStore = None - self.EmbeddingStore = None + self.BaseModel = None + self.CollectionStore = None # type: ignore[assignment] + self.EmbeddingStore = None # type: ignore[assignment] + + def __del__(self) -> None: + """ + Work around premature session close. + + sqlalchemy.orm.exc.DetachedInstanceError: Parent instance is not bound + to a Session; lazy load operation of attribute 'embeddings' cannot proceed. + -- https://docs.sqlalchemy.org/en/20/errors.html#error-bhk3 + + TODO: Review! + """ # noqa: E501 + pass - def _init_models(self, embedding: List[float]): + def _init_models(self, embedding: List[float]) -> None: """ Create SQLAlchemy models at runtime, when not established yet. """ + + # TODO: Use a better way to run this only once. if self.CollectionStore is not None and self.EmbeddingStore is not None: return size = len(embedding) self._init_models_with_dimensionality(size=size) - def _init_models_with_dimensionality(self, size: int): - from langchain.vectorstores.cratedb.model import model_factory - - self.CollectionStore, self.EmbeddingStore = model_factory(dimensions=size) + def _init_models_with_dimensionality(self, size: int) -> None: + mf = ModelFactory(dimensions=size) + self.BaseModel, self.CollectionStore, self.EmbeddingStore = ( + mf.BaseModel, # type: ignore[assignment] + mf.CollectionStore, + mf.EmbeddingStore, + ) - def get_collection( - self, session: sqlalchemy.orm.Session - ) -> Optional["CollectionStore"]: + def get_collection(self, session: sqlalchemy.orm.Session) -> Any: if self.CollectionStore is None: raise RuntimeError( - "Collection can't be accessed without specifying dimension size of embedding vectors" + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" ) return self.CollectionStore.get_by_name(session, self.collection_name) @@ -170,15 +174,17 @@ def add_embeddings( def create_tables_if_not_exists(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.create_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.create_all(self._engine) def drop_tables(self) -> None: """ - Need to overwrite because `Base` is different from upstream. + Need to overwrite because this `Base` is different from parent's `Base`. """ - Base.metadata.drop_all(self._engine) + mf = ModelFactory() + mf.Base.metadata.drop_all(self._engine) def delete( self, diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index ee42e7269dc9d..b8b14c05010f5 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -1,84 +1,112 @@ -from functools import lru_cache -from typing import Optional, Tuple +import uuid +from typing import Any, Optional, Tuple import sqlalchemy from crate.client.sqlalchemy.types import ObjectType -from sqlalchemy.orm import Session, relationship +from sqlalchemy.orm import Session, declarative_base, relationship -from langchain.vectorstores.cratedb.base import BaseModel from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector -@lru_cache -def model_factory(dimensions: int): - class CollectionStore(BaseModel): - """Collection store.""" - - __tablename__ = "collection" - - name = sqlalchemy.Column(sqlalchemy.String) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - try: - return ( - session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """ - Get or create a collection. - Returns [Collection, bool] where the bool is True if the collection was created. - """ - created = False - collection = cls.get_by_name(session, name) - if collection: +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class ModelFactory: + """Provide SQLAlchemy model objects at runtime.""" + + def __init__(self, dimensions: Optional[int] = None): + # While it does not have any function here, you will still need to supply a + # dummy dimension size value for operations like deleting records. + self.dimensions = dimensions or 1024 + + Base: Any = declarative_base() + + # Optional: Use a custom schema for the langchain tables. + # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + + class BaseModel(Base): + """Base model for the SQL stores.""" + + __abstract__ = True + uuid = sqlalchemy.Column( + sqlalchemy.String, primary_key=True, default=generate_uuid + ) + + class CollectionStore(BaseModel): + """Collection store.""" + + __tablename__ = "collection" + __table_args__ = {"keep_existing": True} + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name( + cls, session: Session, name: str + ) -> Optional["CollectionStore"]: + try: + return ( + session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True + if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True return collection, created - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() - created = True - return collection, created + class EmbeddingStore(BaseModel): + """Embedding store.""" - class EmbeddingStore(BaseModel): - """Embedding store.""" + __tablename__ = "embedding" + __table_args__ = {"keep_existing": True} - __tablename__ = "embedding" + collection_id = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship("CollectionStore", back_populates="embeddings") - collection_id = sqlalchemy.Column( - sqlalchemy.String, - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship("CollectionStore", back_populates="embeddings") + embedding = sqlalchemy.Column(FloatVector(self.dimensions)) + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) - embedding = sqlalchemy.Column(FloatVector(dimensions)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - # custom_id : any user defined id - custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - - return CollectionStore, EmbeddingStore + self.Base = Base + self.BaseModel = BaseModel + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore From 2d30228af926e2291f13f9d3b5be2a1586ed7fed Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 00:25:57 +0100 Subject: [PATCH 10/28] CrateDB vector: Fix cascading deletes When deleting a collection, also delete its associated embeddings. --- .../vectorstores/test_cratedb.py | 46 +++++++++++++++++++ .../langchain/vectorstores/cratedb/model.py | 3 +- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 8f054fc07a0b3..d573843d2f02f 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -299,6 +299,52 @@ def test_cratedb_with_filter_no_match() -> None: assert output == [] +def test_cratedb_collection_delete() -> None: + """ + Test end to end collection construction and deletion. + Uses two different collections of embeddings. + """ + + store_foo = CrateDBVectorSearch.from_texts( + texts=["foo"], + collection_name="test_collection_foo", + collection_metadata={"category": "foo"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "foo"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store_bar = CrateDBVectorSearch.from_texts( + texts=["bar"], + collection_name="test_collection_bar", + collection_metadata={"category": "bar"}, + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=[{"document": "bar"}], + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + session = store_foo.Session() + + # Verify data in database. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Delete first collection. + store_foo.delete_collection() + + # Verify that the "foo" collection has been deleted. + collection_foo = store_foo.get_collection(session) + collection_bar = store_bar.get_collection(session) + assert collection_foo is None + assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} + + # Verify that associated embeddings also have been deleted. + embeddings_count = session.query(store_foo.EmbeddingStore).count() + assert embeddings_count == 1 + + def test_cratedb_collection_with_metadata() -> None: """Test end to end collection construction""" texts = ["foo", "bar", "baz"] diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index b8b14c05010f5..656de41bf4d45 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -45,7 +45,8 @@ class CollectionStore(BaseModel): embeddings = relationship( "EmbeddingStore", back_populates="collection", - passive_deletes=True, + cascade="all, delete-orphan", + passive_deletes=False, ) @classmethod From 9dfc828266ba95b5fd7f24ed4dfa5da7bd892965 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 13:12:21 +0100 Subject: [PATCH 11/28] CrateDB vector: Add CrateDBVectorSearchMultiCollection It is a special adapter which provides similarity search across multiple collections. It can not be used for indexing documents. --- docs/docs/integrations/providers/cratedb.mdx | 3 + .../integrations/vectorstores/cratedb.ipynb | 88 +++++++----- .../vectorstores/test_cratedb.py | 131 +++++++++++++----- .../vectorstores/cratedb/__init__.py | 2 + .../langchain/vectorstores/cratedb/base.py | 22 ++- .../vectorstores/cratedb/extended.py | 92 ++++++++++++ .../langchain/vectorstores/cratedb/model.py | 15 +- .../cache/fake_embeddings.py | 6 +- 8 files changed, 286 insertions(+), 73 deletions(-) create mode 100644 libs/langchain/langchain/vectorstores/cratedb/extended.py diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 1327719ce2b28..220c35b86fd1c 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -106,6 +106,9 @@ export OPENAI_API_KEY=foobar export CRATEDB_CONNECTION_STRING=crate://crate@localhost ``` +### Example + +Load and index documents, and invoke query. ```python from langchain.document_loaders import UnstructuredURLLoader from langchain.embeddings.openai import OpenAIEmbeddings diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb index 462e721bfff40..06430e6355ae9 100644 --- a/docs/docs/integrations/vectorstores/cratedb.ipynb +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -182,7 +182,11 @@ { "cell_type": "markdown", "source": [ - "Next, you will read input data, and tokenize it." + "## Load and Index Documents\n", + "\n", + "Next, you will read input data, and tokenize it. The module will create a table\n", + "with the name of the collection. Make sure the collection name is unique, and\n", + "that you have the permission to create a table." ], "metadata": { "collapsed": false @@ -196,7 +200,18 @@ "loader = UnstructuredURLLoader(\"https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt\")\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", - "docs = text_splitter.split_documents(documents)" + "docs = text_splitter.split_documents(documents)\n", + "\n", + "COLLECTION_NAME = \"state_of_the_union_test\"\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "db = CrateDBVectorSearch.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=COLLECTION_NAME,\n", + " connection_string=CONNECTION_STRING,\n", + ")" ], "metadata": { "collapsed": false, @@ -208,39 +223,15 @@ { "cell_type": "markdown", "source": [ - "## Similarity Search with Euclidean Distance (Default)\n", + "## Search Documents\n", "\n", - "The module will create a table with the name of the collection. Make sure\n", - "the collection name is unique and that you have the permission to create\n", - "a table." + "### Similarity Search with Euclidean Distance\n", + "Searching by euclidean distance is the default." ], "metadata": { "collapsed": false } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-09T08:04:16.696625Z", - "start_time": "2023-09-09T08:02:31.817790Z" - } - }, - "outputs": [], - "source": [ - "COLLECTION_NAME = \"state_of_the_union_test\"\n", - "\n", - "embeddings = OpenAIEmbeddings()\n", - "\n", - "db = CrateDBVectorSearch.from_documents(\n", - " embedding=embeddings,\n", - " documents=docs,\n", - " collection_name=COLLECTION_NAME,\n", - " connection_string=CONNECTION_STRING,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -277,7 +268,7 @@ { "cell_type": "markdown", "source": [ - "## Maximal Marginal Relevance Search (MMR)\n", + "### Maximal Marginal Relevance Search (MMR)\n", "Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." ], "metadata": { @@ -318,11 +309,40 @@ } } }, + { + "cell_type": "markdown", + "source": [ + "### Searching in Multiple Collections\n", + "`CrateDBVectorSearchMultiCollection` is a special adapter which provides similarity search across\n", + "multiple collections. It can not be used for indexing documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection\n", + "\n", + "multisearch = CrateDBVectorSearchMultiCollection(\n", + " collection_names=[\"test_collection_1\", \"test_collection_2\"],\n", + " embedding_function=embeddings,\n", + " connection_string=CONNECTION_STRING,\n", + ")\n", + "docs_with_score = multisearch.similarity_search_with_score(query)" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Working with the vector store\n", + "## Working with the Vector Store\n", "\n", "In the example above, you created a vector store from scratch. When\n", "aiming to work with an existing vector store, you can initialize it directly." @@ -345,7 +365,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Add documents\n", + "### Add Documents\n", "\n", "You can also add documents to an existing vector store." ] @@ -390,7 +410,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Overwriting a vector store\n", + "### Overwriting a Vector Store\n", "\n", "If you have an existing collection, you can overwrite it by using `from_documents`,\n", "aad setting `pre_delete_collection = True`." @@ -433,7 +453,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Using a vector store as a retriever" + "### Using a Vector Store as a Retriever" ] }, { diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index d573843d2f02f..5a732ca5332f9 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,15 +5,17 @@ docker-compose -f cratedb.yml up """ import os -from typing import Generator, List, Tuple +from typing import Dict, Generator, List, Tuple import pytest import sqlalchemy as sa +import sqlalchemy.orm from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm import Session from langchain.docstore.document import Document from langchain.vectorstores.cratedb import CrateDBVectorSearch +from langchain.vectorstores.cratedb.extended import CrateDBVectorSearchMultiCollection from langchain.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, @@ -151,17 +153,17 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] -class ConsistentFakeEmbeddingsWithAdaDimension( - FakeEmbeddingsWithAdaDimension, ConsistentFakeEmbeddings -): +class ConsistentFakeEmbeddingsWithAdaDimension(ConsistentFakeEmbeddings): """ - Fake embeddings which remember all the texts seen so far to return consistent - vectors for the same texts. + Fake embeddings which remember all the texts seen so far to return + consistent vectors for the same texts. - Other than this, they also have a dimensionality, which is important in this case. + Other than this, they also have a fixed dimensionality, which is + important in this case. """ - pass + def __init__(self, *args: List, **kwargs: Dict) -> None: + super().__init__(dimensionality=ADA_TOKEN_COUNT) def test_cratedb_texts() -> None: @@ -223,12 +225,7 @@ def test_cratedb_with_metadatas_with_scores() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1) - # TODO: Original: - # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 - assert output in [ - [(Document(page_content="foo", metadata={"page": "0"}), 1.0828735)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.1307646)], - ] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)] def test_cratedb_with_filter_match() -> None: @@ -247,9 +244,8 @@ def test_cratedb_with_filter_match() -> None: # TODO: Original: # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 assert output in [ - [(Document(page_content="foo", metadata={"page": "0"}), 1.2615292)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.3979403)], - [(Document(page_content="foo", metadata={"page": "0"}), 1.5065275)], + [(Document(page_content="foo", metadata={"page": "0"}), 2.1307645)], + [(Document(page_content="foo", metadata={"page": "0"}), 2.3150668)], ] @@ -265,10 +261,9 @@ def test_cratedb_with_filter_distant_match() -> None: connection_string=CONNECTION_STRING, pre_delete_collection=True, ) + output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"}) # TODO: Original: - # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) - output = docsearch.similarity_search_with_score("foo", k=3, filter={"page": "2"}) - # TODO: Original: + # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) # noqa: E501 # assert output == [ # (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501 # ] @@ -277,9 +272,10 @@ def test_cratedb_with_filter_distant_match() -> None: Document(page_content="baz", metadata={"page": "2"}), ] assert scores in [ - [0.5], - [0.6], - [0.7], + [1.3], + [1.5], + [1.6], + [1.7], ] @@ -439,7 +435,7 @@ def test_cratedb_with_filter_in_set() -> None: Document(page_content="foo", metadata={"page": "0"}), Document(page_content="baz", metadata={"page": "2"}), ] - assert scores == [2.1, 1.3] + assert scores == [3.0, 2.2] def test_cratedb_delete_docs() -> None: @@ -498,7 +494,7 @@ def test_cratedb_relevance_score() -> None: Document(page_content="bar", metadata={"page": "1"}), Document(page_content="baz", metadata={"page": "2"}), ] - assert scores == [0.8, 0.4, 0.2] + assert scores == [1.4, 1.1, 0.8] def test_cratedb_retriever_search_threshold() -> None: @@ -516,9 +512,7 @@ def test_cratedb_retriever_search_threshold() -> None: retriever = docsearch.as_retriever( search_type="similarity_score_threshold", - # TODO: Original: - # search_kwargs={"k": 3, "score_threshold": 0.999}, - search_kwargs={"k": 3, "score_threshold": 0.333}, + search_kwargs={"k": 3, "score_threshold": 0.999}, ) output = retriever.get_relevant_documents("summer") assert output == [ @@ -574,10 +568,77 @@ def test_cratedb_max_marginal_relevance_search_with_score() -> None: pre_delete_collection=True, ) output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) - # TODO: Original: - # assert output == [(Document(page_content="foo"), 0.0)] - assert output in [ - [(Document(page_content="foo"), 1.0606961)], - [(Document(page_content="foo"), 1.0828735)], - [(Document(page_content="foo"), 1.1307646)], - ] + assert output == [(Document(page_content="foo"), 2.0)] + + +def test_cratedb_multicollection_search_success() -> None: + """ + `CrateDBVectorSearchMultiCollection` provides functionality for + searching multiple collections. + """ + + store_1 = CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection_1", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + _ = CrateDBVectorSearch.from_texts( + texts=["John", "Doe"], + collection_name="test_collection_2", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + + # Probe the first store. + output = store_1.similarity_search("Räuber", k=1) + assert Document(page_content="Räuber") in output[:2] + output = store_1.similarity_search("Hotzenplotz", k=1) + assert Document(page_content="Hotzenplotz") in output[:2] + output = store_1.similarity_search("John Doe", k=1) + assert Document(page_content="Räuber") in output[:2] + + # Probe the multi-store. + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["test_collection_1", "test_collection_2"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + output = multisearch.similarity_search("Räuber Hotzenplotz", k=2) + assert Document(page_content="Räuber") in output[:2] + output = multisearch.similarity_search("John Doe", k=2) + assert Document(page_content="John") in output[:2] + + +def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: + """ + `CrateDBVectorSearchMultiCollection` does not provide functionality for + indexing documents. + """ + + with pytest.raises(NotImplementedError) as ex: + CrateDBVectorSearchMultiCollection.from_texts( + texts=["foo"], + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + assert ex.match("This adapter can not be used for indexing documents") + + +def test_cratedb_multicollection_search_no_collections() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when not able to identify + collections to search in. + """ + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ValueError) as ex: + store.similarity_search("foo") + assert ex.match("No collections found") diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/langchain/langchain/vectorstores/cratedb/__init__.py index 14b02ad126867..62462bce1eba9 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/__init__.py +++ b/libs/langchain/langchain/vectorstores/cratedb/__init__.py @@ -1,5 +1,7 @@ from .base import CrateDBVectorSearch +from .extended import CrateDBVectorSearchMultiCollection __all__ = [ "CrateDBVectorSearch", + "CrateDBVectorSearchMultiCollection", ] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index ec3c4c19d70a6..922ba2ed659d6 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -250,8 +250,26 @@ def _query_collection( collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") + return self._query_collection_multi( + collections=[collection], embedding=embedding, k=k, filter=filter + ) - filter_by = self.EmbeddingStore.collection_id == collection.uuid + def _query_collection_multi( + self, + collections: List[Any], + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query the collection.""" + self._init_models(embedding) + + collection_names = [coll.name for coll in collections] + collection_uuids = [coll.uuid for coll in collections] + self.logger.info(f"Querying collections: {collection_names}") + + with self.Session() as session: + filter_by = self.EmbeddingStore.collection_id.in_(collection_uuids) if filter is not None: filter_clauses = [] @@ -271,7 +289,7 @@ def _query_collection( ) # type: ignore[assignment] filter_clauses.append(filter_by_metadata) - filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # type: ignore[assignment] _type = self.EmbeddingStore diff --git a/libs/langchain/langchain/vectorstores/cratedb/extended.py b/libs/langchain/langchain/vectorstores/cratedb/extended.py new file mode 100644 index 0000000000000..9266438787368 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb/extended.py @@ -0,0 +1,92 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, +) + +import sqlalchemy +from sqlalchemy.orm import sessionmaker + +from langchain.schema.embeddings import Embeddings +from langchain.vectorstores.cratedb.base import ( + DEFAULT_DISTANCE_STRATEGY, + CrateDBVectorSearch, + DistanceStrategy, +) +from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME + + +class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch): + """ + Provide functionality for searching multiple collections. + It can not be used for indexing documents. + + To use it, you should have the ``crate[sqlalchemy]`` Python package installed. + + Synopsis:: + + from langchain.vectorstores.cratedb import CrateDBVectorSearchMultiCollection + + multisearch = CrateDBVectorSearchMultiCollection( + collection_names=["collection_foo", "collection_bar"], + embedding_function=embeddings, + connection_string=CONNECTION_STRING, + ) + docs_with_score = multisearch.similarity_search_with_score(query) + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME], + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, # type: ignore[arg-type] + logger: Optional[logging.Logger] = None, + relevance_score_fn: Optional[Callable[[float], float]] = None, + *, + connection: Optional[sqlalchemy.engine.Connection] = None, + engine_args: Optional[dict[str, Any]] = None, + ) -> None: + self.connection_string = connection_string + self.embedding_function = embedding_function + self.collection_names = collection_names + self._distance_strategy = distance_strategy # type: ignore[assignment] + self.logger = logger or logging.getLogger(__name__) + self.override_relevance_score_fn = relevance_score_fn + self.engine_args = engine_args or {} + # Create a connection if not provided, otherwise use the provided connection + self._engine = self.create_engine() + self.Session = sessionmaker(self._engine) + self._conn = connection if connection else self.connect() + self.__post_init__() + + @classmethod + def _from(cls, *args: List, **kwargs: Dict): # type: ignore[no-untyped-def,override] + raise NotImplementedError("This adapter can not be used for indexing documents") + + def get_collections(self, session: sqlalchemy.orm.Session) -> Any: + if self.CollectionStore is None: + raise RuntimeError( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) + return self.CollectionStore.get_by_names(session, self.collection_names) + + def _query_collection( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + ) -> List[Any]: + """Query multiple collections.""" + self._init_models(embedding) + with self.Session() as session: + collections = self.get_collections(session) + if not collections: + raise ValueError("No collections found") + return self._query_collection_multi( + collections=collections, embedding=embedding, k=k, filter=filter + ) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index 656de41bf4d45..1aec9b49a7260 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import sqlalchemy from crate.client.sqlalchemy.types import ObjectType @@ -62,6 +62,19 @@ def get_by_name( raise return None + @classmethod + def get_by_names( + cls, session: Session, names: List[str] + ) -> Optional["List[CollectionStore]"]: + try: + return ( + session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] # noqa: E501 + ) + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + return None + @classmethod def get_or_create( cls, diff --git a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py index 63394e78cbe84..1241e47e71e83 100644 --- a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py @@ -53,7 +53,11 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" - return self.embed_documents([text])[0] + if text not in self.known_texts: + return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] + return [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] class AngularTwoDimensionalEmbeddings(Embeddings): From b72a06c5d4852b27139e92d55510b5b62dafb74a Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 16:32:21 +0100 Subject: [PATCH 12/28] CrateDB vector: Improve SQLAlchemy data model query utility functions The CrateDB adapter works a bit different compared to the pgvector adapter it is building upon: Because the dimensionality of the vector field needs to be specified at table creation time, but because it is also a runtime parameter in LangChain, the table creation needs to be delayed. In some cases, the tables do not exist yet, but this is only relevant for the case when the user requests to pre-delete the collection, using the `pre_delete_collection` argument. So, do the error handling only there instead, and _not_ on the generic data model utility functions. --- .../vectorstores/test_cratedb.py | 31 ++++++++++++++++++- .../langchain/vectorstores/cratedb/base.py | 18 +++++++++-- .../langchain/vectorstores/cratedb/model.py | 24 +++----------- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 5a732ca5332f9..3351ecac7f7ef 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -5,6 +5,7 @@ docker-compose -f cratedb.yml up """ import os +import re from typing import Dict, Generator, List, Tuple import pytest @@ -324,6 +325,8 @@ def test_cratedb_collection_delete() -> None: # Verify data in database. collection_foo = store_foo.get_collection(session) collection_bar = store_bar.get_collection(session) + if collection_foo is None or collection_bar is None: + assert False, "Expected CollectionStore objects but received None" assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} @@ -333,6 +336,8 @@ def test_cratedb_collection_delete() -> None: # Verify that the "foo" collection has been deleted. collection_foo = store_foo.get_collection(session) collection_bar = store_bar.get_collection(session) + if collection_bar is None: + assert False, "Expected CollectionStore object but received None" assert collection_foo is None assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} @@ -628,12 +633,36 @@ def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: assert ex.match("This adapter can not be used for indexing documents") -def test_cratedb_multicollection_search_no_collections() -> None: +def test_cratedb_multicollection_search_table_does_not_exist() -> None: + """ + `CrateDBVectorSearchMultiCollection` will fail when the `collection` + table does not exist. + """ + + store = CrateDBVectorSearchMultiCollection( + collection_names=["unknown"], + embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + ) + with pytest.raises(ProgrammingError) as ex: + store.similarity_search("foo") + assert ex.match(re.escape("RelationUnknown[Relation 'collection' unknown]")) + + +def test_cratedb_multicollection_search_unknown_collection() -> None: """ `CrateDBVectorSearchMultiCollection` will fail when not able to identify collections to search in. """ + CrateDBVectorSearch.from_texts( + texts=["Räuber", "Hotzenplotz"], + collection_name="test_collection", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + store = CrateDBVectorSearchMultiCollection( collection_names=["unknown"], embedding_function=ConsistentFakeEmbeddingsWithAdaDimension(), diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index 922ba2ed659d6..da5f20702e2b2 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -164,10 +164,24 @@ def add_embeddings( if not embeddings: return [] self._init_models(embeddings[0]) + + # When the user requested to delete the collection before running subsequent + # operations on it, run the deletion gracefully if the table does not exist + # yet. if self.pre_delete_collection: - self.delete_collection() + try: + self.delete_collection() + except sqlalchemy.exc.ProgrammingError as ex: + if "RelationUnknown" not in str(ex): + raise + + # Tables need to be created at runtime, because the `EmbeddingStore.embedding` + # field, a `FloatVector`, needs to be initialized with a dimensionality + # parameter, which is only obtained at runtime. self.create_tables_if_not_exists() self.create_collection() + + # After setting up the table/collection at runtime, add embeddings. return super().add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) @@ -414,7 +428,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: else: raise ValueError( "No supported normalization function for distance_strategy of " - "{self._distance_strategy}. Consider providing relevance_score_fn to " + f"{self._distance_strategy}. Consider providing relevance_score_fn to " "CrateDBVectorSearch constructor." ) diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/langchain/langchain/vectorstores/cratedb/model.py index 1aec9b49a7260..0daea1ad44b5d 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/langchain/langchain/vectorstores/cratedb/model.py @@ -50,30 +50,14 @@ class CollectionStore(BaseModel): ) @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - try: - return ( - session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None + def get_by_name(cls, session: Session, name: str) -> "CollectionStore": + return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] @classmethod def get_by_names( cls, session: Session, names: List[str] - ) -> Optional["List[CollectionStore]"]: - try: - return ( - session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] # noqa: E501 - ) - except sqlalchemy.exc.ProgrammingError as ex: - if "RelationUnknown" not in str(ex): - raise - return None + ) -> List["CollectionStore"]: + return session.query(cls).filter(cls.name.in_(names)).all() # type: ignore[attr-defined] @classmethod def get_or_create( From f8317fe7dcd3b46dd1fadc8cba3d0c26dc956244 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 16:45:01 +0100 Subject: [PATCH 13/28] CrateDB vector: Improve testing when initialized without dimensionality --- .../vectorstores/test_cratedb.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 3351ecac7f7ef..44acb6652123e 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -368,11 +368,12 @@ def test_cratedb_collection_with_metadata() -> None: def test_cratedb_collection_no_embedding_dimension() -> None: - """Test end to end collection construction""" + """ + Verify that addressing collections fails when not specifying dimensions. + """ cratedb_vector = CrateDBVectorSearch( embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, - pre_delete_collection=True, ) session = Session(cratedb_vector.connect()) with pytest.raises(RuntimeError) as ex: @@ -671,3 +672,20 @@ def test_cratedb_multicollection_search_unknown_collection() -> None: with pytest.raises(ValueError) as ex: store.similarity_search("foo") assert ex.match("No collections found") + + +def test_cratedb_multicollection_no_embedding_dimension() -> None: + """ + Verify that addressing collections fails when not specifying dimensions. + """ + store = CrateDBVectorSearchMultiCollection( + embedding_function=None, # type: ignore[arg-type] + connection_string=CONNECTION_STRING, + ) + session = Session(store.connect()) + with pytest.raises(RuntimeError) as ex: + store.get_collection(session) + assert ex.match( + "Collection can't be accessed without specifying " + "dimension size of embedding vectors" + ) From 53aee67e89f065b7f8a579176d9a384c05d19c1f Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 21 Nov 2023 15:11:16 +0100 Subject: [PATCH 14/28] CrateDB vector: Use SA's `bulk_save_objects` method for inserting embeddings The performance gains can be substantially. --- libs/langchain/langchain/vectorstores/cratedb/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index da5f20702e2b2..552cc6c8dee53 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -182,9 +182,11 @@ def add_embeddings( self.create_collection() # After setting up the table/collection at runtime, add embeddings. - return super().add_embeddings( + embedding_ids = super().add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) + refresh_table(self.Session(), self.EmbeddingStore) + return embedding_ids def create_tables_if_not_exists(self) -> None: """ From 70685ceca47d19a4f458325884dfcafc9c84c82c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Wed, 22 Nov 2023 16:05:23 +0100 Subject: [PATCH 15/28] CrateDB vector: Test non-deterministic values by using pytest.approx The test cases can be written substantially more elegant. --- .../vectorstores/test_cratedb.py | 71 ++++--------------- 1 file changed, 14 insertions(+), 57 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 44acb6652123e..bcfc9eebef6d0 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -6,7 +6,7 @@ """ import os import re -from typing import Dict, Generator, List, Tuple +from typing import Dict, Generator, List import pytest import sqlalchemy as sa @@ -85,20 +85,6 @@ def prune_tables(engine: sa.Engine) -> None: pass -def decode_output( - output: List[Tuple[Document, float]] -) -> Tuple[List[Document], List[float]]: - """ - Decode a typical API result into separate `documents` and `scores`. - It is needed as utility function in some test cases to compensate - for different and/or flaky score values, when compared to the - original implementation. - """ - documents = [item[0] for item in output] - scores = [round(item[1], 1) for item in output] - return documents, scores - - def ensure_collection(session: sa.orm.Session, name: str) -> None: """ Create a (fake) collection item. @@ -241,12 +227,11 @@ def test_cratedb_with_filter_match() -> None: connection_string=CONNECTION_STRING, pre_delete_collection=True, ) - output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) # TODO: Original: # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 - assert output in [ - [(Document(page_content="foo", metadata={"page": "0"}), 2.1307645)], - [(Document(page_content="foo", metadata={"page": "0"}), 2.3150668)], + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.1)) ] @@ -263,20 +248,9 @@ def test_cratedb_with_filter_distant_match() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"}) - # TODO: Original: - # output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) # noqa: E501 - # assert output == [ - # (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) # noqa: E501 - # ] - documents, scores = decode_output(output) - assert documents == [ - Document(page_content="baz", metadata={"page": "2"}), - ] - assert scores in [ - [1.3], - [1.5], - [1.6], - [1.7], + # Original score value: 0.0013003906671379406 + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2)) ] @@ -429,19 +403,11 @@ def test_cratedb_with_filter_in_set() -> None: output = docsearch.similarity_search_with_score( "foo", k=2, filter={"page": {"IN": ["0", "2"]}} ) - # TODO: Original: - """ + # Original score values: 0.0, 0.0013003906671379406 assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), 0.0), - (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406), - ] - """ - documents, scores = decode_output(output) - assert documents == [ - Document(page_content="foo", metadata={"page": "0"}), - Document(page_content="baz", metadata={"page": "2"}), + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)), + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)), ] - assert scores == [3.0, 2.2] def test_cratedb_delete_docs() -> None: @@ -486,21 +452,12 @@ def test_cratedb_relevance_score() -> None: ) output = docsearch.similarity_search_with_relevance_scores("foo", k=3) - """ - # TODO: Original code, where the `distance` is stable. + # Original score values: 1.0, 0.9996744261675065, 0.9986996093328621 assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), 1.0), - (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), - (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), - ] - """ - documents, scores = decode_output(output) - assert documents == [ - Document(page_content="foo", metadata={"page": "0"}), - Document(page_content="bar", metadata={"page": "1"}), - Document(page_content="baz", metadata={"page": "2"}), + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)), + (Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)), + (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)), ] - assert scores == [1.4, 1.1, 0.8] def test_cratedb_retriever_search_threshold() -> None: From ccd2a2573c8f7db1479b3adc0d7919a394899962 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 28 Nov 2023 00:21:20 +0100 Subject: [PATCH 16/28] CrateDB vector: Fix initialization of vector dimensionality --- .../vectorstores/test_cratedb.py | 29 +++++++++++++++---- .../langchain/vectorstores/cratedb/base.py | 5 ++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index bcfc9eebef6d0..0b1e44ab31aa1 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -23,20 +23,18 @@ FakeEmbeddings, ) +SCHEMA_NAME = os.environ.get("TEST_CRATEDB_DATABASE", "testdrive") + CONNECTION_STRING = CrateDBVectorSearch.connection_string_from_db_params( driver=os.environ.get("TEST_CRATEDB_DRIVER", "crate"), host=os.environ.get("TEST_CRATEDB_HOST", "localhost"), port=int(os.environ.get("TEST_CRATEDB_PORT", "4200")), - database=os.environ.get("TEST_CRATEDB_DATABASE", "testdrive"), + database=SCHEMA_NAME, user=os.environ.get("TEST_CRATEDB_USER", "crate"), password=os.environ.get("TEST_CRATEDB_PASSWORD", ""), ) - -# TODO: Try 1536 after https://github.com/crate/crate/pull/14699. -# ADA_TOKEN_COUNT = 14 -ADA_TOKEN_COUNT = 1024 -# ADA_TOKEN_COUNT = 1536 +ADA_TOKEN_COUNT = 1536 @pytest.fixture @@ -167,6 +165,25 @@ def test_cratedb_texts() -> None: assert output == [Document(page_content="foo")] +def test_cratedb_embedding_dimension() -> None: + """Verify the `embedding` column uses the correct vector dimensionality.""" + texts = ["foo", "bar", "baz"] + docsearch = CrateDBVectorSearch.from_texts( + texts=texts, + collection_name="test_collection", + embedding=ConsistentFakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + with docsearch.Session() as session: + result = session.execute(sa.text(f"SHOW CREATE TABLE {SCHEMA_NAME}.embedding")) + record = result.first() + if not record: + raise ValueError("No data found") + ddl = record[0] + assert f'"embedding" FLOAT_VECTOR({ADA_TOKEN_COUNT})' in ddl + + def test_cratedb_embeddings() -> None: """Test end to end construction with embeddings and search.""" texts = ["foo", "bar", "baz"] diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/langchain/langchain/vectorstores/cratedb/base.py index 552cc6c8dee53..f2f0f29c47757 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/langchain/langchain/vectorstores/cratedb/base.py @@ -192,8 +192,9 @@ def create_tables_if_not_exists(self) -> None: """ Need to overwrite because this `Base` is different from parent's `Base`. """ - mf = ModelFactory() - mf.Base.metadata.create_all(self._engine) + if self.BaseModel is None: + raise RuntimeError("Storage models not initialized") + self.BaseModel.metadata.create_all(self._engine) def drop_tables(self) -> None: """ From 800ace60224db450f3e18afb7c28f86afebfea44 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 18 Jan 2024 03:06:53 +0100 Subject: [PATCH 17/28] CrateDB: Refactor to `langchain_community` --- .../chat_message_histories/__init__.py | 5 + .../document_loaders/__init__.py | 5 + .../document_loaders/cratedb.py | 5 + .../document_loaders/sqlalchemy.py | 4 +- .../vectorstores/__init__.py | 5 + .../vectorstores/cratedb/__init__.py | 0 .../vectorstores/cratedb/base.py | 8 +- .../vectorstores/cratedb/extended.py | 5 +- .../vectorstores/cratedb/model.py | 2 +- .../vectorstores/cratedb/sqlalchemy_type.py | 0 .../tests/examples/mlb_teams_2012.sql | 1 + .../test_sqlalchemy_cratedb.py | 2 +- .../test_sqlalchemy_postgresql.py | 4 +- .../test_sqlalchemy_sqlite.py | 2 +- .../vectorstores/test_cratedb.py | 18 +-- .../langchain/document_loaders/cratedb.py | 24 +++- .../memory/chat_message_histories/cratedb.py | 127 +++--------------- .../langchain/vectorstores/cratedb.py | 30 +++++ .../examples/mlb_teams_2012.csv | 32 ----- .../examples/mlb_teams_2012.sql | 41 ------ .../integration_tests/memory/test_cratedb.py | 6 +- 21 files changed, 122 insertions(+), 204 deletions(-) create mode 100644 libs/community/langchain_community/document_loaders/cratedb.py rename libs/{langchain/langchain => community/langchain_community}/document_loaders/sqlalchemy.py (97%) rename libs/{langchain/langchain => community/langchain_community}/vectorstores/cratedb/__init__.py (100%) rename libs/{langchain/langchain => community/langchain_community}/vectorstores/cratedb/base.py (99%) rename libs/{langchain/langchain => community/langchain_community}/vectorstores/cratedb/extended.py (95%) rename libs/{langchain/langchain => community/langchain_community}/vectorstores/cratedb/model.py (97%) rename libs/{langchain/langchain => community/langchain_community}/vectorstores/cratedb/sqlalchemy_type.py (100%) create mode 100644 libs/langchain/langchain/vectorstores/cratedb.py delete mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv delete mode 100644 libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql diff --git a/libs/community/langchain_community/chat_message_histories/__init__.py b/libs/community/langchain_community/chat_message_histories/__init__.py index fc20cacacceab..4f631a57c81c5 100644 --- a/libs/community/langchain_community/chat_message_histories/__init__.py +++ b/libs/community/langchain_community/chat_message_histories/__init__.py @@ -28,6 +28,9 @@ from langchain_community.chat_message_histories.cosmos_db import ( CosmosDBChatMessageHistory, ) + from langchain_community.chat_message_histories.cratedb import ( + CrateDBChatMessageHistory, + ) from langchain_community.chat_message_histories.dynamodb import ( DynamoDBChatMessageHistory, ) @@ -94,6 +97,7 @@ "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", + "CrateDBChatMessageHistory", "DynamoDBChatMessageHistory", "ElasticsearchChatMessageHistory", "FileChatMessageHistory", @@ -120,6 +124,7 @@ "CassandraChatMessageHistory": "langchain_community.chat_message_histories.cassandra", # noqa: E501 "ChatMessageHistory": "langchain_community.chat_message_histories.in_memory", "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories.cosmos_db", # noqa: E501 + "CrateDBChatMessageHistory": "langchain_community.chat_message_histories.cratedb", # noqa: E501 "DynamoDBChatMessageHistory": "langchain_community.chat_message_histories.dynamodb", "ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories.elasticsearch", # noqa: E501 "FileChatMessageHistory": "langchain_community.chat_message_histories.file", diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index 2576093d3d48b..76493d827752b 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -129,6 +129,9 @@ from langchain_community.document_loaders.couchbase import ( CouchbaseLoader, ) + from langchain_community.document_loaders.cratedb import ( + CrateDBLoader, + ) from langchain_community.document_loaders.csv_loader import ( CSVLoader, UnstructuredCSVLoader, @@ -576,6 +579,7 @@ "ConcurrentLoader": "langchain_community.document_loaders.concurrent", "ConfluenceLoader": "langchain_community.document_loaders.confluence", "CouchbaseLoader": "langchain_community.document_loaders.couchbase", + "CrateDBLoader": "langchain_community.document_loaders.cratedb", "CubeSemanticLoader": "langchain_community.document_loaders.cube_semantic", "DataFrameLoader": "langchain_community.document_loaders.dataframe", "DatadogLogsLoader": "langchain_community.document_loaders.datadog_logs", @@ -782,6 +786,7 @@ def __getattr__(name: str) -> Any: "ConcurrentLoader", "ConfluenceLoader", "CouchbaseLoader", + "CrateDBLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", diff --git a/libs/community/langchain_community/document_loaders/cratedb.py b/libs/community/langchain_community/document_loaders/cratedb.py new file mode 100644 index 0000000000000..8aa304e0c7762 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/cratedb.py @@ -0,0 +1,5 @@ +from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader + + +class CrateDBLoader(SQLAlchemyLoader): + pass diff --git a/libs/langchain/langchain/document_loaders/sqlalchemy.py b/libs/community/langchain_community/document_loaders/sqlalchemy.py similarity index 97% rename from libs/langchain/langchain/document_loaders/sqlalchemy.py rename to libs/community/langchain_community/document_loaders/sqlalchemy.py index 787c9f339b686..856d4fb105a36 100644 --- a/libs/langchain/langchain/document_loaders/sqlalchemy.py +++ b/libs/community/langchain_community/document_loaders/sqlalchemy.py @@ -2,8 +2,8 @@ import sqlalchemy as sa -from langchain.docstore.document import Document -from langchain.document_loaders.base import BaseLoader +from langchain_community.docstore.document import Document +from langchain_community.document_loaders.base import BaseLoader class SQLAlchemyLoader(BaseLoader): diff --git a/libs/community/langchain_community/vectorstores/__init__.py b/libs/community/langchain_community/vectorstores/__init__.py index c38beea0ed6d2..5741fd7a644b7 100644 --- a/libs/community/langchain_community/vectorstores/__init__.py +++ b/libs/community/langchain_community/vectorstores/__init__.py @@ -92,6 +92,9 @@ from langchain_community.vectorstores.couchbase import ( CouchbaseVectorStore, ) + from langchain_community.vectorstores.cratedb import ( + CrateDBVectorSearch, + ) from langchain_community.vectorstores.dashvector import ( DashVector, ) @@ -334,6 +337,7 @@ "Clickhouse", "ClickhouseSettings", "CouchbaseVectorStore", + "CrateDBVectorSearch", "DashVector", "DatabricksVectorSearch", "DeepLake", @@ -438,6 +442,7 @@ "Clickhouse": "langchain_community.vectorstores.clickhouse", "ClickhouseSettings": "langchain_community.vectorstores.clickhouse", "CouchbaseVectorStore": "langchain_community.vectorstores.couchbase", + "CrateDBVectorSearch": "langchain_community.vectorstores.cratedb", "DashVector": "langchain_community.vectorstores.dashvector", "DatabricksVectorSearch": "langchain_community.vectorstores.databricks_vector_search", # noqa: E501 "DeepLake": "langchain_community.vectorstores.deeplake", diff --git a/libs/langchain/langchain/vectorstores/cratedb/__init__.py b/libs/community/langchain_community/vectorstores/cratedb/__init__.py similarity index 100% rename from libs/langchain/langchain/vectorstores/cratedb/__init__.py rename to libs/community/langchain_community/vectorstores/cratedb/__init__.py diff --git a/libs/langchain/langchain/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py similarity index 99% rename from libs/langchain/langchain/vectorstores/cratedb/base.py rename to libs/community/langchain_community/vectorstores/cratedb/base.py index f2f0f29c47757..8cf1065b88e1b 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -19,13 +19,13 @@ polyfill_refresh_after_dml, refresh_table, ) -from sqlalchemy.orm import sessionmaker - from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.cratedb.model import ModelFactory from langchain.vectorstores.pgvector import PGVector +from sqlalchemy.orm import sessionmaker + +from langchain_community.vectorstores.cratedb.model import ModelFactory class DistanceStrategy(str, enum.Enum): @@ -93,7 +93,7 @@ def __post_init__( # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. patch_inspector() - self._engine = self.create_engine() + self._engine = self._create_engine() self.Session = sessionmaker(self._engine) # TODO: See what can be improved here. diff --git a/libs/langchain/langchain/vectorstores/cratedb/extended.py b/libs/community/langchain_community/vectorstores/cratedb/extended.py similarity index 95% rename from libs/langchain/langchain/vectorstores/cratedb/extended.py rename to libs/community/langchain_community/vectorstores/cratedb/extended.py index 9266438787368..db6a1213faf7a 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/extended.py +++ b/libs/community/langchain_community/vectorstores/cratedb/extended.py @@ -11,12 +11,13 @@ from sqlalchemy.orm import sessionmaker from langchain.schema.embeddings import Embeddings -from langchain.vectorstores.cratedb.base import ( + +from langchain_community.vectorstores.cratedb.base import ( DEFAULT_DISTANCE_STRATEGY, CrateDBVectorSearch, DistanceStrategy, ) -from langchain.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME +from langchain_community.vectorstores.pgvector import _LANGCHAIN_DEFAULT_COLLECTION_NAME class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch): diff --git a/libs/langchain/langchain/vectorstores/cratedb/model.py b/libs/community/langchain_community/vectorstores/cratedb/model.py similarity index 97% rename from libs/langchain/langchain/vectorstores/cratedb/model.py rename to libs/community/langchain_community/vectorstores/cratedb/model.py index 0daea1ad44b5d..f9dae6566d7c0 100644 --- a/libs/langchain/langchain/vectorstores/cratedb/model.py +++ b/libs/community/langchain_community/vectorstores/cratedb/model.py @@ -5,7 +5,7 @@ from crate.client.sqlalchemy.types import ObjectType from sqlalchemy.orm import Session, declarative_base, relationship -from langchain.vectorstores.cratedb.sqlalchemy_type import FloatVector +from langchain_community.vectorstores.cratedb.sqlalchemy_type import FloatVector def generate_uuid() -> str: diff --git a/libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py similarity index 100% rename from libs/langchain/langchain/vectorstores/cratedb/sqlalchemy_type.py rename to libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py diff --git a/libs/community/tests/examples/mlb_teams_2012.sql b/libs/community/tests/examples/mlb_teams_2012.sql index 33cb765a38ebe..9df72ef19954a 100644 --- a/libs/community/tests/examples/mlb_teams_2012.sql +++ b/libs/community/tests/examples/mlb_teams_2012.sql @@ -1,6 +1,7 @@ -- Provisioning table "mlb_teams_2012". -- -- psql postgresql://postgres@localhost < mlb_teams_2012.sql +-- crash < mlb_teams_2012.sql DROP TABLE IF EXISTS mlb_teams_2012; CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py index eec3a428a74e8..6cd42439024e0 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py @@ -12,7 +12,7 @@ import sqlalchemy as sa import sqlparse -from langchain.document_loaders import CrateDBLoader +from langchain_community.document_loaders import CrateDBLoader from tests.data import MLB_TEAMS_2012_SQL logging.basicConfig(level=logging.DEBUG) diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py index 29f52cb9f7a33..38dd6cfd62f09 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py @@ -12,7 +12,7 @@ import sqlalchemy as sa import sqlparse -from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader from tests.data import MLB_TEAMS_2012_SQL logging.basicConfig(level=logging.DEBUG) @@ -37,6 +37,8 @@ def engine() -> sa.Engine: """ Return an SQLAlchemy engine object. """ + if not psycopg2_installed: + raise pytest.skip("psycopg2 not installed") return sa.create_engine(CONNECTION_STRING, echo=False) diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py index f1fac2cecbc00..05759e5dea9ea 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py @@ -9,7 +9,7 @@ import sqlparse from _pytest.tmpdir import TempPathFactory -from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader from tests.data import MLB_TEAMS_2012_SQL logging.basicConfig(level=logging.DEBUG) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 0b1e44ab31aa1..65cd2c885aec4 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -11,13 +11,15 @@ import pytest import sqlalchemy as sa import sqlalchemy.orm +from langchain.docstore.document import Document from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm import Session -from langchain.docstore.document import Document -from langchain.vectorstores.cratedb import CrateDBVectorSearch -from langchain.vectorstores.cratedb.extended import CrateDBVectorSearchMultiCollection -from langchain.vectorstores.cratedb.model import ModelFactory +from langchain_community.vectorstores.cratedb import CrateDBVectorSearch +from langchain_community.vectorstores.cratedb.extended import ( + CrateDBVectorSearchMultiCollection, +) +from langchain_community.vectorstores.cratedb.model import ModelFactory from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, @@ -366,7 +368,7 @@ def test_cratedb_collection_no_embedding_dimension() -> None: embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, ) - session = Session(cratedb_vector.connect()) + session = cratedb_vector.Session() with pytest.raises(RuntimeError) as ex: cratedb_vector.get_collection(session) assert ex.match( @@ -578,7 +580,7 @@ def test_cratedb_multicollection_search_success() -> None: output = store_1.similarity_search("Hotzenplotz", k=1) assert Document(page_content="Hotzenplotz") in output[:2] output = store_1.similarity_search("John Doe", k=1) - assert Document(page_content="Räuber") in output[:2] + assert Document(page_content="Hotzenplotz") in output[:2] # Probe the multi-store. multisearch = CrateDBVectorSearchMultiCollection( @@ -589,7 +591,7 @@ def test_cratedb_multicollection_search_success() -> None: output = multisearch.similarity_search("Räuber Hotzenplotz", k=2) assert Document(page_content="Räuber") in output[:2] output = multisearch.similarity_search("John Doe", k=2) - assert Document(page_content="John") in output[:2] + assert Document(page_content="Doe") in output[:2] def test_cratedb_multicollection_fail_indexing_not_permitted() -> None: @@ -656,7 +658,7 @@ def test_cratedb_multicollection_no_embedding_dimension() -> None: embedding_function=None, # type: ignore[arg-type] connection_string=CONNECTION_STRING, ) - session = Session(store.connect()) + session = store.Session() with pytest.raises(RuntimeError) as ex: store.get_collection(session) assert ex.match( diff --git a/libs/langchain/langchain/document_loaders/cratedb.py b/libs/langchain/langchain/document_loaders/cratedb.py index 9e34b4d0cb9ec..fb75aceca4de4 100644 --- a/libs/langchain/langchain/document_loaders/cratedb.py +++ b/libs/langchain/langchain/document_loaders/cratedb.py @@ -1,5 +1,23 @@ -from langchain.document_loaders.sqlalchemy import SQLAlchemyLoader +from typing import TYPE_CHECKING, Any +from langchain._api import create_importer -class CrateDBLoader(SQLAlchemyLoader): - pass +if TYPE_CHECKING: + from langchain_community.document_loaders.cratedb import CrateDBLoader + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = {"CrateDBLoader": "langchain_community.document_loaders"} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "CrateDBLoader", +] diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py index 19007176cb193..376cacd8985ad 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py @@ -1,113 +1,30 @@ -import json -import typing as t +from typing import TYPE_CHECKING, Any -import sqlalchemy as sa -from cratedb_toolkit.sqlalchemy import ( - patch_inspector, - polyfill_refresh_after_dml, - refresh_table, -) +from langchain._api import create_importer -from langchain.memory.chat_message_histories.sql import ( - BaseMessageConverter, - SQLChatMessageHistory, -) -from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict +if TYPE_CHECKING: + from langchain_community.chat_message_histories.cratedb import ( + CrateDBChatMessageHistory, + CrateDBMessageConverter, + ) +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "CrateDBChatMessageHistory": "langchain_community.chat_message_histories", + "CrateDBMessageConverter": "langchain_community.chat_message_histories" +} -def create_message_model(table_name, DynamicBase): # type: ignore - """ - Create a message model for a given table name. +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) - This is a specialized version for CrateDB for generating integer-based primary keys. - TODO: Find a way to converge CrateDB's generate_random_uuid() into a variant - returning its integer value. - Args: - table_name: The name of the table to use. - DynamicBase: The base class to use for the model. +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) - Returns: - The model class. - """ - # Model is declared inside a function to be able to use a dynamic table name. - class Message(DynamicBase): - __tablename__ = table_name - id = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now()) - session_id = sa.Column(sa.Text) - message = sa.Column(sa.Text) - - return Message - - -class CrateDBMessageConverter(BaseMessageConverter): - """ - The default message converter for CrateDBMessageConverter. - - It is the same as the generic `SQLChatMessageHistory` converter, - but swaps in a different `create_message_model` function. - """ - - def __init__(self, table_name: str): - self.model_class = create_message_model(table_name, sa.orm.declarative_base()) - - def from_sql_model(self, sql_message: t.Any) -> BaseMessage: - return messages_from_dict([json.loads(sql_message.message)])[0] - - def to_sql_model(self, message: BaseMessage, session_id: str) -> t.Any: - return self.model_class( - session_id=session_id, message=json.dumps(_message_to_dict(message)) - ) - - def get_sql_model_class(self) -> t.Any: - return self.model_class - - -class CrateDBChatMessageHistory(SQLChatMessageHistory): - """ - It is the same as the generic `SQLChatMessageHistory` implementation, - but swaps in a different message converter by default. - """ - - DEFAULT_MESSAGE_CONVERTER: t.Type[BaseMessageConverter] = CrateDBMessageConverter - - def __init__( - self, - session_id: str, - connection_string: str, - table_name: str = "message_store", - session_id_field_name: str = "session_id", - custom_message_converter: t.Optional[BaseMessageConverter] = None, - ): - # FIXME: Refactor elsewhere. - patch_inspector() - - super().__init__( - session_id, - connection_string, - table_name=table_name, - session_id_field_name=session_id_field_name, - custom_message_converter=custom_message_converter, - ) - - # TODO: Check how this can be improved. - polyfill_refresh_after_dml(self.Session) - - def _messages_query(self) -> sa.Select: - """ - Construct an SQLAlchemy selectable to query for messages. - For CrateDB, add an `ORDER BY` clause on the primary key. - """ - selectable = super()._messages_query() - selectable = selectable.order_by(self.sql_model_class.id) - return selectable - - def clear(self) -> None: - """ - Needed for CrateDB to synchronize data because `on_flush` does not catch it. - """ - outcome = super().clear() - with self.Session() as session: - refresh_table(session, self.sql_model_class) - return outcome +__all__ = [ + "CrateDBChatMessageHistory", + "CrateDBMessageConverter", +] diff --git a/libs/langchain/langchain/vectorstores/cratedb.py b/libs/langchain/langchain/vectorstores/cratedb.py new file mode 100644 index 0000000000000..65876a820a6d4 --- /dev/null +++ b/libs/langchain/langchain/vectorstores/cratedb.py @@ -0,0 +1,30 @@ +from typing import TYPE_CHECKING, Any + +from langchain._api import create_importer + +if TYPE_CHECKING: + from langchain_community.vectorstores.cratedb.base import CrateDBVectorSearch + from langchain_community.vectorstores.cratedb.extended import ( + CrateDBVectorSearchMultiCollection, + ) + +# Create a way to dynamically look up deprecated imports. +# Used to consolidate logic for raising deprecation warnings and +# handling optional imports. +DEPRECATED_LOOKUP = { + "CrateDBVectorSearch": "langchain_community.vectorstores", + "CrateDBVectorSearchMultiCollection": "langchain_community.vectorstores", +} + +_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) + + +def __getattr__(name: str) -> Any: + """Look up attributes dynamically.""" + return _import_attribute(name) + + +__all__ = [ + "CrateDBVectorSearch", + "CrateDBVectorSearchMultiCollection", +] diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv deleted file mode 100644 index b22ae961a1331..0000000000000 --- a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.csv +++ /dev/null @@ -1,32 +0,0 @@ -"Team", "Payroll (millions)", "Wins" -"Nationals", 81.34, 98 -"Reds", 82.20, 97 -"Yankees", 197.96, 95 -"Giants", 117.62, 94 -"Braves", 83.31, 94 -"Athletics", 55.37, 94 -"Rangers", 120.51, 93 -"Orioles", 81.43, 93 -"Rays", 64.17, 90 -"Angels", 154.49, 89 -"Tigers", 132.30, 88 -"Cardinals", 110.30, 88 -"Dodgers", 95.14, 86 -"White Sox", 96.92, 85 -"Brewers", 97.65, 83 -"Phillies", 174.54, 81 -"Diamondbacks", 74.28, 81 -"Pirates", 63.43, 79 -"Padres", 55.24, 76 -"Mariners", 81.97, 75 -"Mets", 93.35, 74 -"Blue Jays", 75.48, 73 -"Royals", 60.91, 72 -"Marlins", 118.07, 69 -"Red Sox", 173.18, 69 -"Indians", 78.43, 68 -"Twins", 94.08, 66 -"Rockies", 78.06, 64 -"Cubs", 88.19, 61 -"Astros", 60.65, 55 - diff --git a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql b/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql deleted file mode 100644 index 9df72ef19954a..0000000000000 --- a/libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql +++ /dev/null @@ -1,41 +0,0 @@ --- Provisioning table "mlb_teams_2012". --- --- psql postgresql://postgres@localhost < mlb_teams_2012.sql --- crash < mlb_teams_2012.sql - -DROP TABLE IF EXISTS mlb_teams_2012; -CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); -INSERT INTO mlb_teams_2012 - ("Team", "Payroll (millions)", "Wins") -VALUES - ('Nationals', 81.34, 98), - ('Reds', 82.20, 97), - ('Yankees', 197.96, 95), - ('Giants', 117.62, 94), - ('Braves', 83.31, 94), - ('Athletics', 55.37, 94), - ('Rangers', 120.51, 93), - ('Orioles', 81.43, 93), - ('Rays', 64.17, 90), - ('Angels', 154.49, 89), - ('Tigers', 132.30, 88), - ('Cardinals', 110.30, 88), - ('Dodgers', 95.14, 86), - ('White Sox', 96.92, 85), - ('Brewers', 97.65, 83), - ('Phillies', 174.54, 81), - ('Diamondbacks', 74.28, 81), - ('Pirates', 63.43, 79), - ('Padres', 55.24, 76), - ('Mariners', 81.97, 75), - ('Mets', 93.35, 74), - ('Blue Jays', 75.48, 73), - ('Royals', 60.91, 72), - ('Marlins', 118.07, 69), - ('Red Sox', 173.18, 69), - ('Indians', 78.43, 68), - ('Twins', 94.08, 66), - ('Rockies', 78.06, 64), - ('Cubs', 88.19, 61), - ('Astros', 60.65, 55) -; diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/langchain/tests/integration_tests/memory/test_cratedb.py index 2c00b5d2b200b..b8f6f51397c64 100644 --- a/libs/langchain/tests/integration_tests/memory/test_cratedb.py +++ b/libs/langchain/tests/integration_tests/memory/test_cratedb.py @@ -59,7 +59,7 @@ def sql_histories( def test_add_messages( - sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], ) -> None: history1, _ = sql_histories history1.add_user_message("Hello!") @@ -74,7 +74,7 @@ def test_add_messages( def test_multiple_sessions( - sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], ) -> None: history1, history2 = sql_histories @@ -104,7 +104,7 @@ def test_multiple_sessions( def test_clear_messages( - sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory] + sql_histories: Tuple[CrateDBChatMessageHistory, CrateDBChatMessageHistory], ) -> None: sql_history, other_history = sql_histories sql_history.add_user_message("Hello!") From b40c24f3dafe4844b3e7d394c0cbde3ac7fe5472 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 18 Jan 2024 22:41:58 +0100 Subject: [PATCH 18/28] CrateDB vector: Adjustments for updates to pgvector adapter --- .../vectorstores/cratedb/base.py | 8 ++--- .../vectorstores/cratedb/extended.py | 7 ++--- .../vectorstores/cratedb/sqlalchemy_patch.py | 29 +++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) create mode 100644 libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py diff --git a/libs/community/langchain_community/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py index 8cf1065b88e1b..f4abac3ab7a2d 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -16,7 +16,6 @@ import sqlalchemy from cratedb_toolkit.sqlalchemy.patch import patch_inspector from cratedb_toolkit.sqlalchemy.polyfill import ( - polyfill_refresh_after_dml, refresh_table, ) from langchain.docstore.document import Document @@ -26,6 +25,7 @@ from sqlalchemy.orm import sessionmaker from langchain_community.vectorstores.cratedb.model import ModelFactory +from langchain_community.vectorstores.cratedb.sqlalchemy_patch import polyfill_refresh_after_dml_engine class DistanceStrategy(str, enum.Enum): @@ -93,11 +93,11 @@ def __post_init__( # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. patch_inspector() - self._engine = self._create_engine() + self._engine = self._bind self.Session = sessionmaker(self._engine) - # TODO: See what can be improved here. - polyfill_refresh_after_dml(self.Session) + # TODO: Pull in from a future `sqlalchemy-cratedb`. + polyfill_refresh_after_dml_engine(self._engine) # Need to defer initialization, because dimension size # can only be figured out at runtime. diff --git a/libs/community/langchain_community/vectorstores/cratedb/extended.py b/libs/community/langchain_community/vectorstores/cratedb/extended.py index db6a1213faf7a..324eca0861d95 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/extended.py +++ b/libs/community/langchain_community/vectorstores/cratedb/extended.py @@ -8,7 +8,6 @@ ) import sqlalchemy -from sqlalchemy.orm import sessionmaker from langchain.schema.embeddings import Embeddings @@ -58,10 +57,10 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn self.engine_args = engine_args or {} + # Create a connection if not provided, otherwise use the provided connection - self._engine = self.create_engine() - self.Session = sessionmaker(self._engine) - self._conn = connection if connection else self.connect() + self._bind = connection if connection else self._create_engine() + self.__post_init__() @classmethod diff --git a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py new file mode 100644 index 0000000000000..fd50e55119be5 --- /dev/null +++ b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py @@ -0,0 +1,29 @@ +import sqlalchemy as sa + + +def polyfill_refresh_after_dml_engine(engine: sa.engine.Engine): + def receive_after_execute( + conn: sa.engine.Connection, + clauseelement, + multiparams, + params, + execution_options, + result, + ): + """ + Run a `REFRESH TABLE ...` command after each DML operation (INSERT, UPDATE, + DELETE). This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` + adapters. + + TODO: Pull in from a future `sqlalchemy-cratedb`. + """ + if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)): + if not isinstance(clauseelement.table, sa.sql.Join): + full_table_name = f'"{clauseelement.table.name}"' + if clauseelement.table.schema is not None: + full_table_name = ( + f'"{clauseelement.table.schema}".' + full_table_name + ) + conn.execute(sa.text(f"REFRESH TABLE {full_table_name};")) + + sa.event.listen(engine, "after_execute", receive_after_execute) From cb06a662c185f97690c84ec2186635b23a72318b Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Fri, 19 Jan 2024 01:10:08 +0100 Subject: [PATCH 19/28] CrateDB vector: Relax test constraint --- .../tests/integration_tests/vectorstores/test_cratedb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 65cd2c885aec4..3f92717c2a138 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -250,7 +250,7 @@ def test_cratedb_with_filter_match() -> None: # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.1)) + (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.3)) ] From fa28b24d03dd11047080774f7b40e30a266a8cf6 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Wed, 5 Jun 2024 22:15:43 +0200 Subject: [PATCH 20/28] CrateDB loader: SQLAlchemyLoader has been superseded by SQLDatabaseLoader --- docs/docs/.gitignore | 3 +- .../{sqlalchemy.mdx => sql_database.mdx} | 32 +- .../document_loaders/sql_database.ipynb | 360 ++++++++++++++++++ .../document_loaders/sqlalchemy.ipynb | 237 ------------ .../document_loaders/cratedb.py | 4 +- .../document_loaders/sqlalchemy.py | 112 ------ .../document_loaders/test_sql_database.py | 23 ++ .../test_sqlalchemy_cratedb.py | 146 ------- .../test_sqlalchemy_postgresql.py | 179 --------- .../test_sqlalchemy_sqlite.py | 181 --------- .../tests/unit_tests/test_sql_database.py | 9 +- .../tests/unit_tests/test_sqlalchemy.py | 7 - .../langchain/document_loaders/__init__.py | 1 - 13 files changed, 415 insertions(+), 879 deletions(-) rename docs/docs/how_to/{sqlalchemy.mdx => sql_database.mdx} (84%) create mode 100644 docs/docs/integrations/document_loaders/sql_database.ipynb delete mode 100644 docs/docs/integrations/document_loaders/sqlalchemy.ipynb delete mode 100644 libs/community/langchain_community/document_loaders/sqlalchemy.py delete mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py delete mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py delete mode 100644 libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py delete mode 100644 libs/community/tests/unit_tests/test_sqlalchemy.py diff --git a/docs/docs/.gitignore b/docs/docs/.gitignore index 25a6e30a4b775..e586a74dfb131 100644 --- a/docs/docs/.gitignore +++ b/docs/docs/.gitignore @@ -4,4 +4,5 @@ node_modules/ .docusaurus .cache-loader -docs/api \ No newline at end of file +docs/api +example.sqlite diff --git a/docs/docs/how_to/sqlalchemy.mdx b/docs/docs/how_to/sql_database.mdx similarity index 84% rename from docs/docs/how_to/sqlalchemy.mdx rename to docs/docs/how_to/sql_database.mdx index 9f7e663db075e..1ecdeda75b307 100644 --- a/docs/docs/how_to/sqlalchemy.mdx +++ b/docs/docs/how_to/sql_database.mdx @@ -1,10 +1,11 @@ -# SQLAlchemy +# SQLDatabaseLoader ## About -The [SQLAlchemy] document loader loads records from any supported database, -see [SQLAlchemy dialects] for all supported SQL databases and dialects. +The `SQLDatabaseLoader` loads records from any database supported by +[SQLAlchemy], see [SQLAlchemy dialects] for the whole list of supported +SQL databases and dialects. You can either use plain SQL for querying, or use an SQLAlchemy `Select` statement object, if you are using SQLAlchemy-Core or -ORM. @@ -31,11 +32,11 @@ psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integrat ### Basic loading ```python -from langchain.document_loaders import SQLAlchemyLoader +from langchain_community.document_loaders.sql_database import SQLDatabaseLoader from pprint import pprint -loader = SQLAlchemyLoader( +loader = SQLDatabaseLoader( query="SELECT * FROM mlb_teams_2012 LIMIT 3;", url="postgresql+psycopg2://postgres@localhost:5432/testdrive", ) @@ -66,7 +67,7 @@ Having the `query` within metadata is useful when using documents loaded from database tables for chains that answer questions using their origin queries. ```python -loader = SQLAlchemyLoader( +loader = SQLDatabaseLoader( query="SELECT * FROM mlb_teams_2012 LIMIT 3;", url="postgresql+psycopg2://postgres@localhost:5432/testdrive", include_rownum_into_metadata=True, @@ -97,11 +98,20 @@ the `metadata` dictionary with corresponding information. When `page_content_col is empty, all columns will be used. ```python -loader = SQLAlchemyLoader( +import functools + +row_to_content = functools.partial( + SQLDatabaseLoader.page_content_default_mapper, column_names=["Payroll (millions)", "Wins"] +) +row_to_metadata = functools.partial( + SQLDatabaseLoader.metadata_default_mapper, column_names=["Team"] +) + +loader = SQLDatabaseLoader( query="SELECT * FROM mlb_teams_2012 LIMIT 3;", url="postgresql+psycopg2://postgres@localhost:5432/testdrive", - page_content_columns=["Payroll (millions)", "Wins"], - metadata_columns=["Team"], + page_content_mapper=row_to_content, + metadata_mapper=row_to_metadata, ) docs = loader.load() ``` @@ -128,10 +138,10 @@ document created from each row. This is useful for identifying documents through their metadata. Typically, you may use the primary key column(s) for that purpose. ```python -loader = SQLAlchemyLoader( +loader = SQLDatabaseLoader( query="SELECT * FROM mlb_teams_2012 LIMIT 3;", url="postgresql+psycopg2://postgres@localhost:5432/testdrive", - source_columns="Team", + source_columns=["Team"], ) docs = loader.load() ``` diff --git a/docs/docs/integrations/document_loaders/sql_database.ipynb b/docs/docs/integrations/document_loaders/sql_database.ipynb new file mode 100644 index 0000000000000..9b3fe41df43fa --- /dev/null +++ b/docs/docs/integrations/document_loaders/sql_database.ipynb @@ -0,0 +1,360 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQL Database\n", + "\n", + "## About\n", + "\n", + "The `SQLDatabaseLoader` loads records from any database supported by\n", + "[SQLAlchemy], see [SQLAlchemy dialects] for the whole list of supported\n", + "SQL databases and dialects.\n", + "\n", + "For talking to the database, the document loader uses the [SQLDatabase]\n", + "utility from the LangChain integration toolkit.\n", + "\n", + "You can either use plain SQL for querying, or use an SQLAlchemy `Select`\n", + "statement object, if you are using SQLAlchemy-Core or -ORM.\n", + "\n", + "You can select which columns to place into the document, which columns\n", + "to place into its metadata, which columns to use as a `source` attribute\n", + "in metadata, and whether to include the result row number and/or the SQL\n", + "query expression into the metadata.\n", + "\n", + "## What's inside\n", + "\n", + "This notebook covers how to load documents from an [SQLite] database,\n", + "using the [SQLAlchemy] document loader.\n", + "\n", + "It loads the result of a database query with one document per row.\n", + "\n", + "[SQLAlchemy]: https://www.sqlalchemy.org/\n", + "[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/latest/dialects/\n", + "[SQLDatabase]: https://python.langchain.com/docs/integrations/toolkits/sql_database\n", + "[SQLite]: https://sqlite.org/\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install langchain langchain-community sqlalchemy termsql" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Populate SQLite database with example input data." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nationals|81.34|98\r\n", + "Reds|82.2|97\r\n", + "Yankees|197.96|95\r\n", + "Giants|117.62|94\r\n", + "Braves|83.31|94\r\n", + "Athletics|55.37|94\r\n", + "Rangers|120.51|93\r\n", + "Orioles|81.43|93\r\n", + "Rays|64.17|90\r\n", + "Angels|154.49|89\r\n", + "Tigers|132.3|88\r\n", + "Cardinals|110.3|88\r\n", + "Dodgers|95.14|86\r\n", + "White Sox|96.92|85\r\n", + "Brewers|97.65|83\r\n", + "Phillies|174.54|81\r\n", + "Diamondbacks|74.28|81\r\n", + "Pirates|63.43|79\r\n", + "Padres|55.24|76\r\n", + "Mariners|81.97|75\r\n", + "Mets|93.35|74\r\n", + "Blue Jays|75.48|73\r\n", + "Royals|60.91|72\r\n", + "Marlins|118.07|69\r\n", + "Red Sox|173.18|69\r\n", + "Indians|78.43|68\r\n", + "Twins|94.08|66\r\n", + "Rockies|78.06|64\r\n", + "Cubs|88.19|61\r\n", + "Astros|60.65|55\r\n", + "||\r\n" + ] + } + ], + "source": [ + "!termsql --infile=./example_data/mlb_teams_2012.csv --head --csv --outfile=example.sqlite --table=payroll" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Basic usage" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from langchain_community.document_loaders import SQLDatabaseLoader\n", + "\n", + "loader = SQLDatabaseLoader(\n", + " \"SELECT * FROM payroll LIMIT 2\",\n", + " url=\"sqlite:///example.sqlite\",\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98'),\n", + " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97')]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify which columns are content vs. metadata\n", + "\n", + "Use the `page_content_mapper` keyword argument to optionally customize how to derive\n", + "a page content string from an input database record / row. By default, all columns\n", + "will be used.\n", + "\n", + "Use the `metadata_mapper` keyword argument to optionally customize how to derive\n", + "a document metadata dictionary from an input database record / row. By default,\n", + "document metadata will be empty." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "# Configure built-in page content mapper to include only specified columns.\n", + "row_to_content = functools.partial(\n", + " SQLDatabaseLoader.page_content_default_mapper, column_names=[\"Team\", \"Wins\"]\n", + ")\n", + "\n", + "# Configure built-in metadata dictionary mapper to include specified columns.\n", + "row_to_metadata = functools.partial(\n", + " SQLDatabaseLoader.metadata_default_mapper, column_names=[\"Payroll (millions)\"]\n", + ")\n", + "\n", + "loader = SQLDatabaseLoader(\n", + " \"SELECT * FROM payroll LIMIT 2\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " page_content_mapper=row_to_content,\n", + " metadata_mapper=row_to_metadata,\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nWins: 98', metadata={'Payroll (millions)': 81.34}),\n", + " Document(page_content='Team: Reds\\nWins: 97', metadata={'Payroll (millions)': 82.2})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Those examples demonstrate how to use custom functions to define arbitrary\n", + "mapping rules by using Python code.\n", + "```python\n", + "def page_content_mapper(row: sa.RowMapping, column_names: Optional[List[str]] = None) -> str:\n", + " return f\"Team: {row['Team']}\"\n", + "```\n", + "```python\n", + "def metadata_default_mapper(row: sa.RowMapping, column_names: Optional[List[str]] = None) -> Dict[str, Any]:\n", + " return {\"team\": row['Team']}\n", + "```" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specify column(s) to identify the document source\n", + "\n", + "Use the `source_columns` option to specify the columns to use as a \"source\" for the\n", + "document created from each row. This is useful for identifying documents through\n", + "their metadata. Typically, you may use the primary key column(s) for that purpose." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "loader = SQLDatabaseLoader(\n", + " \"SELECT * FROM payroll LIMIT 2\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " source_columns=[\"Team\"],\n", + ")\n", + "documents = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98', metadata={'source': 'Nationals'}),\n", + " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97', metadata={'source': 'Reds'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Enrich metadata with row number and/or original SQL query\n", + "\n", + "Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to\n", + "optionally populate the `metadata` dictionary with corresponding information.\n", + "\n", + "Having the `query` within metadata is useful when using documents loaded from\n", + "database tables for chains that answer questions using their origin queries." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 49, + "outputs": [], + "source": [ + "loader = SQLDatabaseLoader(\n", + " \"SELECT * FROM payroll LIMIT 2\",\n", + " url=\"sqlite:///example.sqlite\",\n", + " include_rownum_into_metadata=True,\n", + " include_query_into_metadata=True,\n", + ")\n", + "documents = loader.load()" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 50, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM payroll LIMIT 2'}),\n", + " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM payroll LIMIT 2'})]\n" + ] + } + ], + "source": [ + "pprint(documents)" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb b/docs/docs/integrations/document_loaders/sqlalchemy.ipynb deleted file mode 100644 index 5d603d7263c53..0000000000000 --- a/docs/docs/integrations/document_loaders/sqlalchemy.ipynb +++ /dev/null @@ -1,237 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SQLAlchemy\n", - "\n", - "This notebook demonstrates how to load documents from an [SQLite] database,\n", - "using the [SQLAlchemy] document loader.\n", - "\n", - "It loads the result of a database query with one document per row.\n", - "\n", - "[SQLAlchemy]: https://www.sqlalchemy.org/\n", - "[SQLite]: https://sqlite.org/" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Prerequisites" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "#!pip install langchain termsql" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Provide input data as SQLite database." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Overwriting example.csv\n" - ] - } - ], - "source": [ - "%%file example.csv\n", - "Team,Payroll\n", - "Nationals,81.34\n", - "Reds,82.20" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Nationals|81.34\r\n", - "Reds|82.2\r\n" - ] - } - ], - "source": [ - "!termsql --infile=example.csv --head --delimiter=\",\" --outfile=example.sqlite --table=payroll" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "source": [ - "## Usage" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.document_loaders import SQLAlchemyLoader\n", - "from pprint import pprint\n", - "\n", - "loader = SQLAlchemyLoader(\n", - " \"SELECT * FROM payroll\",\n", - " url=\"sqlite:///example.sqlite\",\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={}),\n", - " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Specifying Which Columns are Content vs Metadata" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "loader = SQLAlchemyLoader(\n", - " \"SELECT * FROM payroll\",\n", - " url=\"sqlite:///example.sqlite\",\n", - " page_content_columns=[\"Team\"],\n", - " metadata_columns=[\"Payroll\"],\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals', metadata={'Payroll': 81.34}),\n", - " Document(page_content='Team: Reds', metadata={'Payroll': 82.2})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Adding Source to Metadata" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "loader = SQLAlchemyLoader(\n", - " \"SELECT * FROM payroll\",\n", - " url=\"sqlite:///example.sqlite\",\n", - " source_columns=[\"Team\"],\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nPayroll: 81.34', metadata={'source': 'Nationals'}),\n", - " Document(page_content='Team: Reds\\nPayroll: 82.2', metadata={'source': 'Reds'})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/libs/community/langchain_community/document_loaders/cratedb.py b/libs/community/langchain_community/document_loaders/cratedb.py index 8aa304e0c7762..a97b4dde8f354 100644 --- a/libs/community/langchain_community/document_loaders/cratedb.py +++ b/libs/community/langchain_community/document_loaders/cratedb.py @@ -1,5 +1,5 @@ -from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader +from langchain_community.document_loaders.sql_database import SQLDatabaseLoader -class CrateDBLoader(SQLAlchemyLoader): +class CrateDBLoader(SQLDatabaseLoader): pass diff --git a/libs/community/langchain_community/document_loaders/sqlalchemy.py b/libs/community/langchain_community/document_loaders/sqlalchemy.py deleted file mode 100644 index 856d4fb105a36..0000000000000 --- a/libs/community/langchain_community/document_loaders/sqlalchemy.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Dict, List, Optional, Union - -import sqlalchemy as sa - -from langchain_community.docstore.document import Document -from langchain_community.document_loaders.base import BaseLoader - - -class SQLAlchemyLoader(BaseLoader): - """ - Load documents by querying database tables supported by SQLAlchemy. - Each document represents one row of the result. - """ - - def __init__( - self, - query: Union[str, sa.Select], - url: str, - page_content_columns: Optional[List[str]] = None, - metadata_columns: Optional[List[str]] = None, - source_columns: Optional[List[str]] = None, - include_rownum_into_metadata: bool = False, - include_query_into_metadata: bool = False, - sqlalchemy_kwargs: Optional[Dict] = None, - ): - """ - - Args: - query: The query to execute. - url: The SQLAlchemy connection string of the database to connect to. - page_content_columns: The columns to write into the `page_content` - of the document. Optional. - metadata_columns: The columns to write into the `metadata` of the document. - Optional. - source_columns: The names of the columns to use as the `source` within the - metadata dictionary. Optional. - include_rownum_into_metadata: Whether to include the row number into the - metadata dictionary. Optional. Default: False. - include_query_into_metadata: Whether to include the query expression into - the metadata dictionary. Optional. Default: False. - sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`. - """ - self.query = query - self.url = url - self.page_content_columns = page_content_columns - self.metadata_columns = metadata_columns - self.source_columns = source_columns - self.include_rownum_into_metadata = include_rownum_into_metadata - self.include_query_into_metadata = include_query_into_metadata - self.sqlalchemy_kwargs = sqlalchemy_kwargs or {} - - def load(self) -> List[Document]: - try: - import sqlalchemy as sa - except ImportError: - raise ImportError( - "Could not import sqlalchemy python package. " - "Please install it with `pip install sqlalchemy`." - ) - - engine = sa.create_engine(self.url, **self.sqlalchemy_kwargs) - - docs = [] - with engine.connect() as conn: - if isinstance(self.query, sa.Select): - result = conn.execute(self.query) - query_sql = str(self.query.compile(bind=engine)) - elif isinstance(self.query, str): - result = conn.execute(sa.text(self.query)) - query_sql = self.query - else: - raise TypeError( - f"Unable to process query of unknown type: {self.query}" - ) - field_names = list(result.mappings().keys()) - - if self.page_content_columns is None: - page_content_columns = field_names - else: - page_content_columns = self.page_content_columns - - if self.metadata_columns is None: - metadata_columns = [] - else: - metadata_columns = self.metadata_columns - - for i, row in enumerate(result.mappings()): - page_content = "\n".join( - f"{column}: {value}" - for column, value in row.items() - if column in page_content_columns - ) - - metadata: Dict[str, Union[str, int]] = {} - if self.include_rownum_into_metadata: - metadata["row"] = i - if self.include_query_into_metadata: - metadata["query"] = query_sql - - source_values = [] - for column, value in row.items(): - if column in metadata_columns: - metadata[column] = value - if self.source_columns and column in self.source_columns: - source_values.append(value) - if source_values: - metadata["source"] = ",".join(source_values) - - doc = Document(page_content=page_content, metadata=metadata) - docs.append(doc) - - return docs diff --git a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py index 121948075a316..81e939b4ab2da 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py @@ -47,6 +47,14 @@ warnings.warn("psycopg2 not installed, skipping corresponding tests", UserWarning) psycopg2_installed = False +try: + import crate.client.sqlalchemy # noqa: F401 + + cratedb_installed = True +except ImportError: + warnings.warn("cratedb not installed, skipping corresponding tests", UserWarning) + cratedb_installed = False + @pytest.fixture() def engine(db_uri: str) -> sa.Engine: @@ -75,6 +83,9 @@ def provision_database(engine: sa.Engine) -> None: continue connection.execute(sa.text(statement)) connection.commit() + if engine.dialect.name.startswith("crate"): + connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) + connection.commit() tmpdir = TemporaryDirectory() @@ -103,6 +114,18 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None: "postgresql+psycopg2://langchain:langchain@localhost:6023/langchain" ) ids.append("postgresql") + if cratedb_installed: + # We use non-standard port for testing purposes. + # The easiest way to spin up the PostgreSQL instance is to use + # the docker compose file at the root of the repo located at + # langchain/docker/docker-compose.yml + # use `docker compose up postgres` to start the instance + # it will have the appropriate credentials set up including + # being exposed on the appropriate port. + urls.append( + "crate://crate@localhost/?schema=testdrive" + ) + ids.append("cratedb") metafunc.parametrize("db_uri", urls, ids=ids) diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py deleted file mode 100644 index 6cd42439024e0..0000000000000 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_cratedb.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -Test SQLAlchemy/CrateDB document loader functionality. - -cd tests/integration_tests/document_loaders/docker-compose -docker-compose -f cratedb.yml up -""" -import logging -import os -import unittest - -import pytest -import sqlalchemy as sa -import sqlparse - -from langchain_community.document_loaders import CrateDBLoader -from tests.data import MLB_TEAMS_2012_SQL - -logging.basicConfig(level=logging.DEBUG) - -try: - import crate.client.sqlalchemy # noqa: F401 - - crate_client_installed = True -except ImportError: - crate_client_installed = False - - -CONNECTION_STRING = os.environ.get( - "TEST_CRATEDB_CONNECTION_STRING", "crate://crate@localhost/?schema=testdrive" -) - - -@pytest.fixture -def engine() -> sa.Engine: - """ - Return an SQLAlchemy engine object. - """ - return sa.create_engine(CONNECTION_STRING, echo=False) - - -@pytest.fixture() -def provision_database(engine: sa.Engine) -> None: - """ - Provision database with table schema and data. - """ - sql_statements = MLB_TEAMS_2012_SQL.read_text() - with engine.connect() as connection: - connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) - for statement in sqlparse.split(sql_statements): - connection.execute(sa.text(statement)) - connection.execute(sa.text("REFRESH TABLE mlb_teams_2012;")) - connection.commit() - - -@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") -def test_cratedb_loader_no_options() -> None: - """Test SQLAlchemy loader with CrateDB.""" - - loader = CrateDBLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {} - - -@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") -def test_cratedb_loader_page_content_columns() -> None: - """Test SQLAlchemy loader with CrateDB.""" - - loader = CrateDBLoader( - "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", - url=CONNECTION_STRING, - page_content_columns=["a"], - ) - docs = loader.load() - - assert len(docs) == 2 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {} - - assert docs[1].page_content == "a: 3" - assert docs[1].metadata == {} - - -@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") -def test_cratedb_loader_metadata_columns() -> None: - """Test SQLAlchemy loader with CrateDB.""" - - loader = CrateDBLoader( - "SELECT 1 AS a, 2 AS b", - url=CONNECTION_STRING, - page_content_columns=["a"], - metadata_columns=["b"], - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {"b": 2} - - -@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") -def test_cratedb_loader_real_data_with_sql(provision_database: None) -> None: - """Test SQLAlchemy loader with CrateDB.""" - - loader = CrateDBLoader( - query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', - url=CONNECTION_STRING, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == {} - - -@unittest.skipIf(not crate_client_installed, "CrateDB client not installed") -def test_cratedb_loader_real_data_with_selectable(provision_database: None) -> None: - """Test SQLAlchemy loader with CrateDB.""" - - # Define an SQLAlchemy table. - mlb_teams_2012 = sa.Table( - "mlb_teams_2012", - sa.MetaData(), - sa.Column("Team", sa.VARCHAR), - sa.Column("Payroll (millions)", sa.FLOAT), - sa.Column("Wins", sa.BIGINT), - ) - - # Query the database table using an SQLAlchemy selectable. - select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) - loader = CrateDBLoader( - query=select, - url=CONNECTION_STRING, - include_query_into_metadata=True, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == { - "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' - 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' - 'ORDER BY mlb_teams_2012."Team"' - } diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py deleted file mode 100644 index 38dd6cfd62f09..0000000000000 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_postgresql.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Test SQLAlchemy/PostgreSQL document loader functionality. - -cd tests/integration_tests/document_loaders/docker-compose -docker-compose -f postgresql.yml up -""" -import logging -import os -import unittest - -import pytest -import sqlalchemy as sa -import sqlparse - -from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader -from tests.data import MLB_TEAMS_2012_SQL - -logging.basicConfig(level=logging.DEBUG) - - -try: - import psycopg2 # noqa: F401 - - psycopg2_installed = True -except ImportError: - psycopg2_installed = False - - -CONNECTION_STRING = os.environ.get( - "TEST_POSTGRESQL_CONNECTION_STRING", - "postgresql+psycopg2://postgres@localhost:5432/", -) - - -@pytest.fixture -def engine() -> sa.Engine: - """ - Return an SQLAlchemy engine object. - """ - if not psycopg2_installed: - raise pytest.skip("psycopg2 not installed") - return sa.create_engine(CONNECTION_STRING, echo=False) - - -@pytest.fixture() -def provision_database(engine: sa.Engine) -> None: - """ - Provision database with table schema and data. - """ - sql_statements = MLB_TEAMS_2012_SQL.read_text() - with engine.connect() as connection: - connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) - for statement in sqlparse.split(sql_statements): - connection.execute(sa.text(statement)) - connection.commit() - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_no_options() -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_include_rownum_into_metadata() -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", - url=CONNECTION_STRING, - include_rownum_into_metadata=True, - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {"row": 0} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_include_query_into_metadata() -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", url=CONNECTION_STRING, include_query_into_metadata=True - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_page_content_columns() -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", - url=CONNECTION_STRING, - page_content_columns=["a"], - ) - docs = loader.load() - - assert len(docs) == 2 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {} - - assert docs[1].page_content == "a: 3" - assert docs[1].metadata == {} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_metadata_columns() -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", - url=CONNECTION_STRING, - page_content_columns=["a"], - metadata_columns=["b"], - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {"b": 2} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_real_data_with_sql(provision_database: None) -> None: - """Test SQLAlchemy loader with psycopg2.""" - - loader = SQLAlchemyLoader( - query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', - url=CONNECTION_STRING, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == {} - - -@unittest.skipIf(not psycopg2_installed, "psycopg2 not installed") -def test_postgresql_loader_real_data_with_selectable(provision_database: None) -> None: - """Test SQLAlchemy loader with psycopg2.""" - - # Define an SQLAlchemy table. - mlb_teams_2012 = sa.Table( - "mlb_teams_2012", - sa.MetaData(), - sa.Column("Team", sa.VARCHAR), - sa.Column("Payroll (millions)", sa.FLOAT), - sa.Column("Wins", sa.BIGINT), - ) - - # Query the database table using an SQLAlchemy selectable. - select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) - loader = SQLAlchemyLoader( - query=select, - url=CONNECTION_STRING, - include_query_into_metadata=True, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == { - "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' - 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' - 'ORDER BY mlb_teams_2012."Team"' - } diff --git a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py b/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py deleted file mode 100644 index 05759e5dea9ea..0000000000000 --- a/libs/community/tests/integration_tests/document_loaders/test_sqlalchemy_sqlite.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Test SQLAlchemy/SQLite document loader functionality. -""" -import logging -import unittest - -import pytest -import sqlalchemy as sa -import sqlparse -from _pytest.tmpdir import TempPathFactory - -from langchain_community.document_loaders.sqlalchemy import SQLAlchemyLoader -from tests.data import MLB_TEAMS_2012_SQL - -logging.basicConfig(level=logging.DEBUG) - - -try: - import sqlite3 # noqa: F401 - - sqlite3_installed = True -except ImportError: - sqlite3_installed = False - - -@pytest.fixture(scope="module") -def db_uri(tmp_path_factory: TempPathFactory) -> str: - """ - Return an SQLAlchemy URI for a temporary SQLite database. - """ - db_path = tmp_path_factory.getbasetemp().joinpath("testdrive.sqlite") - return f"sqlite:///{db_path}" - - -@pytest.fixture(scope="module") -def engine(db_uri: str) -> sa.Engine: - """ - Return an SQLAlchemy engine object. - """ - return sa.create_engine(db_uri, echo=False) - - -@pytest.fixture() -def provision_database(engine: sa.Engine) -> None: - """ - Provision database with table schema and data. - """ - sql_statements = MLB_TEAMS_2012_SQL.read_text() - with engine.connect() as connection: - connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) - for statement in sqlparse.split(sql_statements): - connection.execute(sa.text(statement)) - connection.commit() - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_no_options(db_uri: str) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader("SELECT 1 AS a, 2 AS b", url=db_uri) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_include_rownum_into_metadata(db_uri: str) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", - url=db_uri, - include_rownum_into_metadata=True, - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {"row": 0} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_include_query_into_metadata(db_uri: str) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", url=db_uri, include_query_into_metadata=True - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1\nb: 2" - assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_page_content_columns(db_uri: str) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", - url=db_uri, - page_content_columns=["a"], - ) - docs = loader.load() - - assert len(docs) == 2 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {} - - assert docs[1].page_content == "a: 3" - assert docs[1].metadata == {} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_metadata_columns(db_uri: str) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader( - "SELECT 1 AS a, 2 AS b", - url=db_uri, - page_content_columns=["a"], - metadata_columns=["b"], - ) - docs = loader.load() - - assert len(docs) == 1 - assert docs[0].page_content == "a: 1" - assert docs[0].metadata == {"b": 2} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_real_data_with_sql( - db_uri: str, provision_database: None -) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - loader = SQLAlchemyLoader( - query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', - url=db_uri, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == {} - - -@unittest.skipIf(not sqlite3_installed, "sqlite3 not installed") -def test_sqlite_loader_real_data_with_selectable( - db_uri: str, provision_database: None -) -> None: - """Test SQLAlchemy loader with sqlite3.""" - - # Define an SQLAlchemy table. - mlb_teams_2012 = sa.Table( - "mlb_teams_2012", - sa.MetaData(), - sa.Column("Team", sa.VARCHAR), - sa.Column("Payroll (millions)", sa.FLOAT), - sa.Column("Wins", sa.BIGINT), - ) - - # Query the database table using an SQLAlchemy selectable. - select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) - loader = SQLAlchemyLoader( - query=select, - url=db_uri, - include_query_into_metadata=True, - ) - docs = loader.load() - - assert len(docs) == 30 - assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" - assert docs[0].metadata == { - "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' - 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' - 'ORDER BY mlb_teams_2012."Team"' - } diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 6acb734a54309..f7171041a2425 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -3,6 +3,8 @@ import pytest import sqlalchemy as sa +import sqlalchemy.orm +from langchain_community.utilities.sql_database import SQLDatabase, truncate_word from packaging import version from sqlalchemy import ( Column, @@ -16,8 +18,6 @@ ) from sqlalchemy.engine import Engine, Result -from langchain_community.utilities.sql_database import SQLDatabase, truncate_word - is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1 metadata_obj = MetaData() @@ -56,6 +56,11 @@ def db_lazy_reflection(engine: Engine) -> SQLDatabase: @pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") +def test_configure_mappers() -> None: + """Test that configuring table mappers works.""" + sqlalchemy.orm.configure_mappers() + + def test_table_info(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" output = db.table_info diff --git a/libs/community/tests/unit_tests/test_sqlalchemy.py b/libs/community/tests/unit_tests/test_sqlalchemy.py deleted file mode 100644 index a434620155e00..0000000000000 --- a/libs/community/tests/unit_tests/test_sqlalchemy.py +++ /dev/null @@ -1,7 +0,0 @@ -import sqlalchemy.orm - -import langchain_community # noqa: F401 - - -def test_configure_mappers() -> None: - sqlalchemy.orm.configure_mappers() diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 058993541ab5b..db2f7b7a466fe 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -512,7 +512,6 @@ def __getattr__(name: str) -> Any: "SlackDirectoryLoader", "SnowflakeLoader", "SpreedlyLoader", - "SQLAlchemyLoader", "StripeLoader", "TelegramChatApiLoader", "TelegramChatFileLoader", From 41ccacf1650bd3f0e0ec0c556c1a06bfaa0c76ce Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 10 Jun 2024 23:09:01 +0200 Subject: [PATCH 21/28] CrateDB: Migrate from `crate[sqlalchemy]` to `sqlalchemy-cratedb` The CrateDB SQLAlchemy dialect needs more love, so it was separated from the DBAPI HTTP driver. --- .../integrations/document_loaders/cratedb.ipynb | 2 +- .../memory/cratedb_chat_message_history.ipynb | 9 ++++----- docs/docs/integrations/providers/cratedb.mdx | 16 ++++++++-------- .../docs/integrations/vectorstores/cratedb.ipynb | 10 +++++----- libs/community/extended_testing_deps.txt | 5 +++-- .../chat_message_histories/cratedb.py | 4 ---- .../vectorstores/cratedb/base.py | 4 ---- .../vectorstores/cratedb/extended.py | 2 +- .../vectorstores/cratedb/model.py | 5 ++++- .../vectorstores/cratedb/sqlalchemy_type.py | 4 ++-- .../document_loaders/test_sql_database.py | 2 +- 11 files changed, 29 insertions(+), 34 deletions(-) diff --git a/docs/docs/integrations/document_loaders/cratedb.ipynb b/docs/docs/integrations/document_loaders/cratedb.ipynb index 78a0e19138703..5dff3cc1955cb 100644 --- a/docs/docs/integrations/document_loaders/cratedb.ipynb +++ b/docs/docs/integrations/document_loaders/cratedb.ipynb @@ -32,7 +32,7 @@ }, "outputs": [], "source": [ - "#!pip install crash 'langchain[cratedb]'" + "#!pip install crash langchain sqlalchemy-cratedb" ] }, { diff --git a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb index f51f5f1d63fca..244b21d43d13f 100644 --- a/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb +++ b/docs/docs/integrations/memory/cratedb_chat_message_history.ipynb @@ -27,7 +27,7 @@ "execution_count": null, "outputs": [], "source": [ - "!#pip install 'langchain[cratedb]'" + "!#pip install langchain sqlalchemy-cratedb" ], "metadata": { "collapsed": false @@ -145,7 +145,6 @@ "from datetime import datetime\n", "from typing import Any\n", "\n", - "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier\n", "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n", "\n", @@ -159,7 +158,7 @@ "class CustomMessage(Base):\n", "\t__tablename__ = \"custom_message_store\"\n", "\n", - "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())\n", "\tsession_id = sa.Column(sa.Text)\n", "\ttype = sa.Column(sa.Text)\n", "\tcontent = sa.Column(sa.Text)\n", @@ -272,7 +271,7 @@ "import json\n", "import typing as t\n", "\n", - "from langchain.memory.chat_message_histories.cratedb import generate_autoincrement_identifier, CrateDBMessageConverter\n", + "from langchain.memory.chat_message_histories.cratedb import CrateDBMessageConverter\n", "from langchain.schema import _message_to_dict\n", "\n", "\n", @@ -280,7 +279,7 @@ "\n", "class MessageWithDifferentSessionIdColumn(Base):\n", "\t__tablename__ = \"message_store_different_session_id\"\n", - "\tid = sa.Column(sa.BigInteger, primary_key=True, default=generate_autoincrement_identifier)\n", + "\tid = sa.Column(sa.BigInteger, primary_key=True, server_default=sa.func.now())\n", "\tcustom_session_id = sa.Column(sa.Text)\n", "\tmessage = sa.Column(sa.Text)\n", "\n", diff --git a/docs/docs/integrations/providers/cratedb.mdx b/docs/docs/integrations/providers/cratedb.mdx index 220c35b86fd1c..dde1adf25e983 100644 --- a/docs/docs/integrations/providers/cratedb.mdx +++ b/docs/docs/integrations/providers/cratedb.mdx @@ -84,7 +84,7 @@ docker run --rm -it --name=cratedb --publish=4200:4200 --publish=5432:5432 \ ### Install Client ```bash -pip install 'crate[sqlalchemy]' 'langchain[openai]' 'crash' +pip install crash langchain langchain-openai sqlalchemy-cratedb ``` @@ -188,16 +188,16 @@ if __name__ == "__main__": [CrateDB]: https://github.com/crate/crate -[CrateDB Cloud]: https://crate.io/product +[CrateDB Cloud]: https://cratedb.com/product [CrateDB Cloud Console]: https://console.cratedb.cloud/ [CrateDB Cloud CRFREE]: https://community.crate.io/t/new-cratedb-cloud-edge-feature-cratedb-cloud-free-tier/1402 -[CrateDB SQLAlchemy dialect]: https://crate.io/docs/python/en/latest/sqlalchemy.html -[download CrateDB]: https://crate.io/download +[CrateDB SQLAlchemy dialect]: https://cratedb.com/docs/sqlalchemy-cratedb/ +[download CrateDB]: https://cratedb.com/download [Elastisearch]: https://github.com/elastic/elasticsearch -[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector -[free trial]: https://crate.io/lp-crfree?utm_source=langchain -[ISO 27001]: https://crate.io/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification -[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match +[`FLOAT_VECTOR`]: https://cratedb.com/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector +[free trial]: https://cratedb.com/lp-crfree?utm_source=langchain +[ISO 27001]: https://cratedb.com/blog/cratedb-elevates-its-security-standards-and-achieves-iso-27001-certification +[`KNN_MATCH`]: https://cratedb.com/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match [Lucene]: https://github.com/apache/lucene [OCI image]: https://hub.docker.com/_/crate [sign up]: https://console.cratedb.cloud/ diff --git a/docs/docs/integrations/vectorstores/cratedb.ipynb b/docs/docs/integrations/vectorstores/cratedb.ipynb index 06430e6355ae9..ee6bfc78321e4 100644 --- a/docs/docs/integrations/vectorstores/cratedb.ipynb +++ b/docs/docs/integrations/vectorstores/cratedb.ipynb @@ -27,11 +27,11 @@ "\n", "[CrateDB]: https://github.com/crate/crate\n", "[Elasticsearch]: https://github.com/elastic/elasticsearch\n", - "[`FLOAT_VECTOR`]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector\n", - "[`KNN_MATCH`]: https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match\n", + "[`FLOAT_VECTOR`]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector\n", + "[`KNN_MATCH`]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match\n", "[LangChain with CrateDB]: /docs/extras/integrations/providers/cratedb.html\n", "[Lucene]: https://github.com/apache/lucene\n", - "[Python client driver for CrateDB]: https://crate.io/docs/python/" + "[Python client driver for CrateDB]: https://cratedb.com/docs/python/" ] }, { @@ -55,7 +55,7 @@ "outputs": [], "source": [ "# Install required packages: LangChain, OpenAI SDK, and the CrateDB Python driver.\n", - "!pip install 'langchain[cratedb,openai]'" + "!pip install langchain langchain-openai sqlalchemy-cratedb" ] }, { @@ -100,7 +100,7 @@ "running on [CrateDB Cloud].\n", "\n", "[CrateDB Cloud]: https://console.cratedb.cloud/\n", - "[CrateDB using Docker]: https://crate.io/docs/crate/tutorials/en/latest/basic/index.html#docker" + "[CrateDB using Docker]: https://cratedb.com/docs/guide/install/container/" ], "metadata": { "collapsed": false diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 5bc78eaddb021..6093ec7ad584d 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -14,8 +14,8 @@ chardet>=5.1.0,<6 cloudpathlib>=0.18,<0.19 cloudpickle>=2.0.0 cohere>=4,<6 -crate>=0.34.0,<1 -cratedb-toolkit==0.0.12 +crate==1.0.0dev0 +cratedb-toolkit>=0.0.13,<0.1 databricks-vectorsearch>=0.21,<0.22 datasets>=2.15.0,<3 dgml-utils>=0.3.0,<0.4 @@ -78,6 +78,7 @@ requests-toolbelt>=1.0.0,<2 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 +sqlalchemy-cratedb>=0.37,<1 sqlite-vss>=0.1.2,<0.2 sqlite-vec>=0.1.0,<0.2 sseclient-py>=1.8.0,<2 diff --git a/libs/community/langchain_community/chat_message_histories/cratedb.py b/libs/community/langchain_community/chat_message_histories/cratedb.py index 45e287ec1f344..98c257e31b348 100644 --- a/libs/community/langchain_community/chat_message_histories/cratedb.py +++ b/libs/community/langchain_community/chat_message_histories/cratedb.py @@ -3,7 +3,6 @@ import sqlalchemy as sa from cratedb_toolkit.sqlalchemy import ( - patch_inspector, polyfill_refresh_after_dml, refresh_table, ) @@ -80,9 +79,6 @@ def __init__( session_id_field_name: str = "session_id", custom_message_converter: t.Optional[BaseMessageConverter] = None, ): - # FIXME: Refactor elsewhere. - patch_inspector() - super().__init__( session_id, connection_string, diff --git a/libs/community/langchain_community/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py index f4abac3ab7a2d..e6dbbacf04943 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -14,7 +14,6 @@ ) import sqlalchemy -from cratedb_toolkit.sqlalchemy.patch import patch_inspector from cratedb_toolkit.sqlalchemy.polyfill import ( refresh_table, ) @@ -90,9 +89,6 @@ def __post_init__( Initialize the store. """ - # FIXME: Could be a bug in CrateDB SQLAlchemy dialect. - patch_inspector() - self._engine = self._bind self.Session = sessionmaker(self._engine) diff --git a/libs/community/langchain_community/vectorstores/cratedb/extended.py b/libs/community/langchain_community/vectorstores/cratedb/extended.py index 324eca0861d95..bec8437286859 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/extended.py +++ b/libs/community/langchain_community/vectorstores/cratedb/extended.py @@ -24,7 +24,7 @@ class CrateDBVectorSearchMultiCollection(CrateDBVectorSearch): Provide functionality for searching multiple collections. It can not be used for indexing documents. - To use it, you should have the ``crate[sqlalchemy]`` Python package installed. + To use it, you should have the ``sqlalchemy-cratedb`` Python package installed. Synopsis:: diff --git a/libs/community/langchain_community/vectorstores/cratedb/model.py b/libs/community/langchain_community/vectorstores/cratedb/model.py index f9dae6566d7c0..e9846cc452875 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/model.py +++ b/libs/community/langchain_community/vectorstores/cratedb/model.py @@ -2,7 +2,10 @@ from typing import Any, List, Optional, Tuple import sqlalchemy -from crate.client.sqlalchemy.types import ObjectType +try: + from sqlalchemy_cratedb import ObjectType +except ImportError: + from crate.client.sqlalchemy.types import ObjectType from sqlalchemy.orm import Session, declarative_base, relationship from langchain_community.vectorstores.cratedb.sqlalchemy_type import FloatVector diff --git a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py index e784c3013a3d9..624284a7065c6 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py +++ b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py @@ -42,8 +42,8 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: class FloatVector(UserDefinedType): """ - https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector - https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match + https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector + https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match """ cache_ok = True diff --git a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py index 81e939b4ab2da..2a4cdd030f851 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py @@ -48,7 +48,7 @@ psycopg2_installed = False try: - import crate.client.sqlalchemy # noqa: F401 + import sqlalchemy_cratedb # noqa: F401 cratedb_installed = True except ImportError: From 3bc63a83155c5945cb26f189443e677a2496204c Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 18 Jun 2024 23:49:13 +0200 Subject: [PATCH 22/28] CrateDB: Stop using CrateDB Toolkit --- libs/community/extended_testing_deps.txt | 3 +- .../chat_message_histories/cratedb.py | 9 ++---- .../vectorstores/cratedb/base.py | 9 ++---- .../vectorstores/cratedb/sqlalchemy_patch.py | 29 ------------------- 4 files changed, 7 insertions(+), 43 deletions(-) delete mode 100644 libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 6093ec7ad584d..2553cd68f67c1 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -15,7 +15,6 @@ cloudpathlib>=0.18,<0.19 cloudpickle>=2.0.0 cohere>=4,<6 crate==1.0.0dev0 -cratedb-toolkit>=0.0.13,<0.1 databricks-vectorsearch>=0.21,<0.22 datasets>=2.15.0,<3 dgml-utils>=0.3.0,<0.4 @@ -78,7 +77,7 @@ requests-toolbelt>=1.0.0,<2 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 -sqlalchemy-cratedb>=0.37,<1 +sqlalchemy-cratedb>=0.38.0,<1 sqlite-vss>=0.1.2,<0.2 sqlite-vec>=0.1.0,<0.2 sseclient-py>=1.8.0,<2 diff --git a/libs/community/langchain_community/chat_message_histories/cratedb.py b/libs/community/langchain_community/chat_message_histories/cratedb.py index 98c257e31b348..2acd0a34e12e3 100644 --- a/libs/community/langchain_community/chat_message_histories/cratedb.py +++ b/libs/community/langchain_community/chat_message_histories/cratedb.py @@ -2,10 +2,7 @@ import typing as t import sqlalchemy as sa -from cratedb_toolkit.sqlalchemy import ( - polyfill_refresh_after_dml, - refresh_table, -) +from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict from langchain_community.chat_message_histories.sql import ( @@ -87,8 +84,8 @@ def __init__( custom_message_converter=custom_message_converter, ) - # TODO: Check how this can be improved. - polyfill_refresh_after_dml(self.Session) + # Patch dialect to invoke `REFRESH TABLE` after each DML operation. + refresh_after_dml(self.Session) def _messages_query(self) -> sa.Select: """ diff --git a/libs/community/langchain_community/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py index e6dbbacf04943..9fc5a5119a963 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -14,9 +14,6 @@ ) import sqlalchemy -from cratedb_toolkit.sqlalchemy.polyfill import ( - refresh_table, -) from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env @@ -24,7 +21,7 @@ from sqlalchemy.orm import sessionmaker from langchain_community.vectorstores.cratedb.model import ModelFactory -from langchain_community.vectorstores.cratedb.sqlalchemy_patch import polyfill_refresh_after_dml_engine +from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table class DistanceStrategy(str, enum.Enum): @@ -92,8 +89,8 @@ def __post_init__( self._engine = self._bind self.Session = sessionmaker(self._engine) - # TODO: Pull in from a future `sqlalchemy-cratedb`. - polyfill_refresh_after_dml_engine(self._engine) + # Patch dialect to invoke `REFRESH TABLE` after each DML operation. + refresh_after_dml(self._engine) # Need to defer initialization, because dimension size # can only be figured out at runtime. diff --git a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py deleted file mode 100644 index fd50e55119be5..0000000000000 --- a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_patch.py +++ /dev/null @@ -1,29 +0,0 @@ -import sqlalchemy as sa - - -def polyfill_refresh_after_dml_engine(engine: sa.engine.Engine): - def receive_after_execute( - conn: sa.engine.Connection, - clauseelement, - multiparams, - params, - execution_options, - result, - ): - """ - Run a `REFRESH TABLE ...` command after each DML operation (INSERT, UPDATE, - DELETE). This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` - adapters. - - TODO: Pull in from a future `sqlalchemy-cratedb`. - """ - if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)): - if not isinstance(clauseelement.table, sa.sql.Join): - full_table_name = f'"{clauseelement.table.name}"' - if clauseelement.table.schema is not None: - full_table_name = ( - f'"{clauseelement.table.schema}".' + full_table_name - ) - conn.execute(sa.text(f"REFRESH TABLE {full_table_name};")) - - sa.event.listen(engine, "after_execute", receive_after_execute) From c561a95ae4d469dac969ba97c6e2cf8862390182 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 25 Jun 2024 19:14:44 +0200 Subject: [PATCH 23/28] CrateDB: Stop using local `FloatVector` implementation --- .../vectorstores/cratedb/model.py | 7 +- .../vectorstores/cratedb/sqlalchemy_type.py | 84 ------------------- 2 files changed, 1 insertion(+), 90 deletions(-) delete mode 100644 libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py diff --git a/libs/community/langchain_community/vectorstores/cratedb/model.py b/libs/community/langchain_community/vectorstores/cratedb/model.py index e9846cc452875..7499e9080e93a 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/model.py +++ b/libs/community/langchain_community/vectorstores/cratedb/model.py @@ -2,14 +2,9 @@ from typing import Any, List, Optional, Tuple import sqlalchemy -try: - from sqlalchemy_cratedb import ObjectType -except ImportError: - from crate.client.sqlalchemy.types import ObjectType +from sqlalchemy_cratedb import ObjectType, FloatVector from sqlalchemy.orm import Session, declarative_base, relationship -from langchain_community.vectorstores.cratedb.sqlalchemy_type import FloatVector - def generate_uuid() -> str: return str(uuid.uuid4()) diff --git a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py b/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py deleted file mode 100644 index 624284a7065c6..0000000000000 --- a/libs/community/langchain_community/vectorstores/cratedb/sqlalchemy_type.py +++ /dev/null @@ -1,84 +0,0 @@ -# TODO: Refactor to CrateDB SQLAlchemy dialect. -import typing as t - -import numpy as np -import numpy.typing as npt -import sqlalchemy as sa -from sqlalchemy.types import UserDefinedType - -__all__ = ["FloatVector"] - - -def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: - # from `pgvector.utils` - # could be ndarray if already cast by lower-level driver - if value is None or isinstance(value, np.ndarray): - return value - - return np.array(value, dtype=np.float32) - - -def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: - # from `pgvector.utils` - if value is None: - return value - - if isinstance(value, np.ndarray): - if value.ndim != 1: - raise ValueError("expected ndim to be 1") - - if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( - value.dtype, np.floating - ): - raise ValueError("dtype must be numeric") - - value = value.tolist() - - if dim is not None and len(value) != dim: - raise ValueError("expected %d dimensions, not %d" % (dim, len(value))) - - return value - - -class FloatVector(UserDefinedType): - """ - https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector - https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match - """ - - cache_ok = True - - def __init__(self, dim: t.Optional[int] = None) -> None: - super(UserDefinedType, self).__init__() - self.dim = dim - - def get_col_spec(self, **kw: t.Any) -> str: - if self.dim is None: - return "FLOAT_VECTOR" - return "FLOAT_VECTOR(%d)" % self.dim - - def bind_processor(self, dialect: sa.Dialect) -> t.Callable: - def process(value: t.Iterable) -> t.Optional[t.List]: - return to_db(value, self.dim) - - return process - - def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: - def process(value: t.Any) -> t.Optional[npt.ArrayLike]: - return from_db(value) - - return process - - """ - CrateDB currently only supports similarity function `VectorSimilarityFunction.EUCLIDEAN`. - -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55 - - On the other hand, pgvector use a comparator to apply different similarity functions as operators, - see `pgvector.sqlalchemy.Vector.comparator_factory`. - - <->: l2/euclidean_distance - <#>: max_inner_product - <=>: cosine_distance - - TODO: Discuss. - """ # noqa: E501 From 8b278a8a7551f8baea085a331587525ca499bc6f Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 24 Oct 2024 20:29:06 +0200 Subject: [PATCH 24/28] CrateDB: Format code. Satisfy linter and type checker. ruff + mypy --- .../langchain_community/chat_message_histories/cratedb.py | 2 +- .../langchain_community/vectorstores/cratedb/base.py | 7 ++++--- .../langchain_community/vectorstores/cratedb/extended.py | 1 - .../langchain_community/vectorstores/cratedb/model.py | 8 ++++---- .../document_loaders/test_sql_database.py | 4 +--- .../tests/integration_tests/vectorstores/test_cratedb.py | 1 + libs/community/tests/unit_tests/test_sql_database.py | 3 ++- .../langchain/memory/chat_message_histories/cratedb.py | 2 +- 8 files changed, 14 insertions(+), 14 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/cratedb.py b/libs/community/langchain_community/chat_message_histories/cratedb.py index 2acd0a34e12e3..20870b39fdb84 100644 --- a/libs/community/langchain_community/chat_message_histories/cratedb.py +++ b/libs/community/langchain_community/chat_message_histories/cratedb.py @@ -2,8 +2,8 @@ import typing as t import sqlalchemy as sa -from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table from langchain.schema import BaseMessage, _message_to_dict, messages_from_dict +from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table from langchain_community.chat_message_histories.sql import ( BaseMessageConverter, diff --git a/libs/community/langchain_community/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py index 9fc5a5119a963..7173f92ff68be 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -19,9 +19,9 @@ from langchain.utils import get_from_dict_or_env from langchain.vectorstores.pgvector import PGVector from sqlalchemy.orm import sessionmaker +from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table from langchain_community.vectorstores.cratedb.model import ModelFactory -from sqlalchemy_cratedb.support import refresh_after_dml, refresh_table class DistanceStrategy(str, enum.Enum): @@ -87,7 +87,7 @@ def __post_init__( """ self._engine = self._bind - self.Session = sessionmaker(self._engine) + self.Session = sessionmaker(bind=self._engine) # type: ignore[call-overload] # Patch dialect to invoke `REFRESH TABLE` after each DML operation. refresh_after_dml(self._engine) @@ -199,6 +199,7 @@ def drop_tables(self) -> None: def delete( self, ids: Optional[List[str]] = None, + collection_only: bool = False, **kwargs: Any, ) -> None: """ @@ -214,7 +215,7 @@ def delete( patch, listening to `after_delete` events seems not be able to catch it. """ - super().delete(ids=ids, **kwargs) + super().delete(ids=ids, collection_only=collection_only, **kwargs) # CrateDB: Synchronize data because `on_flush` does not catch it. with self.Session() as session: diff --git a/libs/community/langchain_community/vectorstores/cratedb/extended.py b/libs/community/langchain_community/vectorstores/cratedb/extended.py index bec8437286859..4d5ae8e88f3fd 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/extended.py +++ b/libs/community/langchain_community/vectorstores/cratedb/extended.py @@ -8,7 +8,6 @@ ) import sqlalchemy - from langchain.schema.embeddings import Embeddings from langchain_community.vectorstores.cratedb.base import ( diff --git a/libs/community/langchain_community/vectorstores/cratedb/model.py b/libs/community/langchain_community/vectorstores/cratedb/model.py index 7499e9080e93a..31153eda30ce4 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/model.py +++ b/libs/community/langchain_community/vectorstores/cratedb/model.py @@ -2,8 +2,8 @@ from typing import Any, List, Optional, Tuple import sqlalchemy -from sqlalchemy_cratedb import ObjectType, FloatVector from sqlalchemy.orm import Session, declarative_base, relationship +from sqlalchemy_cratedb import FloatVector, ObjectType def generate_uuid() -> str: @@ -48,7 +48,7 @@ class CollectionStore(BaseModel): ) @classmethod - def get_by_name(cls, session: Session, name: str) -> "CollectionStore": + def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: return session.query(cls).filter(cls.name == name).first() # type: ignore[attr-defined] @classmethod @@ -95,8 +95,8 @@ class EmbeddingStore(BaseModel): ) collection = relationship("CollectionStore", back_populates="embeddings") - embedding = sqlalchemy.Column(FloatVector(self.dimensions)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + embedding: sqlalchemy.Column = sqlalchemy.Column(FloatVector(self.dimensions)) + document: sqlalchemy.Column = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata: sqlalchemy.Column = sqlalchemy.Column(ObjectType, nullable=True) # custom_id : any user defined id diff --git a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py index 2a4cdd030f851..a687ab7c0491e 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py +++ b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py @@ -122,9 +122,7 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None: # use `docker compose up postgres` to start the instance # it will have the appropriate credentials set up including # being exposed on the appropriate port. - urls.append( - "crate://crate@localhost/?schema=testdrive" - ) + urls.append("crate://crate@localhost/?schema=testdrive") ids.append("cratedb") metafunc.parametrize("db_uri", urls, ids=ids) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index 3f92717c2a138..e258a42177f70 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -4,6 +4,7 @@ cd tests/integration_tests/vectorstores/docker-compose docker-compose -f cratedb.yml up """ + import os import re from typing import Dict, Generator, List diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index f7171041a2425..0a17f15aab3a9 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -4,7 +4,6 @@ import pytest import sqlalchemy as sa import sqlalchemy.orm -from langchain_community.utilities.sql_database import SQLDatabase, truncate_word from packaging import version from sqlalchemy import ( Column, @@ -18,6 +17,8 @@ ) from sqlalchemy.engine import Engine, Result +from langchain_community.utilities.sql_database import SQLDatabase, truncate_word + is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1 metadata_obj = MetaData() diff --git a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py index 376cacd8985ad..3514521c108ed 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cratedb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cratedb.py @@ -13,7 +13,7 @@ # handling optional imports. DEPRECATED_LOOKUP = { "CrateDBChatMessageHistory": "langchain_community.chat_message_histories", - "CrateDBMessageConverter": "langchain_community.chat_message_histories" + "CrateDBMessageConverter": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) From 41f646270bcd2627e2413cf82bb47cb75bd661f5 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 28 Oct 2024 21:52:38 +0100 Subject: [PATCH 25/28] CrateDB: Remove adjustment to ConsistentFakeEmbeddings in langchain-core --- .../tests/integration_tests/cache/fake_embeddings.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py index 1241e47e71e83..63394e78cbe84 100644 --- a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py @@ -53,11 +53,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" - if text not in self.known_texts: - return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - return [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] + return self.embed_documents([text])[0] class AngularTwoDimensionalEmbeddings(Embeddings): From 19a09ab821d4ee2169f9db4a9444e406939d20fa Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 28 Oct 2024 23:26:49 +0100 Subject: [PATCH 26/28] CrateDB: Refactor leftovers from langchain-core to langchain-community --- .../tests/integration_tests/memory/test_cratedb.py | 0 libs/langchain/tests/data.py | 4 ---- 2 files changed, 4 deletions(-) rename libs/{langchain => community}/tests/integration_tests/memory/test_cratedb.py (100%) diff --git a/libs/langchain/tests/integration_tests/memory/test_cratedb.py b/libs/community/tests/integration_tests/memory/test_cratedb.py similarity index 100% rename from libs/langchain/tests/integration_tests/memory/test_cratedb.py rename to libs/community/tests/integration_tests/memory/test_cratedb.py diff --git a/libs/langchain/tests/data.py b/libs/langchain/tests/data.py index c1206fb2ccbbe..b4f53baf356b4 100644 --- a/libs/langchain/tests/data.py +++ b/libs/langchain/tests/data.py @@ -10,7 +10,3 @@ HELLO_PDF = _EXAMPLES_DIR / "hello.pdf" LAYOUT_PARSER_PAPER_PDF = _EXAMPLES_DIR / "layout-parser-paper.pdf" DUPLICATE_CHARS = _EXAMPLES_DIR / "duplicate-chars.pdf" - -# Paths to data files -MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv" -MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql" From 91da7703d9982c2129ed0ec4c40502e3227ab581 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 29 Oct 2024 00:19:27 +0100 Subject: [PATCH 27/28] CrateDB: Remove documentation about SQLDatabaseLoader Those pages have been submitted to LangChain already. --- docs/docs/how_to/sql_database.mdx | 165 -------- .../document_loaders/sql_database.ipynb | 360 ------------------ 2 files changed, 525 deletions(-) delete mode 100644 docs/docs/how_to/sql_database.mdx delete mode 100644 docs/docs/integrations/document_loaders/sql_database.ipynb diff --git a/docs/docs/how_to/sql_database.mdx b/docs/docs/how_to/sql_database.mdx deleted file mode 100644 index 1ecdeda75b307..0000000000000 --- a/docs/docs/how_to/sql_database.mdx +++ /dev/null @@ -1,165 +0,0 @@ -# SQLDatabaseLoader - - -## About - -The `SQLDatabaseLoader` loads records from any database supported by -[SQLAlchemy], see [SQLAlchemy dialects] for the whole list of supported -SQL databases and dialects. - -You can either use plain SQL for querying, or use an SQLAlchemy `Select` -statement object, if you are using SQLAlchemy-Core or -ORM. - -You can select which columns to place into the document, which columns -to place into its metadata, which columns to use as a `source` attribute -in metadata, and whether to include the result row number and/or the SQL -query expression into the metadata. - - -## Example - -This example uses PostgreSQL, and the `psycopg2` driver. - - -### Prerequisites - -```shell -psql postgresql://postgres@localhost/ --command "CREATE DATABASE testdrive;" -psql postgresql://postgres@localhost/testdrive < ./libs/langchain/tests/integration_tests/examples/mlb_teams_2012.sql -``` - - -### Basic loading - -```python -from langchain_community.document_loaders.sql_database import SQLDatabaseLoader -from pprint import pprint - - -loader = SQLDatabaseLoader( - query="SELECT * FROM mlb_teams_2012 LIMIT 3;", - url="postgresql+psycopg2://postgres@localhost:5432/testdrive", -) -docs = loader.load() -``` - -```python -pprint(docs) -``` - - - -``` -[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={}), - Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={}), - Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={})] -``` - - - - -## Enriching metadata - -Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to -optionally populate the `metadata` dictionary with corresponding information. - -Having the `query` within metadata is useful when using documents loaded from -database tables for chains that answer questions using their origin queries. - -```python -loader = SQLDatabaseLoader( - query="SELECT * FROM mlb_teams_2012 LIMIT 3;", - url="postgresql+psycopg2://postgres@localhost:5432/testdrive", - include_rownum_into_metadata=True, - include_query_into_metadata=True, -) -docs = loader.load() -``` - -```python -pprint(docs) -``` - - - -``` -[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), - Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'}), - Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'row': 2, 'query': 'SELECT * FROM mlb_teams_2012 LIMIT 3;'})] -``` - - - - -## Customizing metadata - -Use the `page_content_columns`, and `metadata_columns` options to optionally populate -the `metadata` dictionary with corresponding information. When `page_content_columns` -is empty, all columns will be used. - -```python -import functools - -row_to_content = functools.partial( - SQLDatabaseLoader.page_content_default_mapper, column_names=["Payroll (millions)", "Wins"] -) -row_to_metadata = functools.partial( - SQLDatabaseLoader.metadata_default_mapper, column_names=["Team"] -) - -loader = SQLDatabaseLoader( - query="SELECT * FROM mlb_teams_2012 LIMIT 3;", - url="postgresql+psycopg2://postgres@localhost:5432/testdrive", - page_content_mapper=row_to_content, - metadata_mapper=row_to_metadata, -) -docs = loader.load() -``` - -```python -pprint(docs) -``` - - - -``` -[Document(page_content='Payroll (millions): 81.34\nWins: 98', metadata={'Team': 'Nationals'}), - Document(page_content='Payroll (millions): 82.2\nWins: 97', metadata={'Team': 'Reds'}), - Document(page_content='Payroll (millions): 197.96\nWins: 95', metadata={'Team': 'Yankees'})] -``` - - - - -## Specify column(s) to identify the document source - -Use the `source_columns` option to specify the columns to use as a "source" for the -document created from each row. This is useful for identifying documents through -their metadata. Typically, you may use the primary key column(s) for that purpose. - -```python -loader = SQLDatabaseLoader( - query="SELECT * FROM mlb_teams_2012 LIMIT 3;", - url="postgresql+psycopg2://postgres@localhost:5432/testdrive", - source_columns=["Team"], -) -docs = loader.load() -``` - -```python -pprint(docs) -``` - - - -``` -[Document(page_content='Team: Nationals\nPayroll (millions): 81.34\nWins: 98', metadata={'source': 'Nationals'}), - Document(page_content='Team: Reds\nPayroll (millions): 82.2\nWins: 97', metadata={'source': 'Reds'}), - Document(page_content='Team: Yankees\nPayroll (millions): 197.96\nWins: 95', metadata={'source': 'Yankees'})] -``` - - - - -[SQLAlchemy]: https://www.sqlalchemy.org/ -[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/20/dialects/ diff --git a/docs/docs/integrations/document_loaders/sql_database.ipynb b/docs/docs/integrations/document_loaders/sql_database.ipynb deleted file mode 100644 index 9b3fe41df43fa..0000000000000 --- a/docs/docs/integrations/document_loaders/sql_database.ipynb +++ /dev/null @@ -1,360 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SQL Database\n", - "\n", - "## About\n", - "\n", - "The `SQLDatabaseLoader` loads records from any database supported by\n", - "[SQLAlchemy], see [SQLAlchemy dialects] for the whole list of supported\n", - "SQL databases and dialects.\n", - "\n", - "For talking to the database, the document loader uses the [SQLDatabase]\n", - "utility from the LangChain integration toolkit.\n", - "\n", - "You can either use plain SQL for querying, or use an SQLAlchemy `Select`\n", - "statement object, if you are using SQLAlchemy-Core or -ORM.\n", - "\n", - "You can select which columns to place into the document, which columns\n", - "to place into its metadata, which columns to use as a `source` attribute\n", - "in metadata, and whether to include the result row number and/or the SQL\n", - "query expression into the metadata.\n", - "\n", - "## What's inside\n", - "\n", - "This notebook covers how to load documents from an [SQLite] database,\n", - "using the [SQLAlchemy] document loader.\n", - "\n", - "It loads the result of a database query with one document per row.\n", - "\n", - "[SQLAlchemy]: https://www.sqlalchemy.org/\n", - "[SQLAlchemy dialects]: https://docs.sqlalchemy.org/en/latest/dialects/\n", - "[SQLDatabase]: https://python.langchain.com/docs/integrations/toolkits/sql_database\n", - "[SQLite]: https://sqlite.org/\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## Prerequisites" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "#!pip install langchain langchain-community sqlalchemy termsql" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "Populate SQLite database with example input data." - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Nationals|81.34|98\r\n", - "Reds|82.2|97\r\n", - "Yankees|197.96|95\r\n", - "Giants|117.62|94\r\n", - "Braves|83.31|94\r\n", - "Athletics|55.37|94\r\n", - "Rangers|120.51|93\r\n", - "Orioles|81.43|93\r\n", - "Rays|64.17|90\r\n", - "Angels|154.49|89\r\n", - "Tigers|132.3|88\r\n", - "Cardinals|110.3|88\r\n", - "Dodgers|95.14|86\r\n", - "White Sox|96.92|85\r\n", - "Brewers|97.65|83\r\n", - "Phillies|174.54|81\r\n", - "Diamondbacks|74.28|81\r\n", - "Pirates|63.43|79\r\n", - "Padres|55.24|76\r\n", - "Mariners|81.97|75\r\n", - "Mets|93.35|74\r\n", - "Blue Jays|75.48|73\r\n", - "Royals|60.91|72\r\n", - "Marlins|118.07|69\r\n", - "Red Sox|173.18|69\r\n", - "Indians|78.43|68\r\n", - "Twins|94.08|66\r\n", - "Rockies|78.06|64\r\n", - "Cubs|88.19|61\r\n", - "Astros|60.65|55\r\n", - "||\r\n" - ] - } - ], - "source": [ - "!termsql --infile=./example_data/mlb_teams_2012.csv --head --csv --outfile=example.sqlite --table=payroll" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## Basic usage" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from pprint import pprint\n", - "\n", - "from langchain_community.document_loaders import SQLDatabaseLoader\n", - "\n", - "loader = SQLDatabaseLoader(\n", - " \"SELECT * FROM payroll LIMIT 2\",\n", - " url=\"sqlite:///example.sqlite\",\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98'),\n", - " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97')]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Specify which columns are content vs. metadata\n", - "\n", - "Use the `page_content_mapper` keyword argument to optionally customize how to derive\n", - "a page content string from an input database record / row. By default, all columns\n", - "will be used.\n", - "\n", - "Use the `metadata_mapper` keyword argument to optionally customize how to derive\n", - "a document metadata dictionary from an input database record / row. By default,\n", - "document metadata will be empty." - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [], - "source": [ - "import functools\n", - "\n", - "# Configure built-in page content mapper to include only specified columns.\n", - "row_to_content = functools.partial(\n", - " SQLDatabaseLoader.page_content_default_mapper, column_names=[\"Team\", \"Wins\"]\n", - ")\n", - "\n", - "# Configure built-in metadata dictionary mapper to include specified columns.\n", - "row_to_metadata = functools.partial(\n", - " SQLDatabaseLoader.metadata_default_mapper, column_names=[\"Payroll (millions)\"]\n", - ")\n", - "\n", - "loader = SQLDatabaseLoader(\n", - " \"SELECT * FROM payroll LIMIT 2\",\n", - " url=\"sqlite:///example.sqlite\",\n", - " page_content_mapper=row_to_content,\n", - " metadata_mapper=row_to_metadata,\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nWins: 98', metadata={'Payroll (millions)': 81.34}),\n", - " Document(page_content='Team: Reds\\nWins: 97', metadata={'Payroll (millions)': 82.2})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Those examples demonstrate how to use custom functions to define arbitrary\n", - "mapping rules by using Python code.\n", - "```python\n", - "def page_content_mapper(row: sa.RowMapping, column_names: Optional[List[str]] = None) -> str:\n", - " return f\"Team: {row['Team']}\"\n", - "```\n", - "```python\n", - "def metadata_default_mapper(row: sa.RowMapping, column_names: Optional[List[str]] = None) -> Dict[str, Any]:\n", - " return {\"team\": row['Team']}\n", - "```" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Specify column(s) to identify the document source\n", - "\n", - "Use the `source_columns` option to specify the columns to use as a \"source\" for the\n", - "document created from each row. This is useful for identifying documents through\n", - "their metadata. Typically, you may use the primary key column(s) for that purpose." - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [], - "source": [ - "loader = SQLDatabaseLoader(\n", - " \"SELECT * FROM payroll LIMIT 2\",\n", - " url=\"sqlite:///example.sqlite\",\n", - " source_columns=[\"Team\"],\n", - ")\n", - "documents = loader.load()" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98', metadata={'source': 'Nationals'}),\n", - " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97', metadata={'source': 'Reds'})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Enrich metadata with row number and/or original SQL query\n", - "\n", - "Use the `include_rownum_into_metadata` and `include_query_into_metadata` options to\n", - "optionally populate the `metadata` dictionary with corresponding information.\n", - "\n", - "Having the `query` within metadata is useful when using documents loaded from\n", - "database tables for chains that answer questions using their origin queries." - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 49, - "outputs": [], - "source": [ - "loader = SQLDatabaseLoader(\n", - " \"SELECT * FROM payroll LIMIT 2\",\n", - " url=\"sqlite:///example.sqlite\",\n", - " include_rownum_into_metadata=True,\n", - " include_query_into_metadata=True,\n", - ")\n", - "documents = loader.load()" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 50, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Document(page_content='Team: Nationals\\nPayroll (millions): 81.34\\nWins: 98', metadata={'row': 0, 'query': 'SELECT * FROM payroll LIMIT 2'}),\n", - " Document(page_content='Team: Reds\\nPayroll (millions): 82.2\\nWins: 97', metadata={'row': 1, 'query': 'SELECT * FROM payroll LIMIT 2'})]\n" - ] - } - ], - "source": [ - "pprint(documents)" - ], - "metadata": { - "collapsed": false - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 1faedfecf590c5ae34af7b1c2490d5329710491d Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 29 Oct 2024 00:19:37 +0100 Subject: [PATCH 28/28] CrateDB: Remove leftovers in langchain-core --- libs/community/tests/unit_tests/test_sql_database.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 0a17f15aab3a9..6acb734a54309 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -3,7 +3,6 @@ import pytest import sqlalchemy as sa -import sqlalchemy.orm from packaging import version from sqlalchemy import ( Column, @@ -57,11 +56,6 @@ def db_lazy_reflection(engine: Engine) -> SQLDatabase: @pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") -def test_configure_mappers() -> None: - """Test that configuring table mappers works.""" - sqlalchemy.orm.configure_mappers() - - def test_table_info(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" output = db.table_info