diff --git a/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py b/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py index 05b63d4ff..81c3c8dd7 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py +++ b/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py @@ -1,8 +1,13 @@ from __future__ import annotations import asyncio +import importlib.util +import inspect import os +import sys +from glob import glob from logging import Logger +from pathlib import Path from time import time_ns from typing import TYPE_CHECKING, Any, ClassVar @@ -47,11 +52,14 @@ class PersonaManager(LoggingConfigurable): type for type checkers. """ + # TODO: the Persona classes from entry points should be stored as a class + # attribute, since they will not change at runtime. + # That should be injected into this instance attribute when personas defined + # under `.jupyter` are loaded. + _persona_classes: list[type[BasePersona]] | None = None _personas: dict[str, BasePersona] file_id: str - # class attrs - _persona_classes: ClassVar[list[type[BasePersona]] | None] = None def __init__( self, @@ -87,9 +95,9 @@ def __init__( # This is stored in a class attribute (global to all instances) because # the entry points are immutable after the server starts, so they only # need to be loaded once. - if not isinstance(PersonaManager._persona_classes, list): + if not isinstance(self._persona_classes, list): self._init_persona_classes() - assert isinstance(PersonaManager._persona_classes, list) + assert isinstance(self._persona_classes, list) self._personas = self._init_personas() @@ -98,12 +106,18 @@ def _init_persona_classes(self) -> None: Initializes the list of persona *classes* by retrieving the `jupyter-ai.personas` entry points group. - This list is cached in the `PersonaManager._persona_classes` class - attribute, i.e. this method should only run once in the extension + # TODO: fix this part of docs now that we have it as an instance attr. + This list is cached in the `self._persona_classes` instance + attribute, .e. this method should only run once in the extension lifecycle. """ - if PersonaManager._persona_classes: - return + # Loading is in two parts: + # 1. Load persona classes from package entry points. + # 2. Load persona classes from local filesystem. + # + # This allows for lightweight development of new personas by the user in + # their local filesystem. The first part is done here, and the second + # part is done in `_init_personas()`. persona_eps = entry_points().select(group=EPG_NAME) self.log.info(f"Found {len(persona_eps)} entry points under '{EPG_NAME}'.") @@ -140,7 +154,15 @@ def _init_persona_classes(self) -> None: "ERROR: Jupyter AI has no AI personas available. " + "Please verify your server configuration and open a new issue on our GitHub repo if this warning persists." ) - PersonaManager._persona_classes = persona_classes + + # Load persona classes from local filesystem + dotjupyter_dir = self.get_dotjupyter_dir() + if dotjupyter_dir is None: + self.log.info("No .jupyter directory found for loading local personas.") + else: + persona_classes.extend(load_from_dir(dotjupyter_dir, self.log)) + + self._persona_classes = persona_classes def _init_personas(self) -> dict[str, BasePersona]: """ @@ -148,7 +170,7 @@ def _init_personas(self) -> dict[str, BasePersona]: to the constructor. """ # Ensure that persona classes were initialized first - persona_classes = PersonaManager._persona_classes + persona_classes = self._persona_classes assert isinstance(persona_classes, list) # If no persona classes are available, log a warning and return @@ -287,3 +309,91 @@ def get_mcp_config(self) -> dict[str, Any]: return {} else: return self._mcp_config_loader.get_config(jdir) + + +def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]: + """ + Load _persona class declarations_ from Python files in the local filesystem. + + Those class declarations are then used to instantiate personas by the + `PersonaManager`. + + Scans the root_dir for .py files containing `persona` in their name that do + _not_ start with a single `_` (i.e. private modules are skipped). Then, it + dynamically imports them, and extracts any class declarations that are + subclasses of `BasePersona`. + + Args: + root_dir: Directory to scan for persona Python files. + log: Logger instance for logging messages. + + Returns: + List of `BasePersona` subclasses found in the directory. + """ + persona_classes: list[type[BasePersona]] = [] + + log.info(f"Searching for persona files in {root_dir}") + # Check if root directory exists + if not os.path.exists(root_dir): + return persona_classes + + # Find all .py files in the root directory that contain "persona" in the name + try: + all_py_files = glob(os.path.join(root_dir, "*.py")) + py_files = [] + for f in all_py_files: + fname_lower = Path(f).stem.lower() + if "persona" in fname_lower and not (fname_lower.startswith("_") or fname_lower.startswith(".")): + py_files.append(f) + + except Exception as e: + # On exception with glob operation, return empty list + log.error(f"{type(e).__name__} occurred while searching for Python files in {root_dir}") + return persona_classes + + if py_files: + log.info(f"Found files from {root_dir}: {[Path(f).name for f in py_files]}") + + # Temporarily add root_dir to sys.path for imports + root_dir_in_path = root_dir in sys.path + if not root_dir_in_path: + sys.path.insert(0, root_dir) + + try: + # For each .py file, dynamically import the module and extract all + # BasePersona subclasses. + for py_file in py_files: + try: + # Get module name from file path + module_name = Path(py_file).stem + + # Create module spec and load the module + spec = importlib.util.spec_from_file_location(module_name, py_file) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all classes in the module that are BasePersona subclasses + for name, obj in inspect.getmembers(module, inspect.isclass): + # Check if it's a subclass of BasePersona but not BasePersona itself + if ( + issubclass(obj, BasePersona) + and obj is not BasePersona + and obj.__module__ == module_name + ): + log.info(f"Found persona class '{obj.__name__}' in '{py_file}'") + persona_classes.append(obj) + + except Exception as e: + # On exception, log error and continue to next file + log.exception(f"Unable to load persona classes from '{py_file}', exception details printed below.") + continue + finally: + # Remove root_dir from sys.path if we added it + if not root_dir_in_path and root_dir in sys.path: + sys.path.remove(root_dir) + + return persona_classes + diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_personas.py b/packages/jupyter-ai/jupyter_ai/tests/test_personas.py new file mode 100644 index 000000000..17bea6532 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_personas.py @@ -0,0 +1,75 @@ +""" +Test the local persona manager. +""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults +from jupyter_ai.personas.persona_manager import load_from_dir + + +@pytest.fixture +def tmp_persona_dir(): + """Create a temporary directory for testing LocalPersonaLoader with guaranteed cleanup.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def mock_logger(): + """Create a mock logger for testing.""" + return Mock() + + +class TestLoadPersonaClassesFromDirectory: + """Test cases for load_from_dir function.""" + + def test_empty_directory_returns_empty_list(self, tmp_persona_dir, mock_logger): + """Test that an empty directory returns an empty list of persona classes.""" + result = load_from_dir(str(tmp_persona_dir), mock_logger) + assert result == [] + + def test_non_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger): + """Test that a Python file without persona classes returns an empty list.""" + # Create a file that doesn't contain "persona" in the name + non_persona_file = tmp_persona_dir / "no_personas.py" + non_persona_file.write_text("pass") + + result = load_from_dir(str(tmp_persona_dir), mock_logger) + assert result == [] + + def test_simple_persona_file_returns_persona_class(self, tmp_persona_dir, mock_logger): + """Test that a file with a BasePersona subclass returns that class.""" + # Create a simple persona file + persona_file = tmp_persona_dir / "simple_personas.py" + persona_content = """ +from jupyter_ai.personas.base_persona import BasePersona + +class TestPersona(BasePersona): + id = "test_persona" + name = "Test Persona" + description = "A simple test persona" + + def process_message(self, message): + pass +""" + persona_file.write_text(persona_content) + + result = load_from_dir(str(tmp_persona_dir), mock_logger) + + assert len(result) == 1 + assert result[0].__name__ == "TestPersona" + assert issubclass(result[0], BasePersona) + + def test_bad_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger): + """Test that a file with syntax errors returns empty list.""" + # Create a file with invalid Python code + bad_persona_file = tmp_persona_dir / "bad_persona.py" + bad_persona_file.write_text("1/0") + + result = load_from_dir(str(tmp_persona_dir), mock_logger) + + assert result == []