diff --git a/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py b/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py
new file mode 100644
index 000000000..a4d466556
--- /dev/null
+++ b/openadapt/alembic/versions/bd9917da991f_add_recording_embedding_and_summary_.py
@@ -0,0 +1,46 @@
+"""add_recording_embedding_and_summary_tables
+
+Revision ID: bd9917da991f
+Revises: 98505a067995
+Create Date: 2025-06-01 19:10:09.277603
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import openadapt
+
+# revision identifiers, used by Alembic.
+revision = 'bd9917da991f'
+down_revision = '98505a067995'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('recording_embedding',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('recording_id', sa.Integer(), nullable=True),
+ sa.Column('embedding', sa.JSON(), nullable=True),
+ sa.Column('model_name', sa.String(), nullable=True),
+ sa.Column('timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
+ sa.ForeignKeyConstraint(['recording_id'], ['recording.id'], name=op.f('fk_recording_embedding_recording_id_recording')),
+ sa.PrimaryKeyConstraint('id', name=op.f('pk_recording_embedding'))
+ )
+ op.create_table('recording_summary',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('recording_id', sa.Integer(), nullable=True),
+ sa.Column('summary_text', sa.Text(), nullable=True),
+ sa.Column('summary_level', sa.String(), nullable=True),
+ sa.Column('timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
+ sa.ForeignKeyConstraint(['recording_id'], ['recording.id'], name=op.f('fk_recording_summary_recording_id_recording')),
+ sa.PrimaryKeyConstraint('id', name=op.f('pk_recording_summary'))
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('recording_summary')
+ op.drop_table('recording_embedding')
+ # ### end Alembic commands ###
diff --git a/openadapt/app/tray.py b/openadapt/app/tray.py
index 805a39b60..ba2f2a92f 100644
--- a/openadapt/app/tray.py
+++ b/openadapt/app/tray.py
@@ -21,6 +21,7 @@
QComboBox,
QDialog,
QDialogButtonBox,
+ QGroupBox,
QHBoxLayout,
QInputDialog,
QLabel,
@@ -44,6 +45,7 @@
from openadapt.strategies.base import BaseReplayStrategy
from openadapt.utils import WrapStdout, get_posthog_instance
from openadapt.visualize import main as visualize
+from openadapt.similarity_search import get_enhanced_similarity_search
# ensure all strategies are registered
import openadapt.strategies # noqa: F401
@@ -144,6 +146,12 @@ def __init__(self) -> None:
self.record_action.triggered.connect(self._record)
self.menu.addAction(self.record_action)
+ self.search_action = TrackedQAction("💬 What do you want help with today?")
+ self.search_action.triggered.connect(self._search_workflows)
+ self.menu.addAction(self.search_action)
+
+ self.menu.addSeparator()
+
self.visualize_menu = self.menu.addMenu("Visualize")
self.replay_menu = self.menu.addMenu("Replay")
self.delete_menu = self.menu.addMenu("Delete")
@@ -256,33 +264,117 @@ def stop_recording(self) -> None:
"""Stop recording."""
Thread(target=stop_record).start()
- def _visualize(self, recording: Recording) -> None:
- """Visualize a recording.
-
- Args:
- recording (Recording): The recording to visualize.
- """
- self.show_toast("Starting visualization...")
+ def _search_workflows(self) -> None:
+ """Handle natural language workflow search with multiple results."""
+ search_query, ok = QInputDialog.getText(
+ None,
+ "💬 OpenAdapt Assistant",
+ "What do you want help with today?\n(e.g., 'calculate taxes', 'organize files', 'send email')",
+ text=""
+ )
+
+ if not ok or not search_query.strip():
+ return
+
+ logger.info(f"User wants help with: {search_query}")
+ self.show_toast("Finding relevant workflows...")
+
try:
- if self.visualize_proc is not None:
- self.visualize_proc.kill()
- self.visualize_proc = multiprocessing.Process(
- target=WrapStdout(visualize), args=(recording,)
- )
- self.visualize_proc.start()
+ with crud.get_new_session(read_only=True) as session:
+ results = get_enhanced_similarity_search(session, search_query, top_n=5)
+
+ if not results:
+ self.show_toast(
+ "No matching workflows found. Try recording a workflow first!",
+ duration=5000
+ )
+ return
+
+ filtered_results = [(r, c) for r, c in results if c >= 0.1]
+
+ if not filtered_results:
+ self.show_toast(
+ f"No good matches found for '{search_query}'. Try different keywords.",
+ duration=5000
+ )
+ return
+
+ self._show_workflow_selection(filtered_results, search_query)
except Exception as e:
- logger.error(e)
- self.show_toast("Visualization failed.")
+ logger.error(f"Error searching workflows: {e}")
+ self.show_toast("Search failed. Please try again.", duration=5000)
+
+ def _show_workflow_selection(self, results: list, search_query: str) -> None:
+ """Show a dialog with multiple workflow options for selection."""
+ dialog = QDialog()
+ dialog.setWindowTitle("Select a Workflow")
+ dialog.setMinimumWidth(700)
+ layout = QVBoxLayout(dialog)
+ layout.setSpacing(15)
+ layout.setContentsMargins(20, 20, 20, 20)
+
+ header_label = QLabel(f"Found workflows for: \"{search_query}\"")
+ header_label.setFont(QFont("Segoe UI", 14, QFont.Weight.Bold))
+ layout.addWidget(header_label)
+
+ workflow_group = QGroupBox("Select a workflow to replay:")
+ workflow_group.setFont(QFont("Segoe UI", 10))
+ workflow_layout = QVBoxLayout(workflow_group)
+ workflow_layout.setSpacing(10)
+
+ self.workflow_buttons = []
+ for recording, confidence in results:
+ confidence_pct = confidence * 100
+ if confidence_pct >= 80: confidence_color = "#28a745"
+ elif confidence_pct >= 60: confidence_color = "#ffc107"
+ else: confidence_color = "#dc3545"
+
+ btn = QPushButton(recording.task_description)
+ btn.setMinimumHeight(50)
+ btn.setFont(QFont("Segoe UI", 11, QFont.Weight.Normal))
+ btn.setStyleSheet(f"border-left: 5px solid {confidence_color}; text-align: left; padding: 10px;")
+
+ timestamp_str = datetime.fromtimestamp(recording.timestamp).strftime('%b %d, %Y %H:%M')
+ metadata_text = f"Confidence: {confidence_pct:.1f}% • Recorded: {timestamp_str}"
+ metadata_label = QLabel(metadata_text)
+ metadata_label.setFont(QFont("Segoe UI", 9))
+ metadata_label.setStyleSheet(f"color: {confidence_color}; padding: 0 0 0 5px;")
+
+ item_layout = QVBoxLayout()
+ item_layout.setSpacing(2)
+ item_layout.addWidget(btn)
+ item_layout.addWidget(metadata_label)
+ workflow_layout.addLayout(item_layout)
+
+ btn.clicked.connect(lambda checked, r=recording: self._handle_workflow_selection(r, search_query))
+ self.workflow_buttons.append(btn)
+
+ layout.addWidget(workflow_group)
+
+ button_box = QDialogButtonBox(QDialogButtonBox.Cancel)
+ button_box.rejected.connect(dialog.reject)
+ layout.addWidget(button_box, 0, Qt.AlignRight)
+
+ dialog.exec()
+
+ def _handle_workflow_selection(self, recording: Recording, search_query: str) -> None:
+ """Handle selection from the workflow selection dialog."""
+ logger.info(f"User selected workflow: {recording.task_description}")
+
+ for widget in QApplication.topLevelWidgets():
+ if isinstance(widget, QDialog):
+ widget.close()
+
+ self.show_toast(f"Selected: {recording.task_description}")
+ self._replay_from_search(recording, search_query)
def _replay(self, recording: Recording) -> None:
"""Dynamically select and configure a replay strategy."""
- # TODO: refactor into class, like ConfirmDeleteDialog
dialog = QDialog()
dialog.setWindowTitle("Configure Replay Strategy")
layout = QVBoxLayout(dialog)
- # Strategy selection
label = QLabel("Select Replay Strategy:")
combo_box = QComboBox()
strategies = {
@@ -291,30 +383,22 @@ def _replay(self, recording: Recording) -> None:
if not cls.__name__.endswith("Mixin")
and cls.__name__ != "DemoReplayStrategy"
}
- strategy_names = list(strategies.keys())
- logger.info(f"{strategy_names=}")
- combo_box.addItems(strategy_names)
-
- # Set default strategy
- default_strategy = "VisualReplayStrategy"
- default_index = combo_box.findText(default_strategy)
- if default_index != -1: # Ensure the strategy is found in the list
- combo_box.setCurrentIndex(default_index)
- else:
- logger.warning(f"{default_strategy=} not found")
+ combo_box.addItems(strategies.keys())
+ if "VisualReplayStrategy" in strategies:
+ combo_box.setCurrentText("VisualReplayStrategy")
strategy_label = QLabel()
+ strategy_label.setWordWrap(True)
+
layout.addWidget(label)
layout.addWidget(combo_box)
layout.addWidget(strategy_label)
- # Container for argument widgets
args_container = QWidget()
args_layout = QVBoxLayout(args_container)
args_container.setLayout(args_layout)
layout.addWidget(args_container)
- # Buttons
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
button_box.accepted.connect(dialog.accept)
button_box.rejected.connect(dialog.reject)
@@ -322,98 +406,187 @@ def _replay(self, recording: Recording) -> None:
def update_args_inputs() -> None:
"""Update argument inputs."""
- # Clear existing widgets
while args_layout.count():
- widget_to_remove = args_layout.takeAt(0).widget()
- if widget_to_remove is not None:
- widget_to_remove.setParent(None)
- widget_to_remove.deleteLater()
+ args_layout.takeAt(0).widget().deleteLater()
strategy_class = strategies[combo_box.currentText()]
-
- strategy_label.setText(strategy_class.__doc__)
+ strategy_label.setText(strategy_class.__doc__ or "No description available.")
sig = inspect.signature(strategy_class.__init__)
for param in sig.parameters.values():
- if param.name != "self" and param.name != "recording":
- arg_label = QLabel(f"{param.name.replace('_', ' ').capitalize()}:")
-
- # Determine if the parameter is a boolean
- if param.annotation is bool:
- # Create a combobox for boolean values
- arg_input = QComboBox()
- arg_input.addItems(["True", "False"])
- # Set default value if exists
- if param.default is not inspect.Parameter.empty:
- default_index = 0 if param.default else 1
- arg_input.setCurrentIndex(default_index)
- else:
- # Create a line edit for non-boolean values
- arg_input = QLineEdit()
- annotation_str = self.format_annotation(param.annotation)
- arg_input.setPlaceholderText(annotation_str or "str")
- # Set default text if there is a default value
- if param.default is not inspect.Parameter.empty:
- arg_input.setText(str(param.default))
-
- args_layout.addWidget(arg_label)
- args_layout.addWidget(arg_input)
-
- args_container.adjustSize()
- dialog.adjustSize()
- dialog.setMinimumSize(0, 0) # Reset the minimum size to allow shrinking
+ if param.name in ("self", "recording"):
+ continue
+
+ arg_label = QLabel(f"{param.name.replace('_', ' ').title()}:")
+ if param.annotation is bool:
+ arg_input = QComboBox()
+ arg_input.addItems(["False", "True"])
+ if param.default:
+ arg_input.setCurrentIndex(int(param.default))
+ else:
+ arg_input = QLineEdit()
+ if param.default is not inspect.Parameter.empty:
+ arg_input.setText(str(param.default))
+
+ args_layout.addWidget(arg_label)
+ args_layout.addWidget(arg_input)
combo_box.currentIndexChanged.connect(update_args_inputs)
- update_args_inputs() # Initial update
+ update_args_inputs()
- # Show dialog and process the result
if dialog.exec() == QDialog.Accepted:
selected_strategy = strategies[combo_box.currentText()]
sig = inspect.signature(selected_strategy.__init__)
kwargs = {}
- index = 0
- for param_name, param in sig.parameters.items():
- if param_name in ["self", "recording"]:
+ widget_idx = 0
+ for param in sig.parameters.values():
+ if param.name in ("self", "recording"):
continue
- widget = args_layout.itemAt(index * 2 + 1).widget()
+ arg_widget = args_layout.itemAt(widget_idx * 2 + 1).widget()
+ value = arg_widget.currentText() == "True" if isinstance(arg_widget, QComboBox) else arg_widget.text()
+ try:
+ kwargs[param.name] = param.annotation(value) if param.annotation != inspect.Parameter.empty else value
+ except (ValueError, TypeError):
+ kwargs[param.name] = value
+ widget_idx += 1
+
+ self.child_conn.send({"type": "replay.starting"})
+ replay_proc = multiprocessing.Process(
+ target=WrapStdout(replay),
+ args=(selected_strategy.__name__, False, None, recording, self.child_conn),
+ kwargs=kwargs,
+ daemon=True,
+ )
+ replay_proc.start()
+
+ def _replay_from_search(self, recording: Recording, search_query: str) -> None:
+ """Replay a recording with search query context."""
+ dialog = QDialog()
+ dialog.setWindowTitle("Configure Replay")
+ layout = QVBoxLayout(dialog)
+ layout.setSpacing(15)
+ layout.setContentsMargins(20, 20, 20, 20)
+
+ header_text = f"
Replay: {recording.task_description}
"
+ header_text += f"From search: \"{search_query}\"
"
+ header_label = QLabel(header_text)
+ layout.addWidget(header_label)
+
+ strategy_group = QGroupBox("Replay Strategy")
+ strategy_group.setFont(QFont("Segoe UI", 10))
+ strategy_layout = QVBoxLayout(strategy_group)
+
+ combo_box = QComboBox()
+ strategies = {
+ cls.__name__: cls for cls in BaseReplayStrategy.__subclasses__()
+ if not cls.__name__.endswith("Mixin") and cls.__name__ != "DemoReplayStrategy"
+ }
+ combo_box.addItems(strategies.keys())
+ if "VisualReplayStrategy" in strategies:
+ combo_box.setCurrentText("VisualReplayStrategy")
+
+ strategy_label = QLabel()
+ strategy_label.setWordWrap(True)
+ strategy_label.setFont(QFont("Segoe UI", 9))
+
+ strategy_layout.addWidget(combo_box)
+ strategy_layout.addWidget(strategy_label)
+ layout.addWidget(strategy_group)
+
+ args_container = QWidget()
+ args_layout = QVBoxLayout(args_container)
+ args_layout.setContentsMargins(0,0,0,0)
+ args_layout.setSpacing(10)
+ layout.addWidget(args_container)
+
+ def update_args_inputs() -> None:
+ """Update argument inputs, hiding instructions."""
+ while args_layout.count():
+ args_layout.takeAt(0).widget().deleteLater()
+
+ strategy_class = strategies[combo_box.currentText()]
+ strategy_label.setText(strategy_class.__doc__ or "No description available.")
+
+ sig = inspect.signature(strategy_class.__init__)
+ has_visible_params = False
+ for param in sig.parameters.values():
+ if param.name in ("self", "recording", "instructions", "str_input"):
+ continue
+ has_visible_params = True
+
+ param_label = QLabel(f"{param.name.replace('_', ' ').title()}:")
+ param_label.setFont(QFont("Segoe UI", 10))
+
if param.annotation is bool:
- # For boolean, get True/False from the combobox selection
- value = widget.currentText() == "True"
+ arg_input = QComboBox()
+ arg_input.addItems(["False", "True"])
+ if param.default:
+ arg_input.setCurrentIndex(int(param.default))
else:
- # Convert the text to the annotated type if possible
- text = widget.text()
+ arg_input = QLineEdit(str(param.default) if param.default is not inspect.Parameter.empty else "")
+
+ args_layout.addWidget(param_label)
+ args_layout.addWidget(arg_input)
+ args_container.setVisible(has_visible_params)
+
+ combo_box.currentIndexChanged.connect(update_args_inputs)
+ update_args_inputs()
+
+ button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
+ button_box.accepted.connect(dialog.accept)
+ button_box.rejected.connect(dialog.reject)
+ layout.addWidget(button_box)
+
+ if dialog.exec() == QDialog.Accepted:
+ selected_strategy = strategies[combo_box.currentText()]
+ sig = inspect.signature(selected_strategy.__init__)
+ kwargs = {}
+ widget_idx = 0
+ for param in sig.parameters.values():
+ if param.name in ("instructions", "str_input"):
+ kwargs[param.name] = search_query
+ continue
+ if param.name in ("self", "recording"):
+ continue
+
+ if args_container.isVisible():
+ arg_widget = args_layout.itemAt(widget_idx * 2 + 1).widget()
+ value = arg_widget.currentText() == "True" if isinstance(arg_widget, QComboBox) else arg_widget.text()
try:
- # Cast text to the parameter's annotated type
- value = (
- param.annotation(text)
- if param.annotation != inspect.Parameter.empty
- else text
- )
- except ValueError as exc:
- logger.warning(f"{exc=}")
- value = text
- kwargs[param_name] = value
- index += 1
- logger.info(f"kwargs=\n{pformat(kwargs)}")
+ kwargs[param.name] = param.annotation(value) if param.annotation != inspect.Parameter.empty else value
+ except (ValueError, TypeError):
+ kwargs[param.name] = value
+ widget_idx += 1
+ logger.info(f"Replaying with kwargs: {pformat(kwargs)}")
self.child_conn.send({"type": "replay.starting"})
- record_replay = False
- recording_timestamp = None
- strategy_name = selected_strategy.__name__
replay_proc = multiprocessing.Process(
target=WrapStdout(replay),
- args=(
- strategy_name,
- record_replay,
- recording_timestamp,
- recording,
- self.child_conn,
- ),
+ args=(selected_strategy.__name__, False, None, recording, self.child_conn),
kwargs=kwargs,
daemon=True,
)
replay_proc.start()
+ def _visualize(self, recording: Recording) -> None:
+ """Visualize a recording.
+
+ Args:
+ recording (Recording): The recording to visualize.
+ """
+ self.show_toast("Starting visualization...")
+ try:
+ if self.visualize_proc is not None:
+ self.visualize_proc.kill()
+ self.visualize_proc = multiprocessing.Process(
+ target=WrapStdout(visualize), args=(recording,)
+ )
+ self.visualize_proc.start()
+
+ except Exception as e:
+ logger.error(e)
+ self.show_toast("Visualization failed.")
+
def _delete(self, recording: Recording) -> None:
"""Delete a recording after confirmation.
@@ -705,52 +878,30 @@ class ConfirmDeleteDialog(QDialog):
"""Dialog window to confirm recording deletion."""
def __init__(self, recording_description: str) -> None:
- """Initialize.
-
- Args:
- recording_description (str): The Recording's description.
- """
+ """Initialize."""
super().__init__()
- self.setWindowTitle("Confirm Delete")
+ self.setWindowTitle("Confirm Deletion")
self.build_ui(recording_description)
def build_ui(self, recording_description: str) -> None:
- """Build the dialog window.
-
- Args:
- recording_description (str): The recording description.
- """
- # Setup layout
+ """Build the dialog window."""
layout = QVBoxLayout(self)
- # Add description text
label = QLabel(
f"Are you sure you want to delete the recording '{recording_description}'?"
)
label.setWordWrap(True)
layout.addWidget(label)
- # Add buttons
- button_layout = QHBoxLayout()
- yes_button = QPushButton("Yes")
- no_button = QPushButton("No")
- button_layout.addWidget(yes_button)
- button_layout.addWidget(no_button)
- layout.addLayout(button_layout)
-
- # Connect buttons
- yes_button.clicked.connect(self.accept)
- no_button.clicked.connect(self.reject)
+ button_box = QDialogButtonBox(QDialogButtonBox.Yes | QDialogButtonBox.No)
+ button_box.button(QDialogButtonBox.Yes).setStyleSheet("color: red;")
+ button_box.accepted.connect(self.accept)
+ button_box.rejected.connect(self.reject)
+ layout.addWidget(button_box)
def exec_(self) -> bool:
- """Show the dialog window and return the user input.
-
- Returns:
- bool: The user's input.
- """
- if super().exec_() == QDialog.Accepted:
- return True
- return False
+ """Show the dialog window and return the user input."""
+ return super().exec_() == QDialog.Accepted
def _run() -> None:
diff --git a/openadapt/config.defaults.json b/openadapt/config.defaults.json
index 1f935ca3f..21c7bb125 100644
--- a/openadapt/config.defaults.json
+++ b/openadapt/config.defaults.json
@@ -82,6 +82,7 @@
"VISUALIZE_MAX_TABLE_CHILDREN": 10,
"SAVE_SCREENSHOT_DIFF": false,
"SPACY_MODEL_NAME": "en_core_web_trf",
+ "EMBEDDING_MODEL_NAME": "sentence-transformers/all-MiniLM-L6-v2",
"DASHBOARD_CLIENT_PORT": 5173,
"DASHBOARD_SERVER_PORT": 8080,
"BROWSER_WEBSOCKET_PORT": 8765,
diff --git a/openadapt/config.py b/openadapt/config.py
index ca3c01801..82366769a 100644
--- a/openadapt/config.py
+++ b/openadapt/config.py
@@ -227,6 +227,9 @@ def validate_scrub_fill_color(cls, v: Union[str, int]) -> int: # noqa: ANN102
# Spacy configurations
SPACY_MODEL_NAME: str = "en_core_web_trf"
+ # Embedding configurations
+ EMBEDDING_MODEL_NAME: str = "sentence-transformers/all-MiniLM-L6-v2"
+
# Dashboard configurations
DASHBOARD_CLIENT_PORT: int = 3000
DASHBOARD_SERVER_PORT: int = 8000
diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py
index 39a4f7677..62291415c 100644
--- a/openadapt/db/crud.py
+++ b/openadapt/db/crud.py
@@ -29,8 +29,10 @@
ScrubbedRecording,
WindowEvent,
copy_sa_instance,
+ RecordingEmbedding,
)
from openadapt.privacy.base import ScrubbingProvider
+from openadapt.embed import get_embedding, get_configured_model_name
BATCH_SIZE = 1
@@ -284,6 +286,11 @@ def insert_recording(session: SaSession, recording_data: dict) -> Recording:
session.add(db_obj)
session.commit()
session.refresh(db_obj)
+
+ # Add embedding
+ if db_obj.task_description:
+ add_or_update_embedding(session, db_obj, db_obj.task_description)
+
return db_obj
@@ -966,3 +973,58 @@ def release_db_lock(raise_exception: bool = True) -> None:
logger.error("Database lock file not found.")
raise
logger.info("Database lock released.")
+
+
+def add_or_update_embedding(session: SaSession, recording: Recording, text: str) -> None:
+ """Adds or updates the embedding for a given recording and text.
+
+ Args:
+ session: The database session.
+ recording: The Recording object.
+ text: The text to embed (e.g., task_description).
+ """
+ current_model_name = get_configured_model_name()
+ embedding_vector = get_embedding(text, model_name=current_model_name)
+ if embedding_vector:
+ # Check if an embedding already exists for this recording and model
+ existing_embedding = (
+ session.query(RecordingEmbedding)
+ .filter_by(recording_id=recording.id, model_name=current_model_name)
+ .first()
+ )
+ if existing_embedding:
+ existing_embedding.embedding = embedding_vector
+ existing_embedding.timestamp = time.time()
+ logger.info(f"Updated embedding for recording_id={recording.id} using model='{current_model_name}'")
+ else:
+ new_embedding = RecordingEmbedding(
+ recording_id=recording.id,
+ embedding=embedding_vector,
+ model_name=current_model_name,
+ )
+ session.add(new_embedding)
+ logger.info(f"Added new embedding for recording_id={recording.id} using model='{current_model_name}'")
+ session.commit()
+ else:
+ logger.warning(
+ f"Could not generate or save embedding for recording_id={recording.id}"
+ )
+
+
+def get_similar_recordings_by_embedding(
+ session: SaSession, query_embedding: list[float], top_n: int = 5
+) -> list[Recording]:
+ """Finds recordings with the most similar embeddings to the query_embedding
+ using pure Python similarity search with NumPy and SciPy.
+
+ Args:
+ session: The database session.
+ query_embedding: The embedding vector of the user's query.
+ top_n: The number of similar recordings to return.
+
+ Returns:
+ A list of Recording objects, ordered by similarity (most similar first).
+ """
+ from openadapt.similarity_search import get_similar_recordings_by_embedding_legacy
+
+ return get_similar_recordings_by_embedding_legacy(session, query_embedding, top_n)
diff --git a/openadapt/embed.py b/openadapt/embed.py
new file mode 100644
index 000000000..983a8e20f
--- /dev/null
+++ b/openadapt/embed.py
@@ -0,0 +1,69 @@
+"""This module handles the generation of text embeddings."""
+
+from sentence_transformers import SentenceTransformer
+from openadapt.custom_logger import logger
+from openadapt.config import config # Import config
+
+# Global variable to cache the loaded model
+_model_cache = {}
+
+def get_configured_model_name() -> str:
+ """Returns the configured embedding model name."""
+ return config.EMBEDDING_MODEL_NAME
+
+def _load_model(model_name: str) -> SentenceTransformer | None:
+ """Loads a sentence transformer model and caches it."""
+ if model_name in _model_cache:
+ return _model_cache[model_name]
+
+ try:
+ model = SentenceTransformer(model_name)
+ logger.info(f"Loaded sentence transformer model: {model_name}")
+ _model_cache[model_name] = model
+ return model
+ except Exception as e:
+ logger.error(f"Failed to load sentence transformer model: {model_name}, {e}")
+ _model_cache[model_name] = None # Cache None to avoid retrying failed loads repeatedly
+ return None
+
+def get_embedding(text: str, model_name: str | None = None) -> list[float] | None:
+ """Generates an embedding for the given text using the specified or configured model.
+
+ Args:
+ text: The text to embed.
+ model_name: Optional. The name of the sentence transformer model to use.
+ If None, uses the model from the global config.
+
+ Returns:
+ The embedding as a list of floats, or None if an error occurs.
+ """
+ current_model_name = model_name or get_configured_model_name()
+ model = _load_model(current_model_name)
+
+ if not model:
+ logger.error(f"Sentence transformer model '{current_model_name}' not loaded. Cannot generate embedding.")
+ return None
+ if not text or not isinstance(text, str):
+ logger.warning(f"Invalid text input for embedding: {text}")
+ return None
+ try:
+ embedding_output = model.encode(text, convert_to_tensor=False, convert_to_numpy=False)
+ # Ensure it's a list of Python floats
+ if hasattr(embedding_output, 'tolist'): # Handles numpy array or torch tensor
+ embedding = embedding_output.tolist()
+ elif isinstance(embedding_output, list):
+ embedding = embedding_output
+ else:
+ # Fallback if it's a single tensor/numpy scalar, though unlikely for sentence embeddings
+ embedding = [float(embedding_output)]
+
+ # Further ensure all elements are floats if it's a list of lists (some models might do this)
+ if embedding and isinstance(embedding[0], list):
+ embedding = [float(item) for sublist in embedding for item in sublist] # Flatten and convert
+ elif embedding:
+ embedding = [float(item) for item in embedding]
+
+ return embedding
+ except Exception as e:
+ logger.error(f"Error generating embedding for text='{text[:100]}...': {e}")
+ return None
\ No newline at end of file
diff --git a/openadapt/models.py b/openadapt/models.py
index 03b60329e..59175e75d 100644
--- a/openadapt/models.py
+++ b/openadapt/models.py
@@ -8,6 +8,7 @@
import io
import sys
import textwrap
+import time
from bs4 import BeautifulSoup
from pynput import keyboard
@@ -100,6 +101,16 @@ class Recording(db.Base):
audio_info = sa.orm.relationship(
"AudioInfo", back_populates="recording", cascade="all, delete-orphan"
)
+ recording_embeddings = sa.orm.relationship(
+ "RecordingEmbedding",
+ back_populates="recording",
+ cascade="all, delete-orphan",
+ )
+ recording_summaries = sa.orm.relationship(
+ "RecordingSummary",
+ back_populates="recording",
+ cascade="all, delete-orphan",
+ )
_processed_action_events = None
@@ -1187,6 +1198,34 @@ class Replay(db.Base):
git_hash = sa.Column(sa.String)
+class RecordingEmbedding(db.Base):
+ """Class representing a recording embedding in the database."""
+
+ __tablename__ = "recording_embedding"
+
+ id = sa.Column(sa.Integer, primary_key=True)
+ recording_id = sa.Column(sa.ForeignKey("recording.id"))
+ embedding = sa.Column(sa.JSON) # Assuming sqlite-vss can handle JSON
+ model_name = sa.Column(sa.String) # To store the name of the embedding model used
+ timestamp = sa.Column(ForceFloat, default=time.time)
+
+ recording = sa.orm.relationship("Recording", back_populates="recording_embeddings")
+
+
+class RecordingSummary(db.Base):
+ """Class representing a hierarchical summary of a recording."""
+
+ __tablename__ = "recording_summary"
+
+ id = sa.Column(sa.Integer, primary_key=True)
+ recording_id = sa.Column(sa.ForeignKey("recording.id"))
+ summary_text = sa.Column(sa.Text)
+ summary_level = sa.Column(sa.String) # e.g., "high", "mid", "low"
+ timestamp = sa.Column(ForceFloat, default=time.time)
+
+ recording = sa.orm.relationship("Recording", back_populates="recording_summaries")
+
+
def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base:
"""Copy a SQLAlchemy instance.
diff --git a/openadapt/scripts/backfill_embeddings.py b/openadapt/scripts/backfill_embeddings.py
new file mode 100644
index 000000000..49ce4d208
--- /dev/null
+++ b/openadapt/scripts/backfill_embeddings.py
@@ -0,0 +1,95 @@
+"""This script backfills embeddings for existing recordings."""
+
+import argparse
+
+from sqlalchemy.orm import Session
+
+from openadapt.db import crud
+from openadapt.models import Recording, RecordingEmbedding
+from openadapt.custom_logger import logger
+from openadapt.embed import get_configured_model_name # Updated import
+
+def backfill_embeddings(session: Session, force_update: bool = False) -> None:
+ """Backfills embeddings for recordings that don't have one or if force_update is True.
+
+ Args:
+ session: The database session.
+ force_update: If True, will re-generate embeddings even if they exist.
+ """
+ logger.info("Starting embedding backfill process...")
+ recordings = session.query(Recording).all()
+ logger.info(f"Found {len(recordings)} recordings to process.")
+
+ processed_count = 0
+ skipped_count = 0
+ error_count = 0
+
+ for recording in recordings:
+ if not recording.task_description:
+ # logger.info(f"Skipping recording_id={recording.id}, no task_description.")
+ skipped_count += 1
+ continue
+
+ if not force_update:
+ # Check if an embedding with the current MODEL_NAME already exists
+ current_model_name = get_configured_model_name()
+ existing_embedding = (
+ session.query(RecordingEmbedding)
+ .filter_by(recording_id=recording.id, model_name=current_model_name)
+ .first()
+ )
+ if existing_embedding:
+ # logger.info(
+ # f"Skipping recording_id={recording.id}, embedding with model='{current_model_name}' already exists."
+ # )
+ skipped_count += 1
+ continue
+
+ logger.info(f"Processing recording_id={recording.id}: '{recording.task_description[:50]}...' ")
+ try:
+ # add_or_update_embedding will handle both creation and update logic
+ # It uses the MODEL_NAME from openadapt.embed
+ crud.add_or_update_embedding(session, recording, recording.task_description)
+ processed_count += 1
+ except Exception as e:
+ logger.error(f"Error processing recording_id={recording.id}: {e}")
+ error_count += 1
+ # Optionally, rollback session for this specific error to not affect others
+ # session.rollback()
+
+ if processed_count > 0 or error_count > 0 : # only commit if there were changes or errors to log
+ try:
+ session.commit()
+ logger.info("Committed changes to the database.")
+ except Exception as e:
+ logger.error(f"Error committing changes to database: {e}")
+ session.rollback()
+
+ logger.info(f"Embedding backfill complete. Processed: {processed_count}, Skipped: {skipped_count}, Errors: {error_count}")
+
+def main() -> None:
+ """Main function to run the backfill script."""
+ parser = argparse.ArgumentParser(description="Backfill embeddings for existing recordings.")
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Force update embeddings even if they already exist for the current model.",
+ )
+ args = parser.parse_args()
+ db_session = None
+ try:
+ # Use read_and_write=True for the session
+ db_session = crud.get_new_session(read_and_write=True)
+ backfill_embeddings(db_session, force_update=args.force)
+ except Exception as e:
+ logger.error(f"An unexpected error occurred: {e}")
+ if db_session:
+ db_session.rollback()
+ finally:
+ if db_session:
+ db_session.close()
+ # crud.release_db_lock()
+ logger.info("Backfill script finished.")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/openadapt/similarity_search.py b/openadapt/similarity_search.py
new file mode 100644
index 000000000..a9e2ba1a2
--- /dev/null
+++ b/openadapt/similarity_search.py
@@ -0,0 +1,131 @@
+"""Vector similarity search functionality using pure Python."""
+
+import json
+import re
+import time
+from collections import Counter
+from typing import List, Tuple
+
+import numpy as np
+from loguru import logger
+from scipy.spatial.distance import cosine
+
+from openadapt.db.crud import SaSession
+from openadapt.embed import get_embedding, get_configured_model_name
+from openadapt.models import Recording, RecordingEmbedding
+
+# A small set of common English stop words to avoid external dependencies or downloads.
+STOP_WORDS = {
+ "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he",
+ "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "were",
+ "will", "with", "i", "you", "your", "d", "s", "t", "ve",
+}
+
+
+def _extract_context(recording: Recording) -> str:
+ """Extract a general context from a recording to improve search relevance."""
+ context_parts = []
+
+ if recording.task_description:
+ context_parts.append(recording.task_description)
+
+ # Extract keywords from window titles
+ if recording.window_events:
+ all_titles = " ".join(
+ [event.title for event in recording.window_events if event.title]
+ )
+ words = re.findall(r"\w+", all_titles.lower())
+ meaningful_words = [
+ word for word in words if word not in STOP_WORDS and not word.isdigit()
+ ]
+
+ if meaningful_words:
+ # Add top 5 most common words as context
+ word_counts = Counter(meaningful_words)
+ for word, _ in word_counts.most_common(5):
+ if len(word) > 2: # filter out very short words
+ context_parts.append(word)
+
+ # Extract context from action events
+ if recording.action_events:
+ # Get a summary of action types
+ action_types = set(ae.name for ae in recording.action_events if ae.name)
+ if action_types:
+ context_parts.append("actions:" + ",".join(sorted(list(action_types))))
+
+ # Detect significant text input
+ typed_text = "".join(
+ ae.key_char
+ for ae in recording.action_events
+ if ae.name == "keypress" and ae.key_char and ae.key_char.isalnum()
+ )
+ if len(typed_text) > 15: # Heuristic for meaningful text input
+ context_parts.append("text_input")
+
+ # Add transcribed text from audio if available
+ if hasattr(recording, "audio_info") and recording.audio_info:
+ for audio in recording.audio_info:
+ if audio.transcribed_text:
+ context_parts.append("transcription:" + audio.transcribed_text)
+
+ # Deduplicate and join
+ unique_context = sorted(list(set(context_parts)))
+ return " ".join(unique_context)
+
+
+def get_enhanced_similarity_search(
+ session: SaSession, query_text: str, top_n: int = 5
+) -> List[Tuple[Recording, float]]:
+ """Perform an enhanced search using application and action context.
+
+ Args:
+ session: The database session.
+ query_text: The user's search query.
+ top_n: The number of results to return.
+
+ Returns:
+ A list of (Recording, similarity_score) tuples.
+ """
+ model_name = get_configured_model_name()
+ query_embedding = get_embedding(query_text, model_name=model_name)
+
+ if not query_embedding:
+ logger.error("Could not generate embedding for query text.")
+ return []
+
+ recordings_with_embeddings = (
+ session.query(Recording, RecordingEmbedding)
+ .join(RecordingEmbedding)
+ .filter(RecordingEmbedding.model_name == model_name)
+ .all()
+ )
+
+ results = []
+ for recording, embedding_record in recordings_with_embeddings:
+ enhanced_context = _extract_context(recording)
+ context_embedding = get_embedding(enhanced_context, model_name=model_name)
+
+ try:
+ stored_embedding = (
+ json.loads(embedding_record.embedding)
+ if isinstance(embedding_record.embedding, str)
+ else embedding_record.embedding
+ )
+ description_similarity = 1 - cosine(query_embedding, stored_embedding)
+ context_similarity = (
+ (1 - cosine(query_embedding, context_embedding))
+ if context_embedding
+ else 0.0
+ )
+
+ combined_similarity = (description_similarity * 0.6) + (
+ context_similarity * 0.4
+ )
+ results.append((recording, combined_similarity))
+ except (TypeError, json.JSONDecodeError) as e:
+ logger.error(
+ f"Error processing embedding for recording {recording.id}: {e}"
+ )
+
+ results.sort(key=lambda x: x[1], reverse=True)
+ return results[:top_n]
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 7f1b705c1..87fda8cba 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,10 @@ pytest = "7.1.3"
rapidocr-onnxruntime = "1.2.3"
scikit-learn = "1.2.2"
scipy = "^1.11.0"
+numpy = "^1.24.0"
torch = "^2.0.0"
tqdm = "4.64.0"
-transformers = "4.29.2"
+transformers = "^4.32.0"
python-dotenv = "1.0.0"
pyinstaller = "6.11.0"
setuptools-lint = "^0.6.0"
@@ -99,6 +100,7 @@ pudb = "^2024.1"
sounddevice = "^0.4.6"
soundfile = "^0.12.1"
posthog = "^3.5.0"
+sentence-transformers = "^2.3.1"
wheel = "^0.43.0"
cython = "^3.0.10"