11
11
from concurrent .futures import ThreadPoolExecutor
12
12
from contextlib import asynccontextmanager , contextmanager
13
13
from contextvars import copy_context
14
+ from dataclasses import is_dataclass
14
15
from typing import (
15
16
TYPE_CHECKING ,
16
17
Any ,
37
38
)
38
39
from langchain_core .callbacks .stdout import StdOutCallbackHandler
39
40
from langchain_core .messages import BaseMessage , get_buffer_string
41
+ from langchain_core .messages .v1 import AIMessage , AIMessageChunk
42
+ from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , LLMResult
40
43
from langchain_core .tracers .schemas import Run
41
44
from langchain_core .utils .env import env_var_is_set
42
45
47
50
48
51
from langchain_core .agents import AgentAction , AgentFinish
49
52
from langchain_core .documents import Document
50
- from langchain_core .outputs import ChatGenerationChunk , GenerationChunk , LLMResult
53
+ from langchain_core .outputs import GenerationChunk
51
54
from langchain_core .runnables .config import RunnableConfig
52
55
53
56
logger = logging .getLogger (__name__ )
@@ -241,6 +244,22 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any:
241
244
return cast ("Func" , wrapped )
242
245
243
246
247
+ def _convert_llm_events (
248
+ event_name : str , args : tuple [Any , ...], kwargs : dict [str , Any ]
249
+ ) -> None :
250
+ if event_name == "on_chat_model_start" and isinstance (args [1 ], list ):
251
+ for idx , item in enumerate (args [1 ]):
252
+ if is_dataclass (item ):
253
+ args [1 ][idx ] = item # convert to old message
254
+ elif event_name == "on_llm_new_token" and is_dataclass (args [0 ]):
255
+ kwargs ["chunk" ] = ChatGenerationChunk (text = args [0 ].text , message = args [0 ])
256
+ args [0 ] = args [0 ].text
257
+ elif event_name == "on_llm_end" and is_dataclass (args [0 ]):
258
+ args [0 ] = LLMResult (
259
+ generations = [[ChatGeneration (text = args [0 ].text , message = args [0 ])]]
260
+ )
261
+
262
+
244
263
def handle_event (
245
264
handlers : list [BaseCallbackHandler ],
246
265
event_name : str ,
@@ -269,6 +288,8 @@ def handle_event(
269
288
if ignore_condition_name is None or not getattr (
270
289
handler , ignore_condition_name
271
290
):
291
+ if not handler .accepts_new_messages :
292
+ _convert_llm_events (event_name , args , kwargs )
272
293
event = getattr (handler , event_name )(* args , ** kwargs )
273
294
if asyncio .iscoroutine (event ):
274
295
coros .append (event )
@@ -363,6 +384,8 @@ async def _ahandle_event_for_handler(
363
384
) -> None :
364
385
try :
365
386
if ignore_condition_name is None or not getattr (handler , ignore_condition_name ):
387
+ if not handler .accepts_new_messages :
388
+ _convert_llm_events (event_name , args , kwargs )
366
389
event = getattr (handler , event_name )
367
390
if asyncio .iscoroutinefunction (event ):
368
391
await event (* args , ** kwargs )
@@ -670,9 +693,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
670
693
671
694
def on_llm_new_token (
672
695
self ,
673
- token : str ,
696
+ token : Union [ str , AIMessageChunk ] ,
674
697
* ,
675
- chunk : Optional [Union [GenerationChunk , ChatGenerationChunk ]] = None ,
698
+ chunk : Optional [
699
+ Union [GenerationChunk , ChatGenerationChunk , AIMessageChunk ]
700
+ ] = None ,
676
701
** kwargs : Any ,
677
702
) -> None :
678
703
"""Run when LLM generates a new token.
@@ -697,11 +722,11 @@ def on_llm_new_token(
697
722
** kwargs ,
698
723
)
699
724
700
- def on_llm_end (self , response : LLMResult , ** kwargs : Any ) -> None :
725
+ def on_llm_end (self , response : Union [ LLMResult , AIMessage ] , ** kwargs : Any ) -> None :
701
726
"""Run when LLM ends running.
702
727
703
728
Args:
704
- response (LLMResult): The LLM result.
729
+ response (LLMResult | AIMessage ): The LLM result.
705
730
**kwargs (Any): Additional keyword arguments.
706
731
"""
707
732
if not self .handlers :
@@ -727,8 +752,8 @@ def on_llm_error(
727
752
Args:
728
753
error (Exception or KeyboardInterrupt): The error.
729
754
kwargs (Any): Additional keyword arguments.
730
- - response (LLMResult): The response which was generated before
731
- the error occurred.
755
+ - response (LLMResult | AIMessage ): The response which was generated
756
+ before the error occurred.
732
757
"""
733
758
if not self .handlers :
734
759
return
@@ -766,9 +791,11 @@ def get_sync(self) -> CallbackManagerForLLMRun:
766
791
767
792
async def on_llm_new_token (
768
793
self ,
769
- token : str ,
794
+ token : Union [ str , AIMessageChunk ] ,
770
795
* ,
771
- chunk : Optional [Union [GenerationChunk , ChatGenerationChunk ]] = None ,
796
+ chunk : Optional [
797
+ Union [GenerationChunk , ChatGenerationChunk , AIMessageChunk ]
798
+ ] = None ,
772
799
** kwargs : Any ,
773
800
) -> None :
774
801
"""Run when LLM generates a new token.
@@ -794,11 +821,13 @@ async def on_llm_new_token(
794
821
)
795
822
796
823
@shielded
797
- async def on_llm_end (self , response : LLMResult , ** kwargs : Any ) -> None :
824
+ async def on_llm_end (
825
+ self , response : Union [LLMResult , AIMessage ], ** kwargs : Any
826
+ ) -> None :
798
827
"""Run when LLM ends running.
799
828
800
829
Args:
801
- response (LLMResult): The LLM result.
830
+ response (LLMResult | AIMessage ): The LLM result.
802
831
**kwargs (Any): Additional keyword arguments.
803
832
"""
804
833
if not self .handlers :
@@ -825,11 +854,8 @@ async def on_llm_error(
825
854
Args:
826
855
error (Exception or KeyboardInterrupt): The error.
827
856
kwargs (Any): Additional keyword arguments.
828
- - response (LLMResult): The response which was generated before
829
- the error occurred.
830
-
831
-
832
-
857
+ - response (LLMResult | AIMessage): The response which was generated
858
+ before the error occurred.
833
859
"""
834
860
if not self .handlers :
835
861
return
0 commit comments