Skip to content

Commit 019bab6

Browse files
authored
Allow personas to get chat path and directory (#1379)
* add warning on v2 persona model * add get_chat_path() and get_chat_dir() methods to persona manager * make PersonaManager logging configurable * make BasePersona logging configurable * add chat path & dir methods to BasePersona * pre-commit * fix mypy errors * pre-commit * fix mypy * make get_chat_path() absolute by default * fixup previous commit
1 parent d773532 commit 019bab6

File tree

5 files changed

+160
-33
lines changed

5 files changed

+160
-33
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/models/persona.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# NOTE: This is the outdated `Persona` model used by Jupyter AI v2.
2+
# This is deprecated and will be removed by Jupyter AI v3.
3+
# The latest definition of a persona is located in
4+
# `jupyter_ai/personas/base_persona.py`.
5+
#
6+
# TODO: Delete this file once v3 model API changes are complete. The current model
7+
# API still depends on this, so that work must be done first.
8+
19
from pydantic import BaseModel
210

311

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
1111
from jupyter_events import EventLogger
1212
from jupyter_server.extension.application import ExtensionApp
13+
from jupyter_server_fileid.manager import ( # type: ignore[import-untyped]
14+
BaseFileIdManager,
15+
)
1316
from jupyterlab_chat.models import Message
1417
from jupyterlab_chat.ychat import YChat
1518
from pycrdt import ArrayEvent
1619
from tornado.web import StaticFileHandler
17-
from traitlets import Integer, List, Unicode
20+
from traitlets import Integer, List, Type, Unicode
1821

1922
from .completions.handlers import DefaultInlineCompletionHandler
2023
from .config_manager import ConfigManager
@@ -71,6 +74,13 @@ class AiExtension(ExtensionApp):
7174
),
7275
]
7376

77+
persona_manager_class = Type(
78+
klass=PersonaManager,
79+
default_value=PersonaManager,
80+
config=True,
81+
help="The `PersonaManager` class.",
82+
)
83+
7484
allowed_providers = List(
7585
Unicode(),
7686
default_value=None,
@@ -230,7 +240,7 @@ async def connect_chat(
230240
return
231241

232242
# initialize persona manager
233-
persona_manager = self._init_persona_manager(ychat)
243+
persona_manager = self._init_persona_manager(room_id, ychat)
234244
if not persona_manager:
235245
self.log.error(
236246
"Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
@@ -372,7 +382,9 @@ async def _stop_extension(self):
372382
"""
373383
# TODO: explore if cleanup is necessary
374384

375-
def _init_persona_manager(self, ychat: YChat) -> Optional[PersonaManager]:
385+
def _init_persona_manager(
386+
self, room_id: str, ychat: YChat
387+
) -> Optional[PersonaManager]:
376388
"""
377389
Initializes a `PersonaManager` instance scoped to a `YChat`.
378390
@@ -390,11 +402,27 @@ def _init_persona_manager(self, ychat: YChat) -> Optional[PersonaManager]:
390402
message_interrupted, dict
391403
)
392404

393-
persona_manager = PersonaManager(
405+
assert self.serverapp
406+
assert self.serverapp.web_app
407+
assert self.serverapp.web_app.settings
408+
fileid_manager = self.serverapp.web_app.settings.get(
409+
"file_id_manager", None
410+
)
411+
assert isinstance(fileid_manager, BaseFileIdManager)
412+
413+
contents_manager = self.serverapp.contents_manager
414+
root_dir = getattr(contents_manager, "root_dir", None)
415+
assert isinstance(root_dir, str)
416+
417+
PersonaManagerClass = self.persona_manager_class
418+
persona_manager = PersonaManagerClass(
419+
parent=self,
420+
room_id=room_id,
394421
ychat=ychat,
395422
config_manager=config_manager,
423+
fileid_manager=fileid_manager,
424+
root_dir=root_dir,
396425
event_loop=self.event_loop,
397-
log=self.log,
398426
message_interrupted=message_interrupted,
399427
)
400428
except Exception as e:

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from abc import ABC, abstractmethod
2+
from abc import ABC, ABCMeta, abstractmethod
33
from dataclasses import asdict
44
from logging import Logger
55
from time import time
@@ -9,6 +9,8 @@
99
from jupyterlab_chat.models import Message, NewMessage, User
1010
from jupyterlab_chat.ychat import YChat
1111
from pydantic import BaseModel
12+
from traitlets import MetaHasTraits
13+
from traitlets.config import LoggingConfigurable
1214

1315
from .persona_awareness import PersonaAwareness
1416

@@ -44,7 +46,15 @@ class PersonaDefaults(BaseModel):
4446
# ^^^ set this to automatically default to a model after a fresh start, no config file
4547

4648

47-
class BasePersona(ABC):
49+
class ABCLoggingConfigurableMeta(ABCMeta, MetaHasTraits):
50+
"""
51+
Metaclass required for `BasePersona` to inherit from both `ABC` and
52+
`LoggingConfigurable`. This pattern is also followed by `BaseFileIdManager`
53+
from `jupyter_server_fileid`.
54+
"""
55+
56+
57+
class BasePersona(ABC, LoggingConfigurable, metaclass=ABCLoggingConfigurableMeta):
4858
"""
4959
Abstract base class that defines a persona when implemented.
5060
"""
@@ -55,21 +65,22 @@ class BasePersona(ABC):
5565
Automatically set by `BasePersona`.
5666
"""
5767

58-
manager: "PersonaManager"
68+
parent: "PersonaManager" # type: ignore
5969
"""
6070
Reference to the `PersonaManager` for this `YChat`, which manages this
61-
instance. Automatically set by `BasePersona`.
71+
instance. Automatically set by the `LoggingConfigurable` parent class.
6272
"""
6373

64-
config: ConfigManager
74+
config_manager: ConfigManager
6575
"""
6676
Reference to the `ConfigManager` singleton, which is used to read & write from
6777
the Jupyter AI settings. Automatically set by `BasePersona`.
6878
"""
6979

70-
log: Logger
80+
log: Logger # type: ignore
7181
"""
72-
The logger for this instance. Automatically set by `BasePersona`.
82+
The `logging.Logger` instance used by this class. Automatically set by the
83+
`LoggingConfigurable` parent class.
7384
"""
7485

7586
awareness: PersonaAwareness
@@ -92,22 +103,26 @@ class BasePersona(ABC):
92103
################################################
93104
def __init__(
94105
self,
95-
*,
106+
*args,
96107
ychat: YChat,
97-
manager: "PersonaManager",
98-
config: ConfigManager,
99-
log: Logger,
108+
config_manager: ConfigManager,
100109
message_interrupted: dict[str, asyncio.Event],
110+
**kwargs,
101111
):
112+
# Forward other arguments to parent class
113+
super().__init__(*args, **kwargs)
114+
115+
# Bind arguments to instance attributes
102116
self.ychat = ychat
103-
self.manager = manager
104-
self.config = config
105-
self.log = log
117+
self.config_manager = config_manager
106118
self.message_interrupted = message_interrupted
119+
120+
# Initialize custom awareness object for this persona
107121
self.awareness = PersonaAwareness(
108122
ychat=self.ychat, log=self.log, user=self.as_user()
109123
)
110124

125+
# Register this persona as a user in the chat
111126
self.ychat.set_user(self.as_user())
112127

113128
################################################
@@ -298,6 +313,22 @@ def send_message(self, body: str) -> None:
298313
"""
299314
self.ychat.add_message(NewMessage(body=body, sender=self.id))
300315

316+
def get_chat_path(self, relative: bool = False) -> str:
317+
"""
318+
Returns the absolute path of the chat file assigned to this persona.
319+
320+
To get a path relative to the `ContentsManager` root directory, call
321+
this method with `relative=True`.
322+
"""
323+
return self.parent.get_chat_path(relative=relative)
324+
325+
def get_chat_dir(self) -> str:
326+
"""
327+
Returns the absolute path to the parent directory of the chat file
328+
assigned to this persona.
329+
"""
330+
return self.parent.get_chat_dir()
331+
301332

302333
class GenerationInterrupted(asyncio.CancelledError):
303334
"""Exception raised when streaming is cancelled by the user"""

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def defaults(self):
2727
)
2828

2929
async def process_message(self, message: Message) -> None:
30-
provider_name = self.config.lm_provider.name
31-
model_id = self.config.lm_provider_params["model_id"]
30+
provider_name = self.config_manager.lm_provider.name
31+
model_id = self.config_manager.lm_provider_params["model_id"]
3232

3333
runnable = self.build_runnable()
3434
variables = JupyternautVariables(
@@ -43,7 +43,7 @@ async def process_message(self, message: Message) -> None:
4343

4444
def build_runnable(self) -> Any:
4545
# TODO: support model parameters. maybe we just add it to lm_provider_params in both 2.x and 3.x
46-
llm = self.config.lm_provider(**self.config.lm_provider_params)
46+
llm = self.config_manager.lm_provider(**self.config_manager.lm_provider_params)
4747
runnable = JUPYTERNAUT_PROMPT_TEMPLATE | llm | StrOutputParser()
4848

4949
runnable = RunnableWithMessageHistory(

packages/jupyter-ai/jupyter_ai/personas/persona_manager.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,92 @@
1+
from __future__ import annotations
2+
13
import asyncio
4+
import os
25
from logging import Logger
36
from time import time_ns
4-
from typing import TYPE_CHECKING, ClassVar, Optional
7+
from typing import TYPE_CHECKING, ClassVar
58

69
from importlib_metadata import entry_points
710
from jupyterlab_chat.models import Message
811
from jupyterlab_chat.ychat import YChat
12+
from traitlets.config import LoggingConfigurable
913

1014
from ..config_manager import ConfigManager
1115
from .base_persona import BasePersona
1216

1317
if TYPE_CHECKING:
1418
from asyncio import AbstractEventLoop
1519

20+
from jupyter_server_fileid.manager import ( # type: ignore[import-untyped]
21+
BaseFileIdManager,
22+
)
23+
1624
# EPG := entry point group
1725
EPG_NAME = "jupyter_ai.personas"
1826

1927

20-
class PersonaManager:
28+
class PersonaManager(LoggingConfigurable):
2129
"""
2230
Class that manages all personas for a single chat.
2331
"""
2432

2533
# instance attrs
2634
ychat: YChat
2735
config_manager: ConfigManager
28-
event_loop: "AbstractEventLoop"
29-
log: Logger
36+
fileid_manager: BaseFileIdManager
37+
root_dir: str
38+
event_loop: AbstractEventLoop
39+
40+
log: Logger # type: ignore
41+
"""
42+
The `logging.Logger` instance used by this class. This is automatically set
43+
by the `LoggingConfigurable` parent class; this declaration only hints the
44+
type for type checkers.
45+
"""
46+
3047
_personas: dict[str, BasePersona]
48+
file_id: str
3149

3250
# class attrs
33-
_persona_classes: ClassVar[Optional[list[type[BasePersona]]]] = None
51+
_persona_classes: ClassVar[list[type[BasePersona]] | None] = None
3452

3553
def __init__(
3654
self,
55+
*args,
56+
room_id: str,
3757
ychat: YChat,
3858
config_manager: ConfigManager,
39-
event_loop: "AbstractEventLoop",
40-
log: Logger,
59+
fileid_manager: BaseFileIdManager,
60+
root_dir: str,
61+
event_loop: AbstractEventLoop,
4162
message_interrupted: dict[str, asyncio.Event],
63+
**kwargs,
4264
):
65+
# Forward other arguments to parent class
66+
super().__init__(*args, **kwargs)
67+
68+
# Bind instance attributes
69+
self.room_id = room_id
4370
self.ychat = ychat
4471
self.config_manager = config_manager
72+
self.fileid_manager = fileid_manager
73+
self.root_dir = root_dir
4574
self.event_loop = event_loop
46-
self.log = log
4775
self.message_interrupted = message_interrupted
4876

77+
# Store file ID
78+
self.file_id = room_id.split(":")[2]
79+
80+
# Load persona classes from entry points.
81+
# This is stored in a class attribute (global to all instances) because
82+
# the entry points are immutable after the server starts, so they only
83+
# need to be loaded once.
4984
if not isinstance(PersonaManager._persona_classes, list):
5085
self._init_persona_classes()
5186
assert isinstance(PersonaManager._persona_classes, list)
5287

5388
self._personas = self._init_personas()
89+
self.log.error(self.get_chat_dir())
5490

5591
def _init_persona_classes(self) -> None:
5692
"""
@@ -124,10 +160,9 @@ def _init_personas(self) -> dict[str, BasePersona]:
124160
for Persona in persona_classes:
125161
try:
126162
persona = Persona(
163+
parent=self,
127164
ychat=self.ychat,
128-
manager=self,
129-
config=self.config_manager,
130-
log=self.log,
165+
config_manager=self.config_manager,
131166
message_interrupted=self.message_interrupted,
132167
)
133168
except Exception:
@@ -200,3 +235,28 @@ def route_message(self, new_message: Message):
200235
)
201236
for persona in mentioned_personas:
202237
self.event_loop.create_task(persona.process_message(new_message))
238+
239+
def get_chat_path(self, relative: bool = False) -> str:
240+
"""
241+
Returns the absolute path of the chat file assigned to this
242+
`PersonaManager`.
243+
244+
To get a path relative to the `ContentsManager` root directory, call
245+
this method with `relative=True`.
246+
"""
247+
relpath = self.fileid_manager.get_path(self.file_id)
248+
if not relpath:
249+
raise Exception(f"Unable to locate chat with file ID: '{self.file_id}'.")
250+
if relative:
251+
return relpath
252+
253+
abspath = os.path.join(self.root_dir, relpath)
254+
return abspath
255+
256+
def get_chat_dir(self) -> str:
257+
"""
258+
Returns the absolute path of the parent directory of the chat file
259+
assigned to this `PersonaManager`.
260+
"""
261+
abspath = self.get_chat_path(absolute=True)
262+
return os.path.dirname(abspath)

0 commit comments

Comments
 (0)