Skip to content

Commit a6eb979

Browse files
authored
Backport PR #1257 on branch 2.x (Refactor Chat Handlers to Simplify Initialization) (#1266)
* Refactor Chat Handlers to Simplify Initialization (#1257) * simplify-entrypoints-loading * fix-lint * fix-tests * add-retriever-typing * remove-retriever-from-base * fix-circular-import(ydoc-import) * fix-tests * fix-type-check-failure * refactor-retriever-init (cherry picked from commit 055d7b8) * fix-test-backporting * fix-ydoc-unwanted-import * fix-test * remove-exception * reorder-chat-handlers-storage
1 parent 63dcc89 commit a6eb979

File tree

5 files changed

+30
-12
lines changed

5 files changed

+30
-12
lines changed

packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_core.prompts import PromptTemplate
99

1010
from .base import BaseChatHandler, SlashCommandRoutingType
11+
from .learn import LearnChatHandler, Retriever
1112

1213
PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
1314
@@ -18,6 +19,16 @@
1819
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)
1920

2021

22+
class CustomLearnException(Exception):
23+
"""Exception raised when Jupyter AI's /ask command is used without the required /learn command."""
24+
25+
def __init__(self):
26+
super().__init__(
27+
"Jupyter AI's default /ask command requires the default /learn command. "
28+
"If you are overriding /learn via the entry points API, be sure to also override or disable /ask."
29+
)
30+
31+
2132
class AskChatHandler(BaseChatHandler):
2233
"""Processes messages prefixed with /ask. This actor will
2334
send the message as input to a RetrieverQA chain, that
@@ -33,12 +44,16 @@ class AskChatHandler(BaseChatHandler):
3344

3445
uses_llm = True
3546

36-
def __init__(self, retriever, *args, **kwargs):
47+
def __init__(self, *args, **kwargs):
3748
super().__init__(*args, **kwargs)
3849

39-
self._retriever = retriever
4050
self.parser.prog = "/ask"
4151
self.parser.add_argument("query", nargs=argparse.REMAINDER)
52+
learn_chat_handler = self.chat_handlers.get("/learn")
53+
if not isinstance(learn_chat_handler, LearnChatHandler):
54+
raise CustomLearnException()
55+
56+
self._retriever = Retriever(learn_chat_handler=learn_chat_handler)
4257

4358
def create_llm_chain(
4459
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
@@ -51,6 +66,7 @@ def create_llm_chain(
5166
memory = ConversationBufferWindowMemory(
5267
memory_key="chat_history", return_messages=True, k=2
5368
)
69+
5470
self.llm_chain = ConversationalRetrievalChain.from_llm(
5571
self.llm,
5672
self._retriever,

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import time
66
import traceback
7+
from pathlib import Path
78
from typing import (
89
TYPE_CHECKING,
910
Any,
@@ -156,6 +157,7 @@ def __init__(
156157
chat_handlers: Dict[str, "BaseChatHandler"],
157158
context_providers: Dict[str, "BaseCommandContextProvider"],
158159
message_interrupted: Dict[str, asyncio.Event],
160+
log_dir: Optional[str],
159161
):
160162
self.log = log
161163
self.config_manager = config_manager
@@ -178,6 +180,7 @@ def __init__(
178180
self.chat_handlers = chat_handlers
179181
self.context_providers = context_providers
180182
self.message_interrupted = message_interrupted
183+
self.log_dir = Path(log_dir) if log_dir else None
181184

182185
self.llm: Optional[BaseProvider] = None
183186
self.llm_params: Optional[dict] = None

packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,8 @@ class GenerateChatHandler(BaseChatHandler):
253253

254254
uses_llm = True
255255

256-
def __init__(self, log_dir: Optional[str], *args, **kwargs):
256+
def __init__(self, *args, **kwargs):
257257
super().__init__(*args, **kwargs)
258-
self.log_dir = Path(log_dir) if log_dir else None
259258
self.llm: Optional[BaseProvider] = None
260259

261260
def create_llm_chain(

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from dask.distributed import Client as DaskClient
77
from importlib_metadata import entry_points
8-
from jupyter_ai.chat_handlers.learn import Retriever
98
from jupyter_ai_magics import BaseProvider, JupyternautPersona
109
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
1110
from jupyter_server.extension.application import ExtensionApp
@@ -367,16 +366,17 @@ def _init_chat_handlers(self):
367366
"chat_handlers": chat_handlers,
368367
"context_providers": self.settings["jai_context_providers"],
369368
"message_interrupted": self.settings["jai_message_interrupted"],
369+
"log_dir": self.error_logs_dir,
370370
}
371+
371372
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
373+
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
372374
clear_chat_handler = ClearChatHandler(**chat_handler_kwargs)
373-
generate_chat_handler = GenerateChatHandler(
374-
**chat_handler_kwargs,
375-
log_dir=self.error_logs_dir,
376-
)
377375
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
378-
retriever = Retriever(learn_chat_handler=learn_chat_handler)
379-
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
376+
# Store learn_chat_handler before initializing AskChatHandler,
377+
# as it is required for initializing the Retriever.
378+
chat_handlers["/learn"] = learn_chat_handler
379+
ask_chat_handler = AskChatHandler(**chat_handler_kwargs)
380380

381381
export_chat_handler = ExportChatHandler(**chat_handler_kwargs)
382382

@@ -386,7 +386,6 @@ def _init_chat_handlers(self):
386386
chat_handlers["/ask"] = ask_chat_handler
387387
chat_handlers["/clear"] = clear_chat_handler
388388
chat_handlers["/generate"] = generate_chat_handler
389-
chat_handlers["/learn"] = learn_chat_handler
390389
chat_handlers["/export"] = export_chat_handler
391390
chat_handlers["/fix"] = fix_chat_handler
392391

packages/jupyter-ai/jupyter_ai/tests/test_handlers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def broadcast_message(message: Message) -> None:
7878
chat_handlers={},
7979
context_providers={},
8080
message_interrupted={},
81+
log_dir="",
8182
)
8283

8384

0 commit comments

Comments
 (0)