diff --git a/src/vdf_io/import_vdf/astradb_import.py b/src/vdf_io/import_vdf/astradb_import.py index aa12b15..5d20dc0 100644 --- a/src/vdf_io/import_vdf/astradb_import.py +++ b/src/vdf_io/import_vdf/astradb_import.py @@ -7,7 +7,6 @@ from astrapy.db import AstraDB from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider -from qdrant_client.http.models import Distance from vdf_io.constants import INT_MAX from vdf_io.names import DBNames diff --git a/src/vdf_io/import_vdf/lancedb_import.py b/src/vdf_io/import_vdf/lancedb_import.py index d1e033b..d9febbd 100644 --- a/src/vdf_io/import_vdf/lancedb_import.py +++ b/src/vdf_io/import_vdf/lancedb_import.py @@ -5,6 +5,7 @@ import pyarrow.parquet as pq import lancedb +from lancedb import create_index from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX from vdf_io.meta_types import NamespaceMeta @@ -23,6 +24,7 @@ class ImportLanceDB(ImportVDB): DB_NAME_SLUG = DBNames.LANCEDB + ID_COLUMN = "id" @classmethod def import_vdb(cls, args): @@ -103,6 +105,25 @@ def upsert_data(self): table = self.db.open_table(new_index_name) tqdm.write(f"Opened table {new_index_name}") + # Get the ID column from the parquet file schema + parquet_schema = pq.read_schema(parquet_files[0]) + id_column = None + for field in parquet_schema: + if field.name == self.ID_COLUMN: + id_column = field.name + break + + if id_column: + # Create index on the table + create_index(table, id_column) + tqdm.write( + f"Created index on {id_column} for table {new_index_name}" + ) + else: + tqdm.write( + f"Warning: No ID column {self.ID_COLUMN} found in schema. Skipping index creation for table {new_index_name}" + ) + for file in tqdm(parquet_files, desc="Iterating parquet files"): file_path = self.get_file_path(final_data_path, file) df = self.read_parquet_progress( @@ -117,7 +138,9 @@ def upsert_data(self): for col in df.columns: if col not in [field.name for field in table.schema]: col_type = df[col].dtype - tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}") + tqdm.write( + f"Adding column {col} of type {col_type} to {new_index_name}" + ) table.add_columns( { col: get_default_value(col_type),