Skip to content

Commit 3d9e694

Browse files
ccurmenfcampos
andauthored
feat(core): start on v1 chat model (#32276)
Co-authored-by: Nuno Campos <nuno@langchain.dev>
1 parent c921d08 commit 3d9e694

File tree

9 files changed

+1146
-39
lines changed

9 files changed

+1146
-39
lines changed

libs/core/langchain_core/callbacks/base.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from typing_extensions import Self
99

10+
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
11+
1012
if TYPE_CHECKING:
1113
from collections.abc import Sequence
1214
from uuid import UUID
@@ -64,9 +66,11 @@ class LLMManagerMixin:
6466

6567
def on_llm_new_token(
6668
self,
67-
token: str,
69+
token: Union[str, AIMessageChunk],
6870
*,
69-
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
71+
chunk: Optional[
72+
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
73+
] = None,
7074
run_id: UUID,
7175
parent_run_id: Optional[UUID] = None,
7276
**kwargs: Any,
@@ -75,16 +79,16 @@ def on_llm_new_token(
7579
7680
Args:
7781
token (str): The new token.
78-
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
79-
containing content and other information.
82+
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
83+
generated chunk, containing content and other information.
8084
run_id (UUID): The run ID. This is the ID of the current run.
8185
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
8286
kwargs (Any): Additional keyword arguments.
8387
"""
8488

8589
def on_llm_end(
8690
self,
87-
response: LLMResult,
91+
response: Union[LLMResult, AIMessage],
8892
*,
8993
run_id: UUID,
9094
parent_run_id: Optional[UUID] = None,
@@ -93,7 +97,7 @@ def on_llm_end(
9397
"""Run when LLM ends running.
9498
9599
Args:
96-
response (LLMResult): The response which was generated.
100+
response (LLMResult | AIMessage): The response which was generated.
97101
run_id (UUID): The run ID. This is the ID of the current run.
98102
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
99103
kwargs (Any): Additional keyword arguments.
@@ -261,7 +265,7 @@ def on_llm_start(
261265
def on_chat_model_start(
262266
self,
263267
serialized: dict[str, Any],
264-
messages: list[list[BaseMessage]],
268+
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
265269
*,
266270
run_id: UUID,
267271
parent_run_id: Optional[UUID] = None,
@@ -439,6 +443,9 @@ class BaseCallbackHandler(
439443
run_inline: bool = False
440444
"""Whether to run the callback inline."""
441445

446+
accepts_new_messages: bool = False
447+
"""Whether the callback accepts new message format."""
448+
442449
@property
443450
def ignore_llm(self) -> bool:
444451
"""Whether to ignore LLM callbacks."""
@@ -509,7 +516,7 @@ async def on_llm_start(
509516
async def on_chat_model_start(
510517
self,
511518
serialized: dict[str, Any],
512-
messages: list[list[BaseMessage]],
519+
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
513520
*,
514521
run_id: UUID,
515522
parent_run_id: Optional[UUID] = None,
@@ -538,9 +545,11 @@ async def on_chat_model_start(
538545

539546
async def on_llm_new_token(
540547
self,
541-
token: str,
548+
token: Union[str, AIMessageChunk],
542549
*,
543-
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
550+
chunk: Optional[
551+
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
552+
] = None,
544553
run_id: UUID,
545554
parent_run_id: Optional[UUID] = None,
546555
tags: Optional[list[str]] = None,
@@ -550,8 +559,8 @@ async def on_llm_new_token(
550559
551560
Args:
552561
token (str): The new token.
553-
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
554-
containing content and other information.
562+
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
563+
generated chunk, containing content and other information.
555564
run_id (UUID): The run ID. This is the ID of the current run.
556565
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
557566
tags (Optional[list[str]]): The tags.
@@ -560,7 +569,7 @@ async def on_llm_new_token(
560569

561570
async def on_llm_end(
562571
self,
563-
response: LLMResult,
572+
response: Union[LLMResult, AIMessage],
564573
*,
565574
run_id: UUID,
566575
parent_run_id: Optional[UUID] = None,
@@ -570,7 +579,7 @@ async def on_llm_end(
570579
"""Run when LLM ends running.
571580
572581
Args:
573-
response (LLMResult): The response which was generated.
582+
response (LLMResult | AIMessage): The response which was generated.
574583
run_id (UUID): The run ID. This is the ID of the current run.
575584
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
576585
tags (Optional[list[str]]): The tags.
@@ -594,8 +603,8 @@ async def on_llm_error(
594603
parent_run_id: The parent run ID. This is the ID of the parent run.
595604
tags: The tags.
596605
kwargs (Any): Additional keyword arguments.
597-
- response (LLMResult): The response which was generated before
598-
the error occurred.
606+
- response (LLMResult | AIMessage): The response which was generated
607+
before the error occurred.
599608
"""
600609

601610
async def on_chain_start(

libs/core/langchain_core/callbacks/manager.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from concurrent.futures import ThreadPoolExecutor
1212
from contextlib import asynccontextmanager, contextmanager
1313
from contextvars import copy_context
14+
from dataclasses import is_dataclass
1415
from typing import (
1516
TYPE_CHECKING,
1617
Any,
@@ -37,6 +38,8 @@
3738
)
3839
from langchain_core.callbacks.stdout import StdOutCallbackHandler
3940
from langchain_core.messages import BaseMessage, get_buffer_string
41+
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
42+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
4043
from langchain_core.tracers.schemas import Run
4144
from langchain_core.utils.env import env_var_is_set
4245

@@ -47,7 +50,7 @@
4750

4851
from langchain_core.agents import AgentAction, AgentFinish
4952
from langchain_core.documents import Document
50-
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
53+
from langchain_core.outputs import GenerationChunk
5154
from langchain_core.runnables.config import RunnableConfig
5255

5356
logger = logging.getLogger(__name__)
@@ -241,6 +244,22 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any:
241244
return cast("Func", wrapped)
242245

243246

247+
def _convert_llm_events(
248+
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
249+
) -> None:
250+
if event_name == "on_chat_model_start" and isinstance(args[1], list):
251+
for idx, item in enumerate(args[1]):
252+
if is_dataclass(item):
253+
args[1][idx] = item # convert to old message
254+
elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
255+
kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
256+
args[0] = args[0].text
257+
elif event_name == "on_llm_end" and is_dataclass(args[0]):
258+
args[0] = LLMResult(
259+
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
260+
)
261+
262+
244263
def handle_event(
245264
handlers: list[BaseCallbackHandler],
246265
event_name: str,
@@ -269,6 +288,8 @@ def handle_event(
269288
if ignore_condition_name is None or not getattr(
270289
handler, ignore_condition_name
271290
):
291+
if not handler.accepts_new_messages:
292+
_convert_llm_events(event_name, args, kwargs)
272293
event = getattr(handler, event_name)(*args, **kwargs)
273294
if asyncio.iscoroutine(event):
274295
coros.append(event)
@@ -363,6 +384,8 @@ async def _ahandle_event_for_handler(
363384
) -> None:
364385
try:
365386
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
387+
if not handler.accepts_new_messages:
388+
_convert_llm_events(event_name, args, kwargs)
366389
event = getattr(handler, event_name)
367390
if asyncio.iscoroutinefunction(event):
368391
await event(*args, **kwargs)
@@ -670,9 +693,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
670693

671694
def on_llm_new_token(
672695
self,
673-
token: str,
696+
token: Union[str, AIMessageChunk],
674697
*,
675-
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
698+
chunk: Optional[
699+
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
700+
] = None,
676701
**kwargs: Any,
677702
) -> None:
678703
"""Run when LLM generates a new token.
@@ -697,11 +722,11 @@ def on_llm_new_token(
697722
**kwargs,
698723
)
699724

700-
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
725+
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
701726
"""Run when LLM ends running.
702727
703728
Args:
704-
response (LLMResult): The LLM result.
729+
response (LLMResult | AIMessage): The LLM result.
705730
**kwargs (Any): Additional keyword arguments.
706731
"""
707732
if not self.handlers:
@@ -727,8 +752,8 @@ def on_llm_error(
727752
Args:
728753
error (Exception or KeyboardInterrupt): The error.
729754
kwargs (Any): Additional keyword arguments.
730-
- response (LLMResult): The response which was generated before
731-
the error occurred.
755+
- response (LLMResult | AIMessage): The response which was generated
756+
before the error occurred.
732757
"""
733758
if not self.handlers:
734759
return
@@ -766,9 +791,11 @@ def get_sync(self) -> CallbackManagerForLLMRun:
766791

767792
async def on_llm_new_token(
768793
self,
769-
token: str,
794+
token: Union[str, AIMessageChunk],
770795
*,
771-
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
796+
chunk: Optional[
797+
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
798+
] = None,
772799
**kwargs: Any,
773800
) -> None:
774801
"""Run when LLM generates a new token.
@@ -794,11 +821,13 @@ async def on_llm_new_token(
794821
)
795822

796823
@shielded
797-
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
824+
async def on_llm_end(
825+
self, response: Union[LLMResult, AIMessage], **kwargs: Any
826+
) -> None:
798827
"""Run when LLM ends running.
799828
800829
Args:
801-
response (LLMResult): The LLM result.
830+
response (LLMResult | AIMessage): The LLM result.
802831
**kwargs (Any): Additional keyword arguments.
803832
"""
804833
if not self.handlers:
@@ -825,11 +854,8 @@ async def on_llm_error(
825854
Args:
826855
error (Exception or KeyboardInterrupt): The error.
827856
kwargs (Any): Additional keyword arguments.
828-
- response (LLMResult): The response which was generated before
829-
the error occurred.
830-
831-
832-
857+
- response (LLMResult | AIMessage): The response which was generated
858+
before the error occurred.
833859
"""
834860
if not self.handlers:
835861
return

libs/core/langchain_core/language_models/_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import re
23
from collections.abc import Sequence
34
from typing import Optional
@@ -127,7 +128,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
127128
and _is_openai_data_block(block)
128129
):
129130
if formatted_message is message:
130-
formatted_message = message.model_copy()
131+
if isinstance(message, BaseMessage):
132+
formatted_message = message.model_copy()
133+
else:
134+
formatted_message = copy.copy(message)
131135
# Also shallow-copy content
132136
formatted_message.content = list(formatted_message.content)
133137

libs/core/langchain_core/language_models/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MessageLikeRepresentation,
2929
get_buffer_string,
3030
)
31+
from langchain_core.messages.v1 import AIMessage as AIMessageV1
3132
from langchain_core.prompt_values import PromptValue
3233
from langchain_core.runnables import Runnable, RunnableSerializable
3334
from langchain_core.utils import get_pydantic_field_names
@@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
8586
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
8687
LanguageModelOutput = Union[BaseMessage, str]
8788
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
88-
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
89+
LanguageModelOutputVar = TypeVar(
90+
"LanguageModelOutputVar", BaseMessage, str, AIMessageV1
91+
)
8992

9093

9194
def _get_verbosity() -> bool:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""LangChain v1.0 chat models."""

0 commit comments

Comments
 (0)