From 8204723542dbe29d4f99198ab669bba34f989e56 Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Sat, 7 Jun 2025 19:45:19 +0530 Subject: [PATCH] feat: natural language search --- ...1f_add_recording_embedding_and_summary_.py | 46 ++ openadapt/app/tray.py | 407 ++++++++++++------ openadapt/config.defaults.json | 1 + openadapt/config.py | 3 + openadapt/db/crud.py | 62 +++ openadapt/embed.py | 69 +++ openadapt/models.py | 39 ++ openadapt/scripts/backfill_embeddings.py | 95 ++++ openadapt/similarity_search.py | 131 ++++++ pyproject.toml | 4 +- 10 files changed, 728 insertions(+), 129 deletions(-) create mode 100644 openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py create mode 100644 openadapt/embed.py create mode 100644 openadapt/scripts/backfill_embeddings.py create mode 100644 openadapt/similarity_search.py diff --git a/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py b/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py new file mode 100644 index 000000000..a4d466556 --- /dev/null +++ b/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py @@ -0,0 +1,46 @@ +"""add_recording_embedding_and_summary_tables + +Revision ID: bd9917da991f +Revises: 98505a067995 +Create Date: 2025-06-01 19:10:09.277603 + +""" +from alembic import op +import sqlalchemy as sa +import openadapt + +# revision identifiers, used by Alembic. +revision = 'bd9917da991f' +down_revision = '98505a067995' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('recording_embedding', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('recording_id', sa.Integer(), nullable=True), + sa.Column('embedding', sa.JSON(), nullable=True), + sa.Column('model_name', sa.String(), nullable=True), + sa.Column('timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True), + sa.ForeignKeyConstraint(['recording_id'], ['recording.id'], name=op.f('fk_recording_embedding_recording_id_recording')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_recording_embedding')) + ) + op.create_table('recording_summary', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('recording_id', sa.Integer(), nullable=True), + sa.Column('summary_text', sa.Text(), nullable=True), + sa.Column('summary_level', sa.String(), nullable=True), + sa.Column('timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True), + sa.ForeignKeyConstraint(['recording_id'], ['recording.id'], name=op.f('fk_recording_summary_recording_id_recording')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_recording_summary')) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('recording_summary') + op.drop_table('recording_embedding') + # ### end Alembic commands ### diff --git a/openadapt/app/tray.py b/openadapt/app/tray.py index 805a39b60..ba2f2a92f 100644 --- a/openadapt/app/tray.py +++ b/openadapt/app/tray.py @@ -21,6 +21,7 @@ QComboBox, QDialog, QDialogButtonBox, + QGroupBox, QHBoxLayout, QInputDialog, QLabel, @@ -44,6 +45,7 @@ from openadapt.strategies.base import BaseReplayStrategy from openadapt.utils import WrapStdout, get_posthog_instance from openadapt.visualize import main as visualize +from openadapt.similarity_search import get_enhanced_similarity_search # ensure all strategies are registered import openadapt.strategies # noqa: F401 @@ -144,6 +146,12 @@ def __init__(self) -> None: self.record_action.triggered.connect(self._record) self.menu.addAction(self.record_action) + self.search_action = TrackedQAction("💬 What do you want help with today?") + self.search_action.triggered.connect(self._search_workflows) + self.menu.addAction(self.search_action) + + self.menu.addSeparator() + self.visualize_menu = self.menu.addMenu("Visualize") self.replay_menu = self.menu.addMenu("Replay") self.delete_menu = self.menu.addMenu("Delete") @@ -256,33 +264,117 @@ def stop_recording(self) -> None: """Stop recording.""" Thread(target=stop_record).start() - def _visualize(self, recording: Recording) -> None: - """Visualize a recording. - - Args: - recording (Recording): The recording to visualize. - """ - self.show_toast("Starting visualization...") + def _search_workflows(self) -> None: + """Handle natural language workflow search with multiple results.""" + search_query, ok = QInputDialog.getText( + None, + "💬 OpenAdapt Assistant", + "What do you want help with today?\n(e.g., 'calculate taxes', 'organize files', 'send email')", + text="" + ) + + if not ok or not search_query.strip(): + return + + logger.info(f"User wants help with: {search_query}") + self.show_toast("Finding relevant workflows...") + try: - if self.visualize_proc is not None: - self.visualize_proc.kill() - self.visualize_proc = multiprocessing.Process( - target=WrapStdout(visualize), args=(recording,) - ) - self.visualize_proc.start() + with crud.get_new_session(read_only=True) as session: + results = get_enhanced_similarity_search(session, search_query, top_n=5) + + if not results: + self.show_toast( + "No matching workflows found. Try recording a workflow first!", + duration=5000 + ) + return + + filtered_results = [(r, c) for r, c in results if c >= 0.1] + + if not filtered_results: + self.show_toast( + f"No good matches found for '{search_query}'. Try different keywords.", + duration=5000 + ) + return + + self._show_workflow_selection(filtered_results, search_query) except Exception as e: - logger.error(e) - self.show_toast("Visualization failed.") + logger.error(f"Error searching workflows: {e}") + self.show_toast("Search failed. Please try again.", duration=5000) + + def _show_workflow_selection(self, results: list, search_query: str) -> None: + """Show a dialog with multiple workflow options for selection.""" + dialog = QDialog() + dialog.setWindowTitle("Select a Workflow") + dialog.setMinimumWidth(700) + layout = QVBoxLayout(dialog) + layout.setSpacing(15) + layout.setContentsMargins(20, 20, 20, 20) + + header_label = QLabel(f"Found workflows for: \"{search_query}\"") + header_label.setFont(QFont("Segoe UI", 14, QFont.Weight.Bold)) + layout.addWidget(header_label) + + workflow_group = QGroupBox("Select a workflow to replay:") + workflow_group.setFont(QFont("Segoe UI", 10)) + workflow_layout = QVBoxLayout(workflow_group) + workflow_layout.setSpacing(10) + + self.workflow_buttons = [] + for recording, confidence in results: + confidence_pct = confidence * 100 + if confidence_pct >= 80: confidence_color = "#28a745" + elif confidence_pct >= 60: confidence_color = "#ffc107" + else: confidence_color = "#dc3545" + + btn = QPushButton(recording.task_description) + btn.setMinimumHeight(50) + btn.setFont(QFont("Segoe UI", 11, QFont.Weight.Normal)) + btn.setStyleSheet(f"border-left: 5px solid {confidence_color}; text-align: left; padding: 10px;") + + timestamp_str = datetime.fromtimestamp(recording.timestamp).strftime('%b %d, %Y %H:%M') + metadata_text = f"Confidence: {confidence_pct:.1f}% • Recorded: {timestamp_str}" + metadata_label = QLabel(metadata_text) + metadata_label.setFont(QFont("Segoe UI", 9)) + metadata_label.setStyleSheet(f"color: {confidence_color}; padding: 0 0 0 5px;") + + item_layout = QVBoxLayout() + item_layout.setSpacing(2) + item_layout.addWidget(btn) + item_layout.addWidget(metadata_label) + workflow_layout.addLayout(item_layout) + + btn.clicked.connect(lambda checked, r=recording: self._handle_workflow_selection(r, search_query)) + self.workflow_buttons.append(btn) + + layout.addWidget(workflow_group) + + button_box = QDialogButtonBox(QDialogButtonBox.Cancel) + button_box.rejected.connect(dialog.reject) + layout.addWidget(button_box, 0, Qt.AlignRight) + + dialog.exec() + + def _handle_workflow_selection(self, recording: Recording, search_query: str) -> None: + """Handle selection from the workflow selection dialog.""" + logger.info(f"User selected workflow: {recording.task_description}") + + for widget in QApplication.topLevelWidgets(): + if isinstance(widget, QDialog): + widget.close() + + self.show_toast(f"Selected: {recording.task_description}") + self._replay_from_search(recording, search_query) def _replay(self, recording: Recording) -> None: """Dynamically select and configure a replay strategy.""" - # TODO: refactor into class, like ConfirmDeleteDialog dialog = QDialog() dialog.setWindowTitle("Configure Replay Strategy") layout = QVBoxLayout(dialog) - # Strategy selection label = QLabel("Select Replay Strategy:") combo_box = QComboBox() strategies = { @@ -291,30 +383,22 @@ def _replay(self, recording: Recording) -> None: if not cls.__name__.endswith("Mixin") and cls.__name__ != "DemoReplayStrategy" } - strategy_names = list(strategies.keys()) - logger.info(f"{strategy_names=}") - combo_box.addItems(strategy_names) - - # Set default strategy - default_strategy = "VisualReplayStrategy" - default_index = combo_box.findText(default_strategy) - if default_index != -1: # Ensure the strategy is found in the list - combo_box.setCurrentIndex(default_index) - else: - logger.warning(f"{default_strategy=} not found") + combo_box.addItems(strategies.keys()) + if "VisualReplayStrategy" in strategies: + combo_box.setCurrentText("VisualReplayStrategy") strategy_label = QLabel() + strategy_label.setWordWrap(True) + layout.addWidget(label) layout.addWidget(combo_box) layout.addWidget(strategy_label) - # Container for argument widgets args_container = QWidget() args_layout = QVBoxLayout(args_container) args_container.setLayout(args_layout) layout.addWidget(args_container) - # Buttons button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) @@ -322,98 +406,187 @@ def _replay(self, recording: Recording) -> None: def update_args_inputs() -> None: """Update argument inputs.""" - # Clear existing widgets while args_layout.count(): - widget_to_remove = args_layout.takeAt(0).widget() - if widget_to_remove is not None: - widget_to_remove.setParent(None) - widget_to_remove.deleteLater() + args_layout.takeAt(0).widget().deleteLater() strategy_class = strategies[combo_box.currentText()] - - strategy_label.setText(strategy_class.__doc__) + strategy_label.setText(strategy_class.__doc__ or "No description available.") sig = inspect.signature(strategy_class.__init__) for param in sig.parameters.values(): - if param.name != "self" and param.name != "recording": - arg_label = QLabel(f"{param.name.replace('_', ' ').capitalize()}:") - - # Determine if the parameter is a boolean - if param.annotation is bool: - # Create a combobox for boolean values - arg_input = QComboBox() - arg_input.addItems(["True", "False"]) - # Set default value if exists - if param.default is not inspect.Parameter.empty: - default_index = 0 if param.default else 1 - arg_input.setCurrentIndex(default_index) - else: - # Create a line edit for non-boolean values - arg_input = QLineEdit() - annotation_str = self.format_annotation(param.annotation) - arg_input.setPlaceholderText(annotation_str or "str") - # Set default text if there is a default value - if param.default is not inspect.Parameter.empty: - arg_input.setText(str(param.default)) - - args_layout.addWidget(arg_label) - args_layout.addWidget(arg_input) - - args_container.adjustSize() - dialog.adjustSize() - dialog.setMinimumSize(0, 0) # Reset the minimum size to allow shrinking + if param.name in ("self", "recording"): + continue + + arg_label = QLabel(f"{param.name.replace('_', ' ').title()}:") + if param.annotation is bool: + arg_input = QComboBox() + arg_input.addItems(["False", "True"]) + if param.default: + arg_input.setCurrentIndex(int(param.default)) + else: + arg_input = QLineEdit() + if param.default is not inspect.Parameter.empty: + arg_input.setText(str(param.default)) + + args_layout.addWidget(arg_label) + args_layout.addWidget(arg_input) combo_box.currentIndexChanged.connect(update_args_inputs) - update_args_inputs() # Initial update + update_args_inputs() - # Show dialog and process the result if dialog.exec() == QDialog.Accepted: selected_strategy = strategies[combo_box.currentText()] sig = inspect.signature(selected_strategy.__init__) kwargs = {} - index = 0 - for param_name, param in sig.parameters.items(): - if param_name in ["self", "recording"]: + widget_idx = 0 + for param in sig.parameters.values(): + if param.name in ("self", "recording"): continue - widget = args_layout.itemAt(index * 2 + 1).widget() + arg_widget = args_layout.itemAt(widget_idx * 2 + 1).widget() + value = arg_widget.currentText() == "True" if isinstance(arg_widget, QComboBox) else arg_widget.text() + try: + kwargs[param.name] = param.annotation(value) if param.annotation != inspect.Parameter.empty else value + except (ValueError, TypeError): + kwargs[param.name] = value + widget_idx += 1 + + self.child_conn.send({"type": "replay.starting"}) + replay_proc = multiprocessing.Process( + target=WrapStdout(replay), + args=(selected_strategy.__name__, False, None, recording, self.child_conn), + kwargs=kwargs, + daemon=True, + ) + replay_proc.start() + + def _replay_from_search(self, recording: Recording, search_query: str) -> None: + """Replay a recording with search query context.""" + dialog = QDialog() + dialog.setWindowTitle("Configure Replay") + layout = QVBoxLayout(dialog) + layout.setSpacing(15) + layout.setContentsMargins(20, 20, 20, 20) + + header_text = f"
Replay: {recording.task_description}
" + header_text += f"
From search: \"{search_query}\"
" + header_label = QLabel(header_text) + layout.addWidget(header_label) + + strategy_group = QGroupBox("Replay Strategy") + strategy_group.setFont(QFont("Segoe UI", 10)) + strategy_layout = QVBoxLayout(strategy_group) + + combo_box = QComboBox() + strategies = { + cls.__name__: cls for cls in BaseReplayStrategy.__subclasses__() + if not cls.__name__.endswith("Mixin") and cls.__name__ != "DemoReplayStrategy" + } + combo_box.addItems(strategies.keys()) + if "VisualReplayStrategy" in strategies: + combo_box.setCurrentText("VisualReplayStrategy") + + strategy_label = QLabel() + strategy_label.setWordWrap(True) + strategy_label.setFont(QFont("Segoe UI", 9)) + + strategy_layout.addWidget(combo_box) + strategy_layout.addWidget(strategy_label) + layout.addWidget(strategy_group) + + args_container = QWidget() + args_layout = QVBoxLayout(args_container) + args_layout.setContentsMargins(0,0,0,0) + args_layout.setSpacing(10) + layout.addWidget(args_container) + + def update_args_inputs() -> None: + """Update argument inputs, hiding instructions.""" + while args_layout.count(): + args_layout.takeAt(0).widget().deleteLater() + + strategy_class = strategies[combo_box.currentText()] + strategy_label.setText(strategy_class.__doc__ or "No description available.") + + sig = inspect.signature(strategy_class.__init__) + has_visible_params = False + for param in sig.parameters.values(): + if param.name in ("self", "recording", "instructions", "str_input"): + continue + has_visible_params = True + + param_label = QLabel(f"{param.name.replace('_', ' ').title()}:") + param_label.setFont(QFont("Segoe UI", 10)) + if param.annotation is bool: - # For boolean, get True/False from the combobox selection - value = widget.currentText() == "True" + arg_input = QComboBox() + arg_input.addItems(["False", "True"]) + if param.default: + arg_input.setCurrentIndex(int(param.default)) else: - # Convert the text to the annotated type if possible - text = widget.text() + arg_input = QLineEdit(str(param.default) if param.default is not inspect.Parameter.empty else "") + + args_layout.addWidget(param_label) + args_layout.addWidget(arg_input) + args_container.setVisible(has_visible_params) + + combo_box.currentIndexChanged.connect(update_args_inputs) + update_args_inputs() + + button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + button_box.accepted.connect(dialog.accept) + button_box.rejected.connect(dialog.reject) + layout.addWidget(button_box) + + if dialog.exec() == QDialog.Accepted: + selected_strategy = strategies[combo_box.currentText()] + sig = inspect.signature(selected_strategy.__init__) + kwargs = {} + widget_idx = 0 + for param in sig.parameters.values(): + if param.name in ("instructions", "str_input"): + kwargs[param.name] = search_query + continue + if param.name in ("self", "recording"): + continue + + if args_container.isVisible(): + arg_widget = args_layout.itemAt(widget_idx * 2 + 1).widget() + value = arg_widget.currentText() == "True" if isinstance(arg_widget, QComboBox) else arg_widget.text() try: - # Cast text to the parameter's annotated type - value = ( - param.annotation(text) - if param.annotation != inspect.Parameter.empty - else text - ) - except ValueError as exc: - logger.warning(f"{exc=}") - value = text - kwargs[param_name] = value - index += 1 - logger.info(f"kwargs=\n{pformat(kwargs)}") + kwargs[param.name] = param.annotation(value) if param.annotation != inspect.Parameter.empty else value + except (ValueError, TypeError): + kwargs[param.name] = value + widget_idx += 1 + logger.info(f"Replaying with kwargs: {pformat(kwargs)}") self.child_conn.send({"type": "replay.starting"}) - record_replay = False - recording_timestamp = None - strategy_name = selected_strategy.__name__ replay_proc = multiprocessing.Process( target=WrapStdout(replay), - args=( - strategy_name, - record_replay, - recording_timestamp, - recording, - self.child_conn, - ), + args=(selected_strategy.__name__, False, None, recording, self.child_conn), kwargs=kwargs, daemon=True, ) replay_proc.start() + def _visualize(self, recording: Recording) -> None: + """Visualize a recording. + + Args: + recording (Recording): The recording to visualize. + """ + self.show_toast("Starting visualization...") + try: + if self.visualize_proc is not None: + self.visualize_proc.kill() + self.visualize_proc = multiprocessing.Process( + target=WrapStdout(visualize), args=(recording,) + ) + self.visualize_proc.start() + + except Exception as e: + logger.error(e) + self.show_toast("Visualization failed.") + def _delete(self, recording: Recording) -> None: """Delete a recording after confirmation. @@ -705,52 +878,30 @@ class ConfirmDeleteDialog(QDialog): """Dialog window to confirm recording deletion.""" def __init__(self, recording_description: str) -> None: - """Initialize. - - Args: - recording_description (str): The Recording's description. - """ + """Initialize.""" super().__init__() - self.setWindowTitle("Confirm Delete") + self.setWindowTitle("Confirm Deletion") self.build_ui(recording_description) def build_ui(self, recording_description: str) -> None: - """Build the dialog window. - - Args: - recording_description (str): The recording description. - """ - # Setup layout + """Build the dialog window.""" layout = QVBoxLayout(self) - # Add description text label = QLabel( f"Are you sure you want to delete the recording '{recording_description}'?" ) label.setWordWrap(True) layout.addWidget(label) - # Add buttons - button_layout = QHBoxLayout() - yes_button = QPushButton("Yes") - no_button = QPushButton("No") - button_layout.addWidget(yes_button) - button_layout.addWidget(no_button) - layout.addLayout(button_layout) - - # Connect buttons - yes_button.clicked.connect(self.accept) - no_button.clicked.connect(self.reject) + button_box = QDialogButtonBox(QDialogButtonBox.Yes | QDialogButtonBox.No) + button_box.button(QDialogButtonBox.Yes).setStyleSheet("color: red;") + button_box.accepted.connect(self.accept) + button_box.rejected.connect(self.reject) + layout.addWidget(button_box) def exec_(self) -> bool: - """Show the dialog window and return the user input. - - Returns: - bool: The user's input. - """ - if super().exec_() == QDialog.Accepted: - return True - return False + """Show the dialog window and return the user input.""" + return super().exec_() == QDialog.Accepted def _run() -> None: diff --git a/openadapt/config.defaults.json b/openadapt/config.defaults.json index 1f935ca3f..21c7bb125 100644 --- a/openadapt/config.defaults.json +++ b/openadapt/config.defaults.json @@ -82,6 +82,7 @@ "VISUALIZE_MAX_TABLE_CHILDREN": 10, "SAVE_SCREENSHOT_DIFF": false, "SPACY_MODEL_NAME": "en_core_web_trf", + "EMBEDDING_MODEL_NAME": "sentence-transformers/all-MiniLM-L6-v2", "DASHBOARD_CLIENT_PORT": 5173, "DASHBOARD_SERVER_PORT": 8080, "BROWSER_WEBSOCKET_PORT": 8765, diff --git a/openadapt/config.py b/openadapt/config.py index ca3c01801..82366769a 100644 --- a/openadapt/config.py +++ b/openadapt/config.py @@ -227,6 +227,9 @@ def validate_scrub_fill_color(cls, v: Union[str, int]) -> int: # noqa: ANN102 # Spacy configurations SPACY_MODEL_NAME: str = "en_core_web_trf" + # Embedding configurations + EMBEDDING_MODEL_NAME: str = "sentence-transformers/all-MiniLM-L6-v2" + # Dashboard configurations DASHBOARD_CLIENT_PORT: int = 3000 DASHBOARD_SERVER_PORT: int = 8000 diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 39a4f7677..62291415c 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -29,8 +29,10 @@ ScrubbedRecording, WindowEvent, copy_sa_instance, + RecordingEmbedding, ) from openadapt.privacy.base import ScrubbingProvider +from openadapt.embed import get_embedding, get_configured_model_name BATCH_SIZE = 1 @@ -284,6 +286,11 @@ def insert_recording(session: SaSession, recording_data: dict) -> Recording: session.add(db_obj) session.commit() session.refresh(db_obj) + + # Add embedding + if db_obj.task_description: + add_or_update_embedding(session, db_obj, db_obj.task_description) + return db_obj @@ -966,3 +973,58 @@ def release_db_lock(raise_exception: bool = True) -> None: logger.error("Database lock file not found.") raise logger.info("Database lock released.") + + +def add_or_update_embedding(session: SaSession, recording: Recording, text: str) -> None: + """Adds or updates the embedding for a given recording and text. + + Args: + session: The database session. + recording: The Recording object. + text: The text to embed (e.g., task_description). + """ + current_model_name = get_configured_model_name() + embedding_vector = get_embedding(text, model_name=current_model_name) + if embedding_vector: + # Check if an embedding already exists for this recording and model + existing_embedding = ( + session.query(RecordingEmbedding) + .filter_by(recording_id=recording.id, model_name=current_model_name) + .first() + ) + if existing_embedding: + existing_embedding.embedding = embedding_vector + existing_embedding.timestamp = time.time() + logger.info(f"Updated embedding for recording_id={recording.id} using model='{current_model_name}'") + else: + new_embedding = RecordingEmbedding( + recording_id=recording.id, + embedding=embedding_vector, + model_name=current_model_name, + ) + session.add(new_embedding) + logger.info(f"Added new embedding for recording_id={recording.id} using model='{current_model_name}'") + session.commit() + else: + logger.warning( + f"Could not generate or save embedding for recording_id={recording.id}" + ) + + +def get_similar_recordings_by_embedding( + session: SaSession, query_embedding: list[float], top_n: int = 5 +) -> list[Recording]: + """Finds recordings with the most similar embeddings to the query_embedding + using pure Python similarity search with NumPy and SciPy. + + Args: + session: The database session. + query_embedding: The embedding vector of the user's query. + top_n: The number of similar recordings to return. + + Returns: + A list of Recording objects, ordered by similarity (most similar first). + """ + from openadapt.similarity_search import get_similar_recordings_by_embedding_legacy + + return get_similar_recordings_by_embedding_legacy(session, query_embedding, top_n) diff --git a/openadapt/embed.py b/openadapt/embed.py new file mode 100644 index 000000000..983a8e20f --- /dev/null +++ b/openadapt/embed.py @@ -0,0 +1,69 @@ +"""This module handles the generation of text embeddings.""" + +from sentence_transformers import SentenceTransformer +from openadapt.custom_logger import logger +from openadapt.config import config # Import config + +# Global variable to cache the loaded model +_model_cache = {} + +def get_configured_model_name() -> str: + """Returns the configured embedding model name.""" + return config.EMBEDDING_MODEL_NAME + +def _load_model(model_name: str) -> SentenceTransformer | None: + """Loads a sentence transformer model and caches it.""" + if model_name in _model_cache: + return _model_cache[model_name] + + try: + model = SentenceTransformer(model_name) + logger.info(f"Loaded sentence transformer model: {model_name}") + _model_cache[model_name] = model + return model + except Exception as e: + logger.error(f"Failed to load sentence transformer model: {model_name}, {e}") + _model_cache[model_name] = None # Cache None to avoid retrying failed loads repeatedly + return None + +def get_embedding(text: str, model_name: str | None = None) -> list[float] | None: + """Generates an embedding for the given text using the specified or configured model. + + Args: + text: The text to embed. + model_name: Optional. The name of the sentence transformer model to use. + If None, uses the model from the global config. + + Returns: + The embedding as a list of floats, or None if an error occurs. + """ + current_model_name = model_name or get_configured_model_name() + model = _load_model(current_model_name) + + if not model: + logger.error(f"Sentence transformer model '{current_model_name}' not loaded. Cannot generate embedding.") + return None + if not text or not isinstance(text, str): + logger.warning(f"Invalid text input for embedding: {text}") + return None + try: + embedding_output = model.encode(text, convert_to_tensor=False, convert_to_numpy=False) + # Ensure it's a list of Python floats + if hasattr(embedding_output, 'tolist'): # Handles numpy array or torch tensor + embedding = embedding_output.tolist() + elif isinstance(embedding_output, list): + embedding = embedding_output + else: + # Fallback if it's a single tensor/numpy scalar, though unlikely for sentence embeddings + embedding = [float(embedding_output)] + + # Further ensure all elements are floats if it's a list of lists (some models might do this) + if embedding and isinstance(embedding[0], list): + embedding = [float(item) for sublist in embedding for item in sublist] # Flatten and convert + elif embedding: + embedding = [float(item) for item in embedding] + + return embedding + except Exception as e: + logger.error(f"Error generating embedding for text='{text[:100]}...': {e}") + return None \ No newline at end of file diff --git a/openadapt/models.py b/openadapt/models.py index 03b60329e..59175e75d 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -8,6 +8,7 @@ import io import sys import textwrap +import time from bs4 import BeautifulSoup from pynput import keyboard @@ -100,6 +101,16 @@ class Recording(db.Base): audio_info = sa.orm.relationship( "AudioInfo", back_populates="recording", cascade="all, delete-orphan" ) + recording_embeddings = sa.orm.relationship( + "RecordingEmbedding", + back_populates="recording", + cascade="all, delete-orphan", + ) + recording_summaries = sa.orm.relationship( + "RecordingSummary", + back_populates="recording", + cascade="all, delete-orphan", + ) _processed_action_events = None @@ -1187,6 +1198,34 @@ class Replay(db.Base): git_hash = sa.Column(sa.String) +class RecordingEmbedding(db.Base): + """Class representing a recording embedding in the database.""" + + __tablename__ = "recording_embedding" + + id = sa.Column(sa.Integer, primary_key=True) + recording_id = sa.Column(sa.ForeignKey("recording.id")) + embedding = sa.Column(sa.JSON) # Assuming sqlite-vss can handle JSON + model_name = sa.Column(sa.String) # To store the name of the embedding model used + timestamp = sa.Column(ForceFloat, default=time.time) + + recording = sa.orm.relationship("Recording", back_populates="recording_embeddings") + + +class RecordingSummary(db.Base): + """Class representing a hierarchical summary of a recording.""" + + __tablename__ = "recording_summary" + + id = sa.Column(sa.Integer, primary_key=True) + recording_id = sa.Column(sa.ForeignKey("recording.id")) + summary_text = sa.Column(sa.Text) + summary_level = sa.Column(sa.String) # e.g., "high", "mid", "low" + timestamp = sa.Column(ForceFloat, default=time.time) + + recording = sa.orm.relationship("Recording", back_populates="recording_summaries") + + def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base: """Copy a SQLAlchemy instance. diff --git a/openadapt/scripts/backfill_embeddings.py b/openadapt/scripts/backfill_embeddings.py new file mode 100644 index 000000000..49ce4d208 --- /dev/null +++ b/openadapt/scripts/backfill_embeddings.py @@ -0,0 +1,95 @@ +"""This script backfills embeddings for existing recordings.""" + +import argparse + +from sqlalchemy.orm import Session + +from openadapt.db import crud +from openadapt.models import Recording, RecordingEmbedding +from openadapt.custom_logger import logger +from openadapt.embed import get_configured_model_name # Updated import + +def backfill_embeddings(session: Session, force_update: bool = False) -> None: + """Backfills embeddings for recordings that don't have one or if force_update is True. + + Args: + session: The database session. + force_update: If True, will re-generate embeddings even if they exist. + """ + logger.info("Starting embedding backfill process...") + recordings = session.query(Recording).all() + logger.info(f"Found {len(recordings)} recordings to process.") + + processed_count = 0 + skipped_count = 0 + error_count = 0 + + for recording in recordings: + if not recording.task_description: + # logger.info(f"Skipping recording_id={recording.id}, no task_description.") + skipped_count += 1 + continue + + if not force_update: + # Check if an embedding with the current MODEL_NAME already exists + current_model_name = get_configured_model_name() + existing_embedding = ( + session.query(RecordingEmbedding) + .filter_by(recording_id=recording.id, model_name=current_model_name) + .first() + ) + if existing_embedding: + # logger.info( + # f"Skipping recording_id={recording.id}, embedding with model='{current_model_name}' already exists." + # ) + skipped_count += 1 + continue + + logger.info(f"Processing recording_id={recording.id}: '{recording.task_description[:50]}...' ") + try: + # add_or_update_embedding will handle both creation and update logic + # It uses the MODEL_NAME from openadapt.embed + crud.add_or_update_embedding(session, recording, recording.task_description) + processed_count += 1 + except Exception as e: + logger.error(f"Error processing recording_id={recording.id}: {e}") + error_count += 1 + # Optionally, rollback session for this specific error to not affect others + # session.rollback() + + if processed_count > 0 or error_count > 0 : # only commit if there were changes or errors to log + try: + session.commit() + logger.info("Committed changes to the database.") + except Exception as e: + logger.error(f"Error committing changes to database: {e}") + session.rollback() + + logger.info(f"Embedding backfill complete. Processed: {processed_count}, Skipped: {skipped_count}, Errors: {error_count}") + +def main() -> None: + """Main function to run the backfill script.""" + parser = argparse.ArgumentParser(description="Backfill embeddings for existing recordings.") + parser.add_argument( + "--force", + action="store_true", + help="Force update embeddings even if they already exist for the current model.", + ) + args = parser.parse_args() + db_session = None + try: + # Use read_and_write=True for the session + db_session = crud.get_new_session(read_and_write=True) + backfill_embeddings(db_session, force_update=args.force) + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + if db_session: + db_session.rollback() + finally: + if db_session: + db_session.close() + # crud.release_db_lock() + logger.info("Backfill script finished.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/openadapt/similarity_search.py b/openadapt/similarity_search.py new file mode 100644 index 000000000..a9e2ba1a2 --- /dev/null +++ b/openadapt/similarity_search.py @@ -0,0 +1,131 @@ +"""Vector similarity search functionality using pure Python.""" + +import json +import re +import time +from collections import Counter +from typing import List, Tuple + +import numpy as np +from loguru import logger +from scipy.spatial.distance import cosine + +from openadapt.db.crud import SaSession +from openadapt.embed import get_embedding, get_configured_model_name +from openadapt.models import Recording, RecordingEmbedding + +# A small set of common English stop words to avoid external dependencies or downloads. +STOP_WORDS = { + "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", + "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "were", + "will", "with", "i", "you", "your", "d", "s", "t", "ve", +} + + +def _extract_context(recording: Recording) -> str: + """Extract a general context from a recording to improve search relevance.""" + context_parts = [] + + if recording.task_description: + context_parts.append(recording.task_description) + + # Extract keywords from window titles + if recording.window_events: + all_titles = " ".join( + [event.title for event in recording.window_events if event.title] + ) + words = re.findall(r"\w+", all_titles.lower()) + meaningful_words = [ + word for word in words if word not in STOP_WORDS and not word.isdigit() + ] + + if meaningful_words: + # Add top 5 most common words as context + word_counts = Counter(meaningful_words) + for word, _ in word_counts.most_common(5): + if len(word) > 2: # filter out very short words + context_parts.append(word) + + # Extract context from action events + if recording.action_events: + # Get a summary of action types + action_types = set(ae.name for ae in recording.action_events if ae.name) + if action_types: + context_parts.append("actions:" + ",".join(sorted(list(action_types)))) + + # Detect significant text input + typed_text = "".join( + ae.key_char + for ae in recording.action_events + if ae.name == "keypress" and ae.key_char and ae.key_char.isalnum() + ) + if len(typed_text) > 15: # Heuristic for meaningful text input + context_parts.append("text_input") + + # Add transcribed text from audio if available + if hasattr(recording, "audio_info") and recording.audio_info: + for audio in recording.audio_info: + if audio.transcribed_text: + context_parts.append("transcription:" + audio.transcribed_text) + + # Deduplicate and join + unique_context = sorted(list(set(context_parts))) + return " ".join(unique_context) + + +def get_enhanced_similarity_search( + session: SaSession, query_text: str, top_n: int = 5 +) -> List[Tuple[Recording, float]]: + """Perform an enhanced search using application and action context. + + Args: + session: The database session. + query_text: The user's search query. + top_n: The number of results to return. + + Returns: + A list of (Recording, similarity_score) tuples. + """ + model_name = get_configured_model_name() + query_embedding = get_embedding(query_text, model_name=model_name) + + if not query_embedding: + logger.error("Could not generate embedding for query text.") + return [] + + recordings_with_embeddings = ( + session.query(Recording, RecordingEmbedding) + .join(RecordingEmbedding) + .filter(RecordingEmbedding.model_name == model_name) + .all() + ) + + results = [] + for recording, embedding_record in recordings_with_embeddings: + enhanced_context = _extract_context(recording) + context_embedding = get_embedding(enhanced_context, model_name=model_name) + + try: + stored_embedding = ( + json.loads(embedding_record.embedding) + if isinstance(embedding_record.embedding, str) + else embedding_record.embedding + ) + description_similarity = 1 - cosine(query_embedding, stored_embedding) + context_similarity = ( + (1 - cosine(query_embedding, context_embedding)) + if context_embedding + else 0.0 + ) + + combined_similarity = (description_similarity * 0.6) + ( + context_similarity * 0.4 + ) + results.append((recording, combined_similarity)) + except (TypeError, json.JSONDecodeError) as e: + logger.error( + f"Error processing embedding for recording {recording.id}: {e}" + ) + + results.sort(key=lambda x: x[1], reverse=True) + return results[:top_n] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7f1b705c1..87fda8cba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,10 @@ pytest = "7.1.3" rapidocr-onnxruntime = "1.2.3" scikit-learn = "1.2.2" scipy = "^1.11.0" +numpy = "^1.24.0" torch = "^2.0.0" tqdm = "4.64.0" -transformers = "4.29.2" +transformers = "^4.32.0" python-dotenv = "1.0.0" pyinstaller = "6.11.0" setuptools-lint = "^0.6.0" @@ -99,6 +100,7 @@ pudb = "^2024.1" sounddevice = "^0.4.6" soundfile = "^0.12.1" posthog = "^3.5.0" +sentence-transformers = "^2.3.1" wheel = "^0.43.0" cython = "^3.0.10"