Skip to content

Commit 484fd20

Browse files
fperezpre-commit-ci[bot]dlqqq
authored
Load personas dynamically from .jupyter dir (#1380)
* First cut at the PersonaLoader imnplementation, no tests yet. * Scope import logic more tightly for persona files. Only match (case insensitively) Python files with `persona` in the name. * Add first pass of unit tests for the local persona loader. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor into standalone function * Make persona_classes an instance attr so it actually reloads per chat file. Also, rename local loader to a shorter name. * Improve logging so we can debug persona loading issues * Ignore files with leading dot * Add a bit more logging for local persona loading * Fix logic to skip . and _ files from the start instead of in the later loop, clarify logging * Use log.exception correctly to print tracebacks * Fix path loading logic so personas can import local resources * fix mypy errors --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David L. Qiu <david@qiu.dev>
1 parent b597e5c commit 484fd20

File tree

2 files changed

+195
-10
lines changed

2 files changed

+195
-10
lines changed

packages/jupyter-ai/jupyter_ai/personas/persona_manager.py

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import importlib.util
5+
import inspect
46
import os
7+
import sys
8+
from glob import glob
59
from logging import Logger
10+
from pathlib import Path
611
from time import time_ns
712
from typing import TYPE_CHECKING, Any, ClassVar
813

@@ -47,11 +52,14 @@ class PersonaManager(LoggingConfigurable):
4752
type for type checkers.
4853
"""
4954

55+
# TODO: the Persona classes from entry points should be stored as a class
56+
# attribute, since they will not change at runtime.
57+
# That should be injected into this instance attribute when personas defined
58+
# under `.jupyter` are loaded.
59+
_persona_classes: list[type[BasePersona]] | None = None
5060
_personas: dict[str, BasePersona]
5161
file_id: str
5262

53-
# class attrs
54-
_persona_classes: ClassVar[list[type[BasePersona]] | None] = None
5563

5664
def __init__(
5765
self,
@@ -87,9 +95,9 @@ def __init__(
8795
# This is stored in a class attribute (global to all instances) because
8896
# the entry points are immutable after the server starts, so they only
8997
# need to be loaded once.
90-
if not isinstance(PersonaManager._persona_classes, list):
98+
if not isinstance(self._persona_classes, list):
9199
self._init_persona_classes()
92-
assert isinstance(PersonaManager._persona_classes, list)
100+
assert isinstance(self._persona_classes, list)
93101

94102
self._personas = self._init_personas()
95103

@@ -98,12 +106,18 @@ def _init_persona_classes(self) -> None:
98106
Initializes the list of persona *classes* by retrieving the
99107
`jupyter-ai.personas` entry points group.
100108
101-
This list is cached in the `PersonaManager._persona_classes` class
102-
attribute, i.e. this method should only run once in the extension
109+
# TODO: fix this part of docs now that we have it as an instance attr.
110+
This list is cached in the `self._persona_classes` instance
111+
attribute, .e. this method should only run once in the extension
103112
lifecycle.
104113
"""
105-
if PersonaManager._persona_classes:
106-
return
114+
# Loading is in two parts:
115+
# 1. Load persona classes from package entry points.
116+
# 2. Load persona classes from local filesystem.
117+
#
118+
# This allows for lightweight development of new personas by the user in
119+
# their local filesystem. The first part is done here, and the second
120+
# part is done in `_init_personas()`.
107121

108122
persona_eps = entry_points().select(group=EPG_NAME)
109123
self.log.info(f"Found {len(persona_eps)} entry points under '{EPG_NAME}'.")
@@ -140,15 +154,23 @@ def _init_persona_classes(self) -> None:
140154
"ERROR: Jupyter AI has no AI personas available. "
141155
+ "Please verify your server configuration and open a new issue on our GitHub repo if this warning persists."
142156
)
143-
PersonaManager._persona_classes = persona_classes
157+
158+
# Load persona classes from local filesystem
159+
dotjupyter_dir = self.get_dotjupyter_dir()
160+
if dotjupyter_dir is None:
161+
self.log.info("No .jupyter directory found for loading local personas.")
162+
else:
163+
persona_classes.extend(load_from_dir(dotjupyter_dir, self.log))
164+
165+
self._persona_classes = persona_classes
144166

145167
def _init_personas(self) -> dict[str, BasePersona]:
146168
"""
147169
Initializes the list of persona instances for the YChat instance passed
148170
to the constructor.
149171
"""
150172
# Ensure that persona classes were initialized first
151-
persona_classes = PersonaManager._persona_classes
173+
persona_classes = self._persona_classes
152174
assert isinstance(persona_classes, list)
153175

154176
# If no persona classes are available, log a warning and return
@@ -287,3 +309,91 @@ def get_mcp_config(self) -> dict[str, Any]:
287309
return {}
288310
else:
289311
return self._mcp_config_loader.get_config(jdir)
312+
313+
314+
def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]:
315+
"""
316+
Load _persona class declarations_ from Python files in the local filesystem.
317+
318+
Those class declarations are then used to instantiate personas by the
319+
`PersonaManager`.
320+
321+
Scans the root_dir for .py files containing `persona` in their name that do
322+
_not_ start with a single `_` (i.e. private modules are skipped). Then, it
323+
dynamically imports them, and extracts any class declarations that are
324+
subclasses of `BasePersona`.
325+
326+
Args:
327+
root_dir: Directory to scan for persona Python files.
328+
log: Logger instance for logging messages.
329+
330+
Returns:
331+
List of `BasePersona` subclasses found in the directory.
332+
"""
333+
persona_classes: list[type[BasePersona]] = []
334+
335+
log.info(f"Searching for persona files in {root_dir}")
336+
# Check if root directory exists
337+
if not os.path.exists(root_dir):
338+
return persona_classes
339+
340+
# Find all .py files in the root directory that contain "persona" in the name
341+
try:
342+
all_py_files = glob(os.path.join(root_dir, "*.py"))
343+
py_files = []
344+
for f in all_py_files:
345+
fname_lower = Path(f).stem.lower()
346+
if "persona" in fname_lower and not (fname_lower.startswith("_") or fname_lower.startswith(".")):
347+
py_files.append(f)
348+
349+
except Exception as e:
350+
# On exception with glob operation, return empty list
351+
log.error(f"{type(e).__name__} occurred while searching for Python files in {root_dir}")
352+
return persona_classes
353+
354+
if py_files:
355+
log.info(f"Found files from {root_dir}: {[Path(f).name for f in py_files]}")
356+
357+
# Temporarily add root_dir to sys.path for imports
358+
root_dir_in_path = root_dir in sys.path
359+
if not root_dir_in_path:
360+
sys.path.insert(0, root_dir)
361+
362+
try:
363+
# For each .py file, dynamically import the module and extract all
364+
# BasePersona subclasses.
365+
for py_file in py_files:
366+
try:
367+
# Get module name from file path
368+
module_name = Path(py_file).stem
369+
370+
# Create module spec and load the module
371+
spec = importlib.util.spec_from_file_location(module_name, py_file)
372+
if spec is None or spec.loader is None:
373+
continue
374+
375+
module = importlib.util.module_from_spec(spec)
376+
spec.loader.exec_module(module)
377+
378+
# Find all classes in the module that are BasePersona subclasses
379+
for name, obj in inspect.getmembers(module, inspect.isclass):
380+
# Check if it's a subclass of BasePersona but not BasePersona itself
381+
if (
382+
issubclass(obj, BasePersona)
383+
and obj is not BasePersona
384+
and obj.__module__ == module_name
385+
):
386+
log.info(f"Found persona class '{obj.__name__}' in '{py_file}'")
387+
persona_classes.append(obj)
388+
389+
except Exception as e:
390+
# On exception, log error and continue to next file
391+
log.exception(f"Unable to load persona classes from '{py_file}', exception details printed below.")
392+
continue
393+
finally:
394+
# Remove root_dir from sys.path if we added it
395+
if not root_dir_in_path and root_dir in sys.path:
396+
sys.path.remove(root_dir)
397+
398+
return persona_classes
399+
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Test the local persona manager.
3+
"""
4+
5+
import tempfile
6+
from pathlib import Path
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults
11+
from jupyter_ai.personas.persona_manager import load_from_dir
12+
13+
14+
@pytest.fixture
15+
def tmp_persona_dir():
16+
"""Create a temporary directory for testing LocalPersonaLoader with guaranteed cleanup."""
17+
with tempfile.TemporaryDirectory() as temp_dir:
18+
yield Path(temp_dir)
19+
20+
21+
@pytest.fixture
22+
def mock_logger():
23+
"""Create a mock logger for testing."""
24+
return Mock()
25+
26+
27+
class TestLoadPersonaClassesFromDirectory:
28+
"""Test cases for load_from_dir function."""
29+
30+
def test_empty_directory_returns_empty_list(self, tmp_persona_dir, mock_logger):
31+
"""Test that an empty directory returns an empty list of persona classes."""
32+
result = load_from_dir(str(tmp_persona_dir), mock_logger)
33+
assert result == []
34+
35+
def test_non_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger):
36+
"""Test that a Python file without persona classes returns an empty list."""
37+
# Create a file that doesn't contain "persona" in the name
38+
non_persona_file = tmp_persona_dir / "no_personas.py"
39+
non_persona_file.write_text("pass")
40+
41+
result = load_from_dir(str(tmp_persona_dir), mock_logger)
42+
assert result == []
43+
44+
def test_simple_persona_file_returns_persona_class(self, tmp_persona_dir, mock_logger):
45+
"""Test that a file with a BasePersona subclass returns that class."""
46+
# Create a simple persona file
47+
persona_file = tmp_persona_dir / "simple_personas.py"
48+
persona_content = """
49+
from jupyter_ai.personas.base_persona import BasePersona
50+
51+
class TestPersona(BasePersona):
52+
id = "test_persona"
53+
name = "Test Persona"
54+
description = "A simple test persona"
55+
56+
def process_message(self, message):
57+
pass
58+
"""
59+
persona_file.write_text(persona_content)
60+
61+
result = load_from_dir(str(tmp_persona_dir), mock_logger)
62+
63+
assert len(result) == 1
64+
assert result[0].__name__ == "TestPersona"
65+
assert issubclass(result[0], BasePersona)
66+
67+
def test_bad_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger):
68+
"""Test that a file with syntax errors returns empty list."""
69+
# Create a file with invalid Python code
70+
bad_persona_file = tmp_persona_dir / "bad_persona.py"
71+
bad_persona_file.write_text("1/0")
72+
73+
result = load_from_dir(str(tmp_persona_dir), mock_logger)
74+
75+
assert result == []

0 commit comments

Comments
 (0)