Skip to content

Load personas dynamically from .jupyter dir #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 120 additions & 10 deletions packages/jupyter-ai/jupyter_ai/personas/persona_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import asyncio
import importlib.util
import inspect
import os
import sys
from glob import glob
from logging import Logger
from pathlib import Path
from time import time_ns
from typing import TYPE_CHECKING, Any, ClassVar

Expand Down Expand Up @@ -47,11 +52,14 @@ class PersonaManager(LoggingConfigurable):
type for type checkers.
"""

# TODO: the Persona classes from entry points should be stored as a class
# attribute, since they will not change at runtime.
# That should be injected into this instance attribute when personas defined
# under `.jupyter` are loaded.
_persona_classes: list[type[BasePersona]] | None = None
_personas: dict[str, BasePersona]
file_id: str

# class attrs
_persona_classes: ClassVar[list[type[BasePersona]] | None] = None

def __init__(
self,
Expand Down Expand Up @@ -87,9 +95,9 @@ def __init__(
# This is stored in a class attribute (global to all instances) because
# the entry points are immutable after the server starts, so they only
# need to be loaded once.
if not isinstance(PersonaManager._persona_classes, list):
if not isinstance(self._persona_classes, list):
self._init_persona_classes()
assert isinstance(PersonaManager._persona_classes, list)
assert isinstance(self._persona_classes, list)

self._personas = self._init_personas()

Expand All @@ -98,12 +106,18 @@ def _init_persona_classes(self) -> None:
Initializes the list of persona *classes* by retrieving the
`jupyter-ai.personas` entry points group.

This list is cached in the `PersonaManager._persona_classes` class
attribute, i.e. this method should only run once in the extension
# TODO: fix this part of docs now that we have it as an instance attr.
This list is cached in the `self._persona_classes` instance
attribute, .e. this method should only run once in the extension
lifecycle.
"""
if PersonaManager._persona_classes:
return
# Loading is in two parts:
# 1. Load persona classes from package entry points.
# 2. Load persona classes from local filesystem.
#
# This allows for lightweight development of new personas by the user in
# their local filesystem. The first part is done here, and the second
# part is done in `_init_personas()`.

persona_eps = entry_points().select(group=EPG_NAME)
self.log.info(f"Found {len(persona_eps)} entry points under '{EPG_NAME}'.")
Expand Down Expand Up @@ -140,15 +154,23 @@ def _init_persona_classes(self) -> None:
"ERROR: Jupyter AI has no AI personas available. "
+ "Please verify your server configuration and open a new issue on our GitHub repo if this warning persists."
)
PersonaManager._persona_classes = persona_classes

# Load persona classes from local filesystem
dotjupyter_dir = self.get_dotjupyter_dir()
if dotjupyter_dir is None:
self.log.info("No .jupyter directory found for loading local personas.")
else:
persona_classes.extend(load_from_dir(dotjupyter_dir, self.log))

self._persona_classes = persona_classes

def _init_personas(self) -> dict[str, BasePersona]:
"""
Initializes the list of persona instances for the YChat instance passed
to the constructor.
"""
# Ensure that persona classes were initialized first
persona_classes = PersonaManager._persona_classes
persona_classes = self._persona_classes
assert isinstance(persona_classes, list)

# If no persona classes are available, log a warning and return
Expand Down Expand Up @@ -287,3 +309,91 @@ def get_mcp_config(self) -> dict[str, Any]:
return {}
else:
return self._mcp_config_loader.get_config(jdir)


def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]:
"""
Load _persona class declarations_ from Python files in the local filesystem.

Those class declarations are then used to instantiate personas by the
`PersonaManager`.

Scans the root_dir for .py files containing `persona` in their name that do
_not_ start with a single `_` (i.e. private modules are skipped). Then, it
dynamically imports them, and extracts any class declarations that are
subclasses of `BasePersona`.

Args:
root_dir: Directory to scan for persona Python files.
log: Logger instance for logging messages.

Returns:
List of `BasePersona` subclasses found in the directory.
"""
persona_classes: list[type[BasePersona]] = []

log.info(f"Searching for persona files in {root_dir}")
# Check if root directory exists
if not os.path.exists(root_dir):
return persona_classes

# Find all .py files in the root directory that contain "persona" in the name
try:
all_py_files = glob(os.path.join(root_dir, "*.py"))
py_files = []
for f in all_py_files:
fname_lower = Path(f).stem.lower()
if "persona" in fname_lower and not (fname_lower.startswith("_") or fname_lower.startswith(".")):
py_files.append(f)

except Exception as e:
# On exception with glob operation, return empty list
log.error(f"{type(e).__name__} occurred while searching for Python files in {root_dir}")
return persona_classes

if py_files:
log.info(f"Found files from {root_dir}: {[Path(f).name for f in py_files]}")

# Temporarily add root_dir to sys.path for imports
root_dir_in_path = root_dir in sys.path
if not root_dir_in_path:
sys.path.insert(0, root_dir)

try:
# For each .py file, dynamically import the module and extract all
# BasePersona subclasses.
for py_file in py_files:
try:
# Get module name from file path
module_name = Path(py_file).stem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want a private function that takes a py_file and returns the persona class or None. When we start to cache and enable reloading, we may want that encapsulation (this can wait until we really need it).


# Create module spec and load the module
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec is None or spec.loader is None:
continue

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find all classes in the module that are BasePersona subclasses
for name, obj in inspect.getmembers(module, inspect.isclass):
# Check if it's a subclass of BasePersona but not BasePersona itself
if (
issubclass(obj, BasePersona)
and obj is not BasePersona
and obj.__module__ == module_name
):
log.info(f"Found persona class '{obj.__name__}' in '{py_file}'")
persona_classes.append(obj)

except Exception as e:
# On exception, log error and continue to next file
log.exception(f"Unable to load persona classes from '{py_file}', exception details printed below.")
continue
finally:
# Remove root_dir from sys.path if we added it
if not root_dir_in_path and root_dir in sys.path:
sys.path.remove(root_dir)

return persona_classes

75 changes: 75 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_personas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Test the local persona manager.
"""

import tempfile
from pathlib import Path
from unittest.mock import Mock

import pytest
from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults
from jupyter_ai.personas.persona_manager import load_from_dir


@pytest.fixture
def tmp_persona_dir():
"""Create a temporary directory for testing LocalPersonaLoader with guaranteed cleanup."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)


@pytest.fixture
def mock_logger():
"""Create a mock logger for testing."""
return Mock()


class TestLoadPersonaClassesFromDirectory:
"""Test cases for load_from_dir function."""

def test_empty_directory_returns_empty_list(self, tmp_persona_dir, mock_logger):
"""Test that an empty directory returns an empty list of persona classes."""
result = load_from_dir(str(tmp_persona_dir), mock_logger)
assert result == []

def test_non_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger):
"""Test that a Python file without persona classes returns an empty list."""
# Create a file that doesn't contain "persona" in the name
non_persona_file = tmp_persona_dir / "no_personas.py"
non_persona_file.write_text("pass")

result = load_from_dir(str(tmp_persona_dir), mock_logger)
assert result == []

def test_simple_persona_file_returns_persona_class(self, tmp_persona_dir, mock_logger):
"""Test that a file with a BasePersona subclass returns that class."""
# Create a simple persona file
persona_file = tmp_persona_dir / "simple_personas.py"
persona_content = """
from jupyter_ai.personas.base_persona import BasePersona

class TestPersona(BasePersona):
id = "test_persona"
name = "Test Persona"
description = "A simple test persona"

def process_message(self, message):
pass
"""
persona_file.write_text(persona_content)

result = load_from_dir(str(tmp_persona_dir), mock_logger)

assert len(result) == 1
assert result[0].__name__ == "TestPersona"
assert issubclass(result[0], BasePersona)

def test_bad_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger):
"""Test that a file with syntax errors returns empty list."""
# Create a file with invalid Python code
bad_persona_file = tmp_persona_dir / "bad_persona.py"
bad_persona_file.write_text("1/0")

result = load_from_dir(str(tmp_persona_dir), mock_logger)

assert result == []
Loading