11from enum import Enum
2- from typing import Any , cast
2+ from typing import Annotated , Any , Literal , cast , get_args , get_origin , overload
33
4- from pydantic import BaseModel , ConfigDict , Field
4+ from pydantic import BaseModel , ConfigDict , Field , RootModel
55
66from ragbits .chat .interface .forms import UserSettings
77from ragbits .chat .interface .ui_customization import UICustomization
@@ -133,13 +133,217 @@ class ChatContext(BaseModel):
133133 model_config = ConfigDict (extra = "allow" )
134134
135135
136- class ChatResponse (BaseModel ):
137- """Container for different types of chat responses."""
136+ _CHAT_RESPONSE_REGISTRY : dict [ChatResponseType , type [BaseModel ]] = {}
137+
138+
139+ class ChatResponseBase (BaseModel ):
140+ """Base class for all ChatResponse variants with auto-registration."""
138141
139142 type : ChatResponseType
140- content : (
141- str | Reference | StateUpdate | LiveUpdate | list [str ] | Image | dict [str , MessageUsage ] | ChunkedContent | None
142- )
143+
144+ def __init_subclass__ (cls , ** kwargs : Any ):
145+ super ().__init_subclass__ (** kwargs )
146+ type_ann = cls .model_fields ["type" ].annotation
147+ origin = get_origin (type_ann )
148+ value = get_args (type_ann )[0 ] if origin is Literal else getattr (cls , "type" , None )
149+
150+ if value is None :
151+ raise ValueError (f"Cannot determine ChatResponseType for { cls .__name__ } " )
152+
153+ _CHAT_RESPONSE_REGISTRY [value ] = cls
154+
155+
156+ class TextChatResponse (ChatResponseBase ):
157+ """Represents text chat response"""
158+
159+ type : Literal [ChatResponseType .TEXT ] = ChatResponseType .TEXT
160+ content : str
161+
162+
163+ class ReferenceChatResponse (ChatResponseBase ):
164+ """Represents reference chat response"""
165+
166+ type : Literal [ChatResponseType .REFERENCE ] = ChatResponseType .REFERENCE
167+ content : Reference
168+
169+
170+ class StateUpdateChatResponse (ChatResponseBase ):
171+ """Represents state update chat response"""
172+
173+ type : Literal [ChatResponseType .STATE_UPDATE ] = ChatResponseType .STATE_UPDATE
174+ content : StateUpdate
175+
176+
177+ class ConversationIdChatResponse (ChatResponseBase ):
178+ """Represents conversation_id chat response"""
179+
180+ type : Literal [ChatResponseType .CONVERSATION_ID ] = ChatResponseType .CONVERSATION_ID
181+ content : str
182+
183+
184+ class LiveUpdateChatResponse (ChatResponseBase ):
185+ """Represents live update chat response"""
186+
187+ type : Literal [ChatResponseType .LIVE_UPDATE ] = ChatResponseType .LIVE_UPDATE
188+ content : LiveUpdate
189+
190+
191+ class FollowupMessagesChatResponse (ChatResponseBase ):
192+ """Represents followup messages chat response"""
193+
194+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ] = ChatResponseType .FOLLOWUP_MESSAGES
195+ content : list [str ]
196+
197+
198+ class ImageChatResponse (ChatResponseBase ):
199+ """Represents image chat response"""
200+
201+ type : Literal [ChatResponseType .IMAGE ] = ChatResponseType .IMAGE
202+ content : Image
203+
204+
205+ class ClearMessageChatResponse (ChatResponseBase ):
206+ """Represents clear message event"""
207+
208+ type : Literal [ChatResponseType .CLEAR_MESSAGE ] = ChatResponseType .CLEAR_MESSAGE
209+ content : None = None
210+
211+
212+ class UsageChatResponse (ChatResponseBase ):
213+ """Represents usage chat response"""
214+
215+ type : Literal [ChatResponseType .USAGE ] = ChatResponseType .USAGE
216+ content : dict [str , MessageUsage ]
217+
218+
219+ class MessageIdChatResponse (ChatResponseBase ):
220+ """Represents message_id chat response"""
221+
222+ type : Literal [ChatResponseType .MESSAGE_ID ] = ChatResponseType .MESSAGE_ID
223+ content : str
224+
225+
226+ class ChunkedContentChatResponse (ChatResponseBase ):
227+ """Represents chunked_content event that contains chunked event of different type"""
228+
229+ type : Literal [ChatResponseType .CHUNKED_CONTENT ] = ChatResponseType .CHUNKED_CONTENT
230+ content : ChunkedContent
231+
232+
233+ ChatResponseUnion = Annotated [
234+ TextChatResponse
235+ | ReferenceChatResponse
236+ | StateUpdateChatResponse
237+ | ConversationIdChatResponse
238+ | LiveUpdateChatResponse
239+ | FollowupMessagesChatResponse
240+ | ImageChatResponse
241+ | ClearMessageChatResponse
242+ | UsageChatResponse
243+ | MessageIdChatResponse
244+ | ChunkedContentChatResponse ,
245+ Field (discriminator = "type" ),
246+ ]
247+
248+
249+ class ChatResponse (RootModel [ChatResponseUnion ]):
250+ """Container for different types of chat responses."""
251+
252+ root : ChatResponseUnion
253+
254+ @property
255+ def content (self ) -> object :
256+ """Returns content of a response, use dedicated `as_*` methods to get type hints."""
257+ return self .root .content
258+
259+ @property
260+ def type (self ) -> ChatResponseType :
261+ """Returns type of the ChatResponse"""
262+ return self .root .type
263+
264+ @overload
265+ def __init__ (
266+ self ,
267+ type : Literal [ChatResponseType .TEXT ],
268+ content : str ,
269+ ) -> None : ...
270+ @overload
271+ def __init__ (
272+ self ,
273+ type : Literal [ChatResponseType .REFERENCE ],
274+ content : Reference ,
275+ ) -> None : ...
276+ @overload
277+ def __init__ (
278+ self ,
279+ type : Literal [ChatResponseType .STATE_UPDATE ],
280+ content : StateUpdate ,
281+ ) -> None : ...
282+ @overload
283+ def __init__ (
284+ self ,
285+ type : Literal [ChatResponseType .CONVERSATION_ID ],
286+ content : str ,
287+ ) -> None : ...
288+ @overload
289+ def __init__ (
290+ self ,
291+ type : Literal [ChatResponseType .LIVE_UPDATE ],
292+ content : LiveUpdate ,
293+ ) -> None : ...
294+ @overload
295+ def __init__ (
296+ self ,
297+ type : Literal [ChatResponseType .FOLLOWUP_MESSAGES ],
298+ content : list [str ],
299+ ) -> None : ...
300+ @overload
301+ def __init__ (
302+ self ,
303+ type : Literal [ChatResponseType .IMAGE ],
304+ content : Image ,
305+ ) -> None : ...
306+ @overload
307+ def __init__ (
308+ self ,
309+ type : Literal [ChatResponseType .CLEAR_MESSAGE ],
310+ content : None ,
311+ ) -> None : ...
312+ @overload
313+ def __init__ (
314+ self ,
315+ type : Literal [ChatResponseType .USAGE ],
316+ content : dict [str , MessageUsage ],
317+ ) -> None : ...
318+ @overload
319+ def __init__ (
320+ self ,
321+ type : Literal [ChatResponseType .MESSAGE_ID ],
322+ content : str ,
323+ ) -> None : ...
324+ @overload
325+ def __init__ (
326+ self ,
327+ type : Literal [ChatResponseType .CHUNKED_CONTENT ],
328+ content : ChunkedContent ,
329+ ) -> None : ...
330+ def __init__ (
331+ self ,
332+ type : ChatResponseType ,
333+ content : Any ,
334+ ) -> None :
335+ """
336+ Backward-compatible constructor.
337+
338+ Allows creating a ChatResponse directly with:
339+ ChatResponse(type=ChatResponseType.TEXT, content="hello")
340+ """
341+ model_cls = _CHAT_RESPONSE_REGISTRY .get (type )
342+ if model_cls is None :
343+ raise ValueError (f"Unsupported ChatResponseType: { type } " )
344+
345+ model_instance = model_cls (type = type , content = content )
346+ super ().__init__ (root = cast (ChatResponseUnion , model_instance ))
143347
144348 def as_text (self ) -> str | None :
145349 """
@@ -149,7 +353,7 @@ def as_text(self) -> str | None:
149353 if text := response.as_text():
150354 print(f"Got text: {text}")
151355 """
152- return str ( self .content ) if self .type == ChatResponseType . TEXT else None
356+ return self .root . content if isinstance ( self .root , TextChatResponse ) else None
153357
154358 def as_reference (self ) -> Reference | None :
155359 """
@@ -159,7 +363,7 @@ def as_reference(self) -> Reference | None:
159363 if ref := response.as_reference():
160364 print(f"Got reference: {ref.title}")
161365 """
162- return cast ( Reference , self .content ) if self .type == ChatResponseType . REFERENCE else None
366+ return self .root . content if isinstance ( self .root , ReferenceChatResponse ) else None
163367
164368 def as_state_update (self ) -> StateUpdate | None :
165369 """
@@ -169,13 +373,13 @@ def as_state_update(self) -> StateUpdate | None:
169373 if state_update := response.as_state_update():
170374 state = verify_state(state_update)
171375 """
172- return cast ( StateUpdate , self .content ) if self .type == ChatResponseType . STATE_UPDATE else None
376+ return self .root . content if isinstance ( self .root , StateUpdateChatResponse ) else None
173377
174378 def as_conversation_id (self ) -> str | None :
175379 """
176380 Return the content as ConversationID if this is a conversation id, else None.
177381 """
178- return cast ( str , self .content ) if self .type == ChatResponseType . CONVERSATION_ID else None
382+ return self .root . content if isinstance ( self .root , ConversationIdChatResponse ) else None
179383
180384 def as_live_update (self ) -> LiveUpdate | None :
181385 """
@@ -185,7 +389,7 @@ def as_live_update(self) -> LiveUpdate | None:
185389 if live_update := response.as_live_update():
186390 print(f"Got live update: {live_update.content.label}")
187391 """
188- return cast ( LiveUpdate , self .content ) if self .type == ChatResponseType . LIVE_UPDATE else None
392+ return self .root . content if isinstance ( self .root , LiveUpdateChatResponse ) else None
189393
190394 def as_followup_messages (self ) -> list [str ] | None :
191395 """
@@ -195,25 +399,25 @@ def as_followup_messages(self) -> list[str] | None:
195399 if followup_messages := response.as_followup_messages():
196400 print(f"Got followup messages: {followup_messages}")
197401 """
198- return cast ( list [ str ], self .content ) if self .type == ChatResponseType . FOLLOWUP_MESSAGES else None
402+ return self .root . content if isinstance ( self .root , FollowupMessagesChatResponse ) else None
199403
200404 def as_image (self ) -> Image | None :
201405 """
202406 Return the content as Image if this is an image response, else None.
203407 """
204- return cast ( Image , self .content ) if self .type == ChatResponseType . IMAGE else None
408+ return self .root . content if isinstance ( self .root , ImageChatResponse ) else None
205409
206410 def as_clear_message (self ) -> None :
207411 """
208412 Return the content of clear_message response, which is None
209413 """
210- return cast ( None , self .content )
414+ return self .root . content if isinstance ( self . root , ClearMessageChatResponse ) else None
211415
212416 def as_usage (self ) -> dict [str , MessageUsage ] | None :
213417 """
214418 Return the content as dict from model name to Usage if this is an usage response, else None
215419 """
216- return cast ( dict [ str , MessageUsage ], self .content ) if self .type == ChatResponseType . USAGE else None
420+ return self .root . content if isinstance ( self .root , UsageChatResponse ) else None
217421
218422
219423class ChatMessageRequest (BaseModel ):
0 commit comments