diff --git a/llm-service/app/ai/indexing/summary_indexer.py b/llm-service/app/ai/indexing/summary_indexer.py index 9189f5045..7d1290df3 100644 --- a/llm-service/app/ai/indexing/summary_indexer.py +++ b/llm-service/app/ai/indexing/summary_indexer.py @@ -218,13 +218,17 @@ def create_storage_context( @classmethod def get_all_data_source_summaries(cls) -> dict[str, str]: root_dir = cls.__persist_root_dir() - # if not os.path.exists(root_dir): - # return {} - storage_context = SummaryIndexer.create_storage_context( - persist_dir=root_dir, - vector_store=SimpleVectorStore(), - ) - indices = load_indices_from_storage(storage_context=storage_context, index_ids=None, + try: + storage_context = SummaryIndexer.create_storage_context( + persist_dir=root_dir, + vector_store=SimpleVectorStore(), + ) + except FileNotFoundError: + # If the directory doesn't exist, we don't have any summaries. + return {} + indices = load_indices_from_storage( + storage_context=storage_context, + index_ids=None, **{ "llm": models.LLM.get_noop(), "response_synthesizer": models.LLM.get_noop(), @@ -234,11 +238,13 @@ def get_all_data_source_summaries(cls) -> dict[str, str]: "summary_query": "None", "data_source_id": 0, }, - ) + ) if len(indices) == 0: return {} - global_summary_store: DocumentSummaryIndex = cast(DocumentSummaryIndex, indices[0]) + global_summary_store: DocumentSummaryIndex = cast( + DocumentSummaryIndex, indices[0] + ) summary_ids = global_summary_store.index_struct.doc_id_to_summary_id.values() nodes = global_summary_store.docstore.get_nodes(list(summary_ids)) diff --git a/llm-service/app/routers/index/chat/__init__.py b/llm-service/app/routers/index/chat/__init__.py index 2302d5afd..e85cdad94 100644 --- a/llm-service/app/routers/index/chat/__init__.py +++ b/llm-service/app/routers/index/chat/__init__.py @@ -42,7 +42,7 @@ from pydantic import BaseModel from app import exceptions -from app.services.chat import generate_suggested_questions +from app.services.chat.suggested_questions import generate_suggested_questions logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["Chat"]) diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 83a38de1c..4223500cf 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -38,15 +38,17 @@ import base64 import json import logging -from typing import Optional +from typing import Optional, Generator -from fastapi import APIRouter, Header +from fastapi import APIRouter, Header, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel +from app.services.chat.streaming_chat import stream_chat from .... import exceptions from ....rag_types import RagPredictConfiguration -from ....services.chat import ( - v2_chat, +from ....services.chat.chat import ( + chat as run_chat, ) from ....services.chat_history.chat_history_manager import ( RagStudioChatMessage, @@ -100,6 +102,24 @@ def chat_history( ) +@router.get( + "/chat-history/{message_id}", + summary="Returns a specific chat messages for the provided session.", +) +@exceptions.propagates +def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage: + results: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history( + session_id=session_id + ) + for message in results: + if message.id == message_id: + return message + raise HTTPException( + status_code=404, + detail=f"Message with id {message_id} not found in session {session_id}", + ) + + @router.delete( "/chat-history", summary="Deletes the chat history for the provided session." ) @@ -161,6 +181,10 @@ class RagStudioChatRequest(BaseModel): configuration: RagPredictConfiguration | None = None +class StreamCompletionRequest(BaseModel): + query: str + + def parse_jwt_cookie(jwt_cookie: str | None) -> str: if jwt_cookie is None: return "unknown" @@ -187,4 +211,34 @@ def chat( session = session_metadata_api.get_session(session_id, user_name=origin_remote_user) configuration = request.configuration or RagPredictConfiguration() - return v2_chat(session, request.query, configuration, user_name=origin_remote_user) + return run_chat(session, request.query, configuration, user_name=origin_remote_user) + + +@router.post( + "/stream-completion", summary="Stream completion responses for the given query" +) +@exceptions.propagates +def stream_chat_completion( + session_id: int, + request: RagStudioChatRequest, + origin_remote_user: Optional[str] = Header(None), +) -> StreamingResponse: + session = session_metadata_api.get_session(session_id, user_name=origin_remote_user) + configuration = request.configuration or RagPredictConfiguration() + + def generate_stream() -> Generator[str, None, None]: + response_id: str = "" + try: + for response in stream_chat( + session, request.query, configuration, user_name=origin_remote_user + ): + print(response) + response_id = response.additional_kwargs["response_id"] + json_delta = json.dumps({"text": response.delta}) + yield f"data: {json_delta}" + "\n\n" + yield f'data: {{"response_id" : "{response_id}"}}\n\n' + except Exception as e: + logger.exception("Failed to stream chat completion") + yield f'data: {{"error" : "{e}"}}\n\n' + + return StreamingResponse(generate_stream(), media_type="text/event-stream") diff --git a/llm-service/app/services/chat/__init__.py b/llm-service/app/services/chat/__init__.py new file mode 100644 index 000000000..9c5987844 --- /dev/null +++ b/llm-service/app/services/chat/__init__.py @@ -0,0 +1,38 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + diff --git a/llm-service/app/services/chat/chat.py b/llm-service/app/services/chat/chat.py new file mode 100644 index 000000000..0256d4ae3 --- /dev/null +++ b/llm-service/app/services/chat/chat.py @@ -0,0 +1,171 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + +import time +import uuid +from typing import Optional + +from fastapi import HTTPException + +from app.services import evaluators, llm_completion +from app.services.chat.utils import retrieve_chat_history, format_source_nodes +from app.services.chat_history.chat_history_manager import ( + Evaluation, + RagMessage, + RagStudioChatMessage, + chat_history_manager, +) +from app.services.metadata_apis.session_metadata_api import Session +from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run +from app.services.query import querier +from app.services.query.query_configuration import QueryConfiguration +from app.ai.vector_stores.vector_store_factory import VectorStoreFactory +from app.rag_types import RagPredictConfiguration + + +def chat( + session: Session, + query: str, + configuration: RagPredictConfiguration, + user_name: Optional[str], +) -> RagStudioChatMessage: + query_configuration = QueryConfiguration( + top_k=session.response_chunks, + model_name=session.inference_model, + rerank_model_name=session.rerank_model, + exclude_knowledge_base=configuration.exclude_knowledge_base, + use_question_condensing=configuration.use_question_condensing, + use_hyde=session.query_configuration.enable_hyde, + use_summary_filter=session.query_configuration.enable_summary_filter, + ) + + response_id = str(uuid.uuid4()) + + if configuration.exclude_knowledge_base or len(session.data_source_ids) == 0: + return direct_llm_chat(session, response_id, query, user_name) + + total_data_sources_size: int = sum( + map( + lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0, + session.data_source_ids, + ) + ) + if total_data_sources_size == 0: + return direct_llm_chat(session, response_id, query, user_name) + + new_chat_message: RagStudioChatMessage = _run_chat( + session, response_id, query, query_configuration, user_name + ) + + chat_history_manager.append_to_history(session.id, [new_chat_message]) + return new_chat_message + + +def _run_chat( + session: Session, + response_id: str, + query: str, + query_configuration: QueryConfiguration, + user_name: Optional[str], +) -> RagStudioChatMessage: + if len(session.data_source_ids) != 1: + raise HTTPException( + status_code=400, detail="Only one datasource is supported for chat." + ) + + data_source_id: int = session.data_source_ids[0] + response, condensed_question = querier.query( + data_source_id, + query, + query_configuration, + retrieve_chat_history(session.id), + ) + if condensed_question and (condensed_question.strip() == query.strip()): + condensed_question = None + relevance, faithfulness = evaluators.evaluate_response( + query, response, session.inference_model + ) + response_source_nodes = format_source_nodes(response, data_source_id) + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=response_source_nodes, + inference_model=session.inference_model, + rag_message=RagMessage( + user=query, + assistant=response.response, + ), + evaluations=[ + Evaluation(name="relevance", value=relevance), + Evaluation(name="faithfulness", value=faithfulness), + ], + timestamp=time.time(), + condensed_question=condensed_question, + ) + + record_rag_mlflow_run( + new_chat_message, query_configuration, response_id, session, user_name + ) + return new_chat_message + + +def direct_llm_chat( + session: Session, response_id: str, query: str, user_name: Optional[str] +) -> RagStudioChatMessage: + record_direct_llm_mlflow_run(response_id, session, user_name) + + chat_response = llm_completion.completion( + session.id, query, session.inference_model + ) + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=[], + inference_model=session.inference_model, + evaluations=[], + rag_message=RagMessage( + user=query, + assistant=str(chat_response.message.content), + ), + timestamp=time.time(), + condensed_question=None, + ) + chat_history_manager.append_to_history(session.id, [new_chat_message]) + return new_chat_message + + diff --git a/llm-service/app/services/chat/streaming_chat.py b/llm-service/app/services/chat/streaming_chat.py new file mode 100644 index 000000000..ff173448b --- /dev/null +++ b/llm-service/app/services/chat/streaming_chat.py @@ -0,0 +1,186 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + +import time +import uuid +from typing import Optional, Generator + +from fastapi import HTTPException +from llama_index.core.base.llms.types import ChatResponse, ChatMessage +from llama_index.core.chat_engine.types import AgentChatResponse + +from app.ai.vector_stores.vector_store_factory import VectorStoreFactory +from app.rag_types import RagPredictConfiguration +from app.services import evaluators, llm_completion +from app.services.chat.utils import retrieve_chat_history, format_source_nodes +from app.services.chat_history.chat_history_manager import ( + RagStudioChatMessage, + RagMessage, + Evaluation, + chat_history_manager, +) +from app.services.metadata_apis.session_metadata_api import Session +from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run +from app.services.query import querier +from app.services.query.query_configuration import QueryConfiguration + + +def stream_chat( + session: Session, + query: str, + configuration: RagPredictConfiguration, + user_name: Optional[str], +) -> Generator[ChatResponse, None, None]: + query_configuration = QueryConfiguration( + top_k=session.response_chunks, + model_name=session.inference_model, + rerank_model_name=session.rerank_model, + exclude_knowledge_base=configuration.exclude_knowledge_base, + use_question_condensing=configuration.use_question_condensing, + use_hyde=session.query_configuration.enable_hyde, + use_summary_filter=session.query_configuration.enable_summary_filter, + ) + + response_id = str(uuid.uuid4()) + + if configuration.exclude_knowledge_base or len(session.data_source_ids) == 0: + return _stream_direct_llm_chat(session, response_id, query, user_name) + + total_data_sources_size: int = sum( + map( + lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0, + session.data_source_ids, + ) + ) + if total_data_sources_size == 0: + return _stream_direct_llm_chat(session, response_id, query, user_name) + + return _run_streaming_chat( + session, response_id, query, query_configuration, user_name + ) + + +def _run_streaming_chat( + session: Session, + response_id: str, + query: str, + query_configuration: QueryConfiguration, + user_name: Optional[str], +) -> Generator[ChatResponse, None, None]: + if len(session.data_source_ids) != 1: + raise HTTPException( + status_code=400, detail="Only one datasource is supported for chat." + ) + + data_source_id: int = session.data_source_ids[0] + streaming_chat_response, condensed_question = querier.streaming_query( + data_source_id, + query, + query_configuration, + retrieve_chat_history(session.id), + ) + + response: ChatResponse = ChatResponse(message=ChatMessage(content=query)) + if streaming_chat_response.chat_stream: + for response in streaming_chat_response.chat_stream: + response.additional_kwargs["response_id"] = response_id + yield response + + chat_response = AgentChatResponse( + response=response.message.content or "", + sources=streaming_chat_response.sources, + source_nodes=streaming_chat_response.source_nodes, + ) + + if condensed_question and (condensed_question.strip() == query.strip()): + condensed_question = None + relevance, faithfulness = evaluators.evaluate_response( + query, chat_response, session.inference_model + ) + response_source_nodes = format_source_nodes(chat_response, data_source_id) + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=response_source_nodes, + inference_model=session.inference_model, + rag_message=RagMessage( + user=query, + assistant=chat_response.response, + ), + evaluations=[ + Evaluation(name="relevance", value=relevance), + Evaluation(name="faithfulness", value=faithfulness), + ], + timestamp=time.time(), + condensed_question=condensed_question, + ) + + chat_history_manager.append_to_history(session.id, [new_chat_message]) + + record_rag_mlflow_run( + new_chat_message, query_configuration, response_id, session, user_name + ) + + +def _stream_direct_llm_chat( + session: Session, response_id: str, query: str, user_name: Optional[str] +) -> Generator[ChatResponse, None, None]: + record_direct_llm_mlflow_run(response_id, session, user_name) + + chat_response = llm_completion.stream_completion( + session.id, query, session.inference_model + ) + response: ChatResponse = ChatResponse(message=ChatMessage(content=query)) + for response in chat_response: + response.additional_kwargs["response_id"] = response_id + yield response + + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=[], + inference_model=session.inference_model, + evaluations=[], + rag_message=RagMessage( + user=query, + assistant=response.message.content or "", + ), + timestamp=time.time(), + condensed_question=None, + ) + chat_history_manager.append_to_history(session.id, [new_chat_message]) diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat/suggested_questions.py similarity index 50% rename from llm-service/app/services/chat.py rename to llm-service/app/services/chat/suggested_questions.py index 0eb385e8c..bdd409e65 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat/suggested_questions.py @@ -1,6 +1,6 @@ -# ############################################################################## +# # CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) -# (C) Cloudera, Inc. 2024 +# (C) Cloudera, Inc. 2025 # All rights reserved. # # Applicable Open Source License: Apache 2.0 @@ -20,7 +20,7 @@ # with an authorized and properly licensed third party, you do not # have any rights to access nor to use this code. # -# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the # contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY # KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED # WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO @@ -34,160 +34,20 @@ # RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. -# ############################################################################## -import time -import uuid +# + from random import shuffle -from typing import List, Iterable, Optional +from typing import List, Optional from fastapi import HTTPException -from llama_index.core.base.llms.types import MessageRole -from llama_index.core.chat_engine.types import AgentChatResponse -from pydantic import BaseModel - -from . import evaluators, llm_completion -from .chat_history.chat_history_manager import ( - RagPredictSourceNode, - Evaluation, - RagMessage, - RagStudioChatMessage, - chat_history_manager, -) -from .metadata_apis import session_metadata_api -from .metadata_apis.session_metadata_api import Session -from .mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run -from .query import querier -from .query.query_configuration import QueryConfiguration -from ..ai.vector_stores.vector_store_factory import VectorStoreFactory -from ..rag_types import RagPredictConfiguration - - -class RagContext(BaseModel): - role: MessageRole - content: str - - -def v2_chat( - session: Session, - query: str, - configuration: RagPredictConfiguration, - user_name: Optional[str], -) -> RagStudioChatMessage: - query_configuration = QueryConfiguration( - top_k=session.response_chunks, - model_name=session.inference_model, - rerank_model_name=session.rerank_model, - exclude_knowledge_base=configuration.exclude_knowledge_base, - use_question_condensing=configuration.use_question_condensing, - use_hyde=session.query_configuration.enable_hyde, - use_summary_filter=session.query_configuration.enable_summary_filter, - ) - - if configuration.exclude_knowledge_base or len(session.data_source_ids) == 0: - return direct_llm_chat(session, query, user_name=user_name) - - total_data_sources_size: int = sum( - map( - lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0, - session.data_source_ids, - ) - ) - if total_data_sources_size == 0: - return direct_llm_chat(session, query, user_name) - - response_id = str(uuid.uuid4()) - - new_chat_message: RagStudioChatMessage = _run_chat( - session, response_id, query, query_configuration, user_name - ) - - chat_history_manager.append_to_history(session.id, [new_chat_message]) - return new_chat_message - - -def _run_chat( - session: Session, - response_id: str, - query: str, - query_configuration: QueryConfiguration, - user_name: Optional[str], -) -> RagStudioChatMessage: - if len(session.data_source_ids) != 1: - raise HTTPException( - status_code=400, detail="Only one datasource is supported for chat." - ) - - data_source_id: int = session.data_source_ids[0] - response, condensed_question = querier.query( - data_source_id, - query, - query_configuration, - retrieve_chat_history(session.id), - ) - if condensed_question and (condensed_question.strip() == query.strip()): - condensed_question = None - relevance, faithfulness = evaluators.evaluate_response( - query, response, session.inference_model - ) - response_source_nodes = format_source_nodes(response, data_source_id) - new_chat_message = RagStudioChatMessage( - id=response_id, - session_id=session.id, - source_nodes=response_source_nodes, - inference_model=session.inference_model, - rag_message=RagMessage( - user=query, - assistant=response.response, - ), - evaluations=[ - Evaluation(name="relevance", value=relevance), - Evaluation(name="faithfulness", value=faithfulness), - ], - timestamp=time.time(), - condensed_question=condensed_question, - ) - - record_rag_mlflow_run( - new_chat_message, query_configuration, response_id, session, user_name - ) - return new_chat_message - - -def retrieve_chat_history(session_id: int) -> List[RagContext]: - chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10] - history: List[RagContext] = [] - for message in chat_history: - history.append( - RagContext(role=MessageRole.USER, content=message.rag_message.user) - ) - history.append( - RagContext( - role=MessageRole.ASSISTANT, content=message.rag_message.assistant - ) - ) - return history - - -def format_source_nodes( - response: AgentChatResponse, data_source_id: int -) -> List[RagPredictSourceNode]: - response_source_nodes = [] - for source_node in response.source_nodes: - doc_id = source_node.node.metadata.get("document_id", source_node.node.node_id) - response_source_nodes.append( - RagPredictSourceNode( - node_id=source_node.node.node_id, - doc_id=doc_id, - source_file_name=source_node.node.metadata["file_name"], - score=source_node.score or 0.0, - dataSourceId=data_source_id, - ) - ) - response_source_nodes = sorted( - response_source_nodes, key=lambda x: x.score, reverse=True - ) - return response_source_nodes +from app.ai.vector_stores.vector_store_factory import VectorStoreFactory +from app.services import llm_completion +from app.services.chat.utils import retrieve_chat_history, process_response +from app.services.metadata_apis import session_metadata_api +from app.services.metadata_apis.session_metadata_api import Session +from app.services.query import querier +from app.services.query.query_configuration import QueryConfiguration SAMPLE_QUESTIONS = [ "What is Cloudera, and how does it support organizations in managing big data?", @@ -296,44 +156,3 @@ def generate_suggested_questions( ) suggested_questions = process_response(response.response) return suggested_questions - - -def process_response(response: str | None) -> list[str]: - if response is None: - return [] - - sentences: Iterable[str] = response.splitlines() - sentences = map(lambda x: x.strip(), sentences) - sentences = map(lambda x: x.removeprefix("*").strip(), sentences) - sentences = map(lambda x: x.removeprefix("-").strip(), sentences) - sentences = map(lambda x: x.strip("*"), sentences) - sentences = filter(lambda x: len(x.split()) <= 60, sentences) - sentences = filter(lambda x: x != "Empty Response", sentences) - sentences = filter(lambda x: x != "", sentences) - return list(sentences)[:5] - - -def direct_llm_chat( - session: Session, query: str, user_name: Optional[str] -) -> RagStudioChatMessage: - response_id = str(uuid.uuid4()) - record_direct_llm_mlflow_run(response_id, session, user_name) - - chat_response = llm_completion.completion( - session.id, query, session.inference_model - ) - new_chat_message = RagStudioChatMessage( - id=response_id, - session_id=session.id, - source_nodes=[], - inference_model=session.inference_model, - evaluations=[], - rag_message=RagMessage( - user=query, - assistant=str(chat_response.message.content), - ), - timestamp=time.time(), - condensed_question=None, - ) - chat_history_manager.append_to_history(session.id, [new_chat_message]) - return new_chat_message diff --git a/llm-service/app/services/chat/utils.py b/llm-service/app/services/chat/utils.py new file mode 100644 index 000000000..f626be384 --- /dev/null +++ b/llm-service/app/services/chat/utils.py @@ -0,0 +1,101 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# + +from typing import List, Iterable + +from llama_index.core.base.llms.types import MessageRole +from llama_index.core.chat_engine.types import AgentChatResponse +from pydantic import BaseModel + +from app.services.chat_history.chat_history_manager import chat_history_manager, RagPredictSourceNode + + +class RagContext(BaseModel): + role: MessageRole + content: str + + +def retrieve_chat_history(session_id: int) -> List[RagContext]: + chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10] + history: List[RagContext] = [] + for message in chat_history: + history.append( + RagContext(role=MessageRole.USER, content=message.rag_message.user) + ) + history.append( + RagContext( + role=MessageRole.ASSISTANT, content=message.rag_message.assistant + ) + ) + return history + + +def format_source_nodes( + response: AgentChatResponse, data_source_id: int +) -> List[RagPredictSourceNode]: + response_source_nodes = [] + for source_node in response.source_nodes: + doc_id = source_node.node.metadata.get("document_id", source_node.node.node_id) + response_source_nodes.append( + RagPredictSourceNode( + node_id=source_node.node.node_id, + doc_id=doc_id, + source_file_name=source_node.node.metadata["file_name"], + score=source_node.score or 0.0, + dataSourceId=data_source_id, + ) + ) + response_source_nodes = sorted( + response_source_nodes, key=lambda x: x.score, reverse=True + ) + return response_source_nodes + + +def process_response(response: str | None) -> list[str]: + if response is None: + return [] + + sentences: Iterable[str] = response.splitlines() + sentences = map(lambda x: x.strip(), sentences) + sentences = map(lambda x: x.removeprefix("*").strip(), sentences) + sentences = map(lambda x: x.removeprefix("-").strip(), sentences) + sentences = map(lambda x: x.strip("*"), sentences) + sentences = filter(lambda x: len(x.split()) <= 60, sentences) + sentences = filter(lambda x: x != "Empty Response", sentences) + sentences = filter(lambda x: x != "", sentences) + return list(sentences)[:5] diff --git a/llm-service/app/services/llm_completion.py b/llm-service/app/services/llm_completion.py index 34541ee9f..922ea5861 100644 --- a/llm-service/app/services/llm_completion.py +++ b/llm-service/app/services/llm_completion.py @@ -36,8 +36,12 @@ # DATA. # import itertools +from typing import Generator -from llama_index.core.base.llms.types import ChatMessage, ChatResponse +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, +) from llama_index.core.llms import LLM from . import models @@ -66,6 +70,26 @@ def completion(session_id: int, question: str, model_name: str) -> ChatResponse: return model.chat(messages) +def stream_completion( + session_id: int, question: str, model_name: str +) -> Generator[ChatResponse, None, None]: + """ + Streamed version of the completion function. + Returns a generator that yields ChatResponse objects as they become available. + """ + model = models.LLM.get(model_name) + chat_history = chat_history_manager.retrieve_chat_history(session_id)[:10] + messages = list( + itertools.chain.from_iterable( + map(lambda x: make_chat_messages(x), chat_history) + ) + ) + messages.append(ChatMessage.from_str(question, role="user")) + + stream = model.stream_chat(messages) + return stream + + def hypothetical(question: str, configuration: QueryConfiguration) -> str: model: LLM = models.LLM.get(configuration.model_name) prompt: str = ( diff --git a/llm-service/app/services/models/llm.py b/llm-service/app/services/models/llm.py index a270c841f..e09ec4a17 100644 --- a/llm-service/app/services/models/llm.py +++ b/llm-service/app/services/models/llm.py @@ -79,6 +79,7 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM: model=model_name, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, + max_tokens=1024, ) @staticmethod diff --git a/llm-service/app/services/query/querier.py b/llm-service/app/services/query/querier.py index 816acff55..717e8e74f 100644 --- a/llm-service/app/services/query/querier.py +++ b/llm-service/app/services/query/querier.py @@ -32,7 +32,7 @@ import typing if typing.TYPE_CHECKING: - from ..chat import RagContext + from ..chat.utils import RagContext import logging from typing import List, Optional @@ -43,7 +43,10 @@ from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.base.llms.types import ChatMessage -from llama_index.core.chat_engine.types import AgentChatResponse +from llama_index.core.chat_engine.types import ( + AgentChatResponse, + StreamingAgentChatResponse, +) from llama_index.core.indices import VectorStoreIndex from llama_index.core.llms import LLM from llama_index.core.postprocessor.types import BaseNodePostprocessor @@ -76,6 +79,55 @@ CUSTOM_PROMPT = PromptTemplate(CUSTOM_TEMPLATE) +def streaming_query( + data_source_id: int, + query_str: str, + configuration: QueryConfiguration, + chat_history: list[RagContext], +) -> tuple[StreamingAgentChatResponse, str | None]: + qdrant_store = VectorStoreFactory.for_chunks(data_source_id) + vector_store = qdrant_store.llama_vector_store() + embedding_model = qdrant_store.get_embedding_model() + index = VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embedding_model, + ) + logger.info("fetched Qdrant index") + llm = models.LLM.get(model_name=configuration.model_name) + + retriever = _create_retriever( + configuration, embedding_model, index, data_source_id, llm + ) + chat_engine = _build_flexible_chat_engine( + configuration, llm, retriever, data_source_id + ) + + logger.info("querying chat engine") + chat_messages = list( + map( + lambda message: ChatMessage(role=message.role, content=message.content), + chat_history, + ) + ) + + condensed_question: str = chat_engine.condense_question( + chat_messages, query_str + ).strip() + try: + chat_response: StreamingAgentChatResponse = chat_engine.stream_chat( + query_str, chat_messages + ) + logger.info("query response received from chat engine") + return chat_response, condensed_question + except botocore.exceptions.ClientError as error: + logger.warning(error.response) + json_error = error.response + raise HTTPException( + status_code=json_error["ResponseMetadata"]["HTTPStatusCode"], + detail=json_error["message"], + ) from error + + def query( data_source_id: int, query_str: str, diff --git a/llm-service/app/services/session.py b/llm-service/app/services/session.py index 274780ca1..7020dfd12 100644 --- a/llm-service/app/services/session.py +++ b/llm-service/app/services/session.py @@ -40,7 +40,10 @@ from fastapi import HTTPException from . import models -from .chat_history.chat_history_manager import chat_history_manager +from .chat_history.chat_history_manager import ( + chat_history_manager, + RagStudioChatMessage, +) from .metadata_apis import session_metadata_api RENAME_SESSION_PROMPT_TEMPLATE = """ @@ -78,7 +81,7 @@ def rename_session(session_id: int, user_name: Optional[str]) -> str: - chat_history = chat_history_manager.retrieve_chat_history(session_id=session_id) + chat_history: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history(session_id=session_id) if not chat_history: raise HTTPException(status_code=400, detail="No chat history found") first_interaction = chat_history[0].rag_message diff --git a/llm-service/app/tests/services/test_chat.py b/llm-service/app/tests/services/test_chat.py index 32e4ae147..ce0136f50 100644 --- a/llm-service/app/tests/services/test_chat.py +++ b/llm-service/app/tests/services/test_chat.py @@ -40,7 +40,7 @@ from hypothesis import example, given from hypothesis import strategies as st -from app.services.chat import process_response +from app.services.chat.utils import process_response @st.composite diff --git a/scripts/release_version.txt b/scripts/release_version.txt index 114c45f6b..32150ea1e 100644 --- a/scripts/release_version.txt +++ b/scripts/release_version.txt @@ -1 +1 @@ -export RELEASE_TAG=1.17.0 +export RELEASE_TAG=dev-testing diff --git a/ui/package.json b/ui/package.json index 30038622c..e6dead992 100644 --- a/ui/package.json +++ b/ui/package.json @@ -21,6 +21,7 @@ "@ant-design/icons": "^5.5.1", "@emotion/react": "^11.14.0", "@emotion/styled": "^11.14.0", + "@microsoft/fetch-event-source": "^2.0.1", "@mui/material": "^6.4.3", "@mui/x-charts": "^7.26.0", "@tanstack/react-query": "^5.59.20", diff --git a/ui/pnpm-lock.yaml b/ui/pnpm-lock.yaml index 0a9c6b260..92a152595 100644 --- a/ui/pnpm-lock.yaml +++ b/ui/pnpm-lock.yaml @@ -17,6 +17,9 @@ importers: '@emotion/styled': specifier: ^11.14.0 version: 11.14.0(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react@18.3.1) + '@microsoft/fetch-event-source': + specifier: ^2.0.1 + version: 2.0.1 '@mui/material': specifier: ^6.4.3 version: 6.4.3(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@emotion/styled@11.14.0(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) @@ -768,6 +771,9 @@ packages: '@jridgewell/trace-mapping@0.3.25': resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==} + '@microsoft/fetch-event-source@2.0.1': + resolution: {integrity: sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==} + '@mui/core-downloads-tracker@6.4.3': resolution: {integrity: sha512-hlyOzo2ObarllAOeT1ZSAusADE5NZNencUeIvXrdQ1Na+FL1lcznhbxfV5He1KqGiuR8Az3xtCUcYKwMVGFdzg==} @@ -4719,6 +4725,8 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.0 + '@microsoft/fetch-event-source@2.0.1': {} + '@mui/core-downloads-tracker@6.4.3': {} '@mui/material@6.4.3(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@emotion/styled@11.14.0(@emotion/react@11.14.0(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react@18.3.1))(@types/react@18.3.12)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index fdb9abb18..7411d2f25 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -37,6 +37,7 @@ ******************************************************************************/ import { + commonHeaders, getRequest, llmServicePath, MutationKeys, @@ -52,6 +53,11 @@ import { useQueryClient, } from "@tanstack/react-query"; import { suggestedQuestionKey } from "src/api/ragQueryApi.ts"; +import { + EventSourceMessage, + EventStreamContentType, + fetchEventSource, +} from "@microsoft/fetch-event-source"; export interface SourceNode { node_id: string; @@ -103,7 +109,7 @@ export interface ChatResponseFeedback { rating: boolean; } -const placeholderChatResponseId = "placeholder"; +export const placeholderChatResponseId = "placeholder"; export const isPlaceholder = (chatMessage: ChatMessageType): boolean => { return chatMessage.id === placeholderChatResponseId; @@ -253,75 +259,6 @@ export const replacePlaceholderInChatHistory = ( }; }; -export const useChatMutation = ({ - onSuccess, - onError, -}: UseMutationType) => { - const queryClient = useQueryClient(); - return useMutation({ - mutationKey: [MutationKeys.chatMutation], - mutationFn: chatMutation, - onMutate: (variables) => { - queryClient.setQueryData>( - chatHistoryQueryKey({ - session_id: variables.session_id, - }), - (cachedData) => - appendPlaceholderToChatHistory(variables.query, cachedData), - ); - }, - onSuccess: (data, variables) => { - queryClient.setQueryData>( - chatHistoryQueryKey({ - session_id: variables.session_id, - }), - (cachedData) => replacePlaceholderInChatHistory(data, cachedData), - ); - queryClient - .invalidateQueries({ - queryKey: suggestedQuestionKey(variables.session_id), - }) - .catch((error: unknown) => { - console.error(error); - }); - onSuccess?.(data); - }, - onError: (error: Error, variables) => { - const uuid = crypto.randomUUID(); - const errorMessage: ChatMessageType = { - id: `error-${uuid}`, - session_id: variables.session_id, - source_nodes: [], - rag_message: { - user: variables.query, - assistant: error.message, - }, - evaluations: [], - timestamp: Date.now(), - }; - queryClient.setQueryData>( - chatHistoryQueryKey({ - session_id: variables.session_id, - offset: 0, - }), - (cachedData) => - replacePlaceholderInChatHistory(errorMessage, cachedData), - ); - - onError?.(error); - }, - }); -}; - -const chatMutation = async ( - request: ChatMutationRequest, -): Promise => { - return await postRequest( - `${llmServicePath}/sessions/${request.session_id.toString()}/chat`, - request, - ); -}; - export const createQueryConfiguration = ( excludeKnowledgeBase: boolean, ): QueryConfiguration => { @@ -384,3 +321,154 @@ const feedbackMutation = async ({ { feedback }, ); }; + +export interface ChatMutationResponse { + text?: string; + response_id?: string; + error?: string; +} + +const errorChatMessage = (variables: ChatMutationRequest, error: Error) => { + const uuid = crypto.randomUUID(); + const errorMessage: ChatMessageType = { + id: `error-${uuid}`, + session_id: variables.session_id, + source_nodes: [], + rag_message: { + user: variables.query, + assistant: error.message, + }, + evaluations: [], + timestamp: Date.now(), + }; + return errorMessage; +}; + +export const useStreamingChatMutation = ({ + onError, + onSuccess, + onChunk, +}: UseMutationType & { onChunk: (msg: string) => void }) => { + const queryClient = useQueryClient(); + const handleError = (variables: ChatMutationRequest, error: Error) => { + const errorMessage = errorChatMessage(variables, error); + queryClient.setQueryData>( + chatHistoryQueryKey({ + session_id: variables.session_id, + offset: 0, + }), + (cachedData) => replacePlaceholderInChatHistory(errorMessage, cachedData), + ); + }; + return useMutation({ + mutationKey: [MutationKeys.chatMutation], + mutationFn: (request: ChatMutationRequest) => { + const convertError = (errorMessage: string) => { + const error = new Error(errorMessage); + handleError(request, error); + onError?.(error); + }; + return streamChatMutation(request, onChunk, convertError); + }, + onMutate: (variables) => { + queryClient.setQueryData>( + chatHistoryQueryKey({ + session_id: variables.session_id, + }), + (cachedData) => + appendPlaceholderToChatHistory(variables.query, cachedData), + ); + }, + onSuccess: (messageId, variables) => { + if (!messageId) { + return; + } + fetch( + `${llmServicePath}/sessions/${variables.session_id.toString()}/chat-history/${messageId}`, + ) + .then(async (res) => { + const message = (await res.json()) as ChatMessageType; + queryClient.setQueryData>( + chatHistoryQueryKey({ + session_id: variables.session_id, + }), + (cachedData) => + replacePlaceholderInChatHistory(message, cachedData), + ); + queryClient + .invalidateQueries({ + queryKey: suggestedQuestionKey(variables.session_id), + }) + .catch((error: unknown) => { + console.error(error); + }); + onSuccess?.(message); + }) + .catch((error: unknown) => { + handleError(variables, error as Error); + onError?.(error as Error); + }); + }, + onError: (error: Error, variables) => { + handleError(variables, error); + onError?.(error); + }, + }); +}; + +const streamChatMutation = async ( + request: ChatMutationRequest, + onChunk: (chunk: string) => void, + onError: (error: string) => void, +): Promise => { + const ctrl = new AbortController(); + let responseId = ""; + await fetchEventSource( + `${llmServicePath}/sessions/${request.session_id.toString()}/stream-completion`, + { + method: "POST", + headers: commonHeaders, + body: JSON.stringify({ + query: request.query, + configuration: request.configuration, + }), + signal: ctrl.signal, + onmessage(msg: EventSourceMessage) { + const data = JSON.parse(msg.data) as ChatMutationResponse; + + if (data.error) { + ctrl.abort(); + onError(data.error); + } + + if (data.text) { + onChunk(data.text); + } + if (data.response_id) { + responseId = data.response_id; + } + }, + onerror(err: unknown) { + ctrl.abort(); + onError(String(err)); + }, + async onopen(response) { + if ( + response.ok && + response.headers.get("content-type")?.includes(EventStreamContentType) + ) { + await Promise.resolve(); + } else if ( + response.status >= 400 && + response.status < 500 && + response.status !== 429 + ) { + onError("An error occurred: " + response.statusText); + } else { + onError("An error occurred: " + response.statusText); + } + }, + }, + ); + return responseId; +}; diff --git a/ui/src/api/modelsApi.ts b/ui/src/api/modelsApi.ts index e36c8f8c3..70759e4fa 100644 --- a/ui/src/api/modelsApi.ts +++ b/ui/src/api/modelsApi.ts @@ -57,13 +57,11 @@ export const useGetModelById = (model_id?: string) => { return useQuery({ queryKey: [QueryKeys.getModelById, { model_id }], queryFn: async () => { - if (!model_id) { - return undefined; - } const llmModels = await getLlmModels(); return llmModels.find((model) => model.model_id === model_id); }, staleTime: 1000 * 60 * 60, + enabled: !!model_id, }); }; diff --git a/ui/src/api/utils.ts b/ui/src/api/utils.ts index a6c06b5a1..ccdce209a 100644 --- a/ui/src/api/utils.ts +++ b/ui/src/api/utils.ts @@ -80,6 +80,7 @@ export enum MutationKeys { "removeDataSourceFromProject" = "removeDataSourceFromProject", "updateAmpConfig" = "updateAmpConfig", "restartApplication" = "restartApplication", + "streamChatMutation" = "streamChatMutation", } export enum QueryKeys { diff --git a/ui/src/pages/RagChatTab/ChatLayout.tsx b/ui/src/pages/RagChatTab/ChatLayout.tsx index 0042f44b2..ffb5247ed 100644 --- a/ui/src/pages/RagChatTab/ChatLayout.tsx +++ b/ui/src/pages/RagChatTab/ChatLayout.tsx @@ -73,6 +73,7 @@ function ChatLayout() { const { data: dataSources, status: dataSourcesStatus } = useGetDataSourcesForProject(+projectId); const [excludeKnowledgeBase, setExcludeKnowledgeBase] = useState(false); + const [streamedChat, setStreamedChat] = useState(""); const { status: chatHistoryStatus, data: chatHistory, @@ -108,6 +109,7 @@ function ChatLayout() { isFetching, isFetchingPreviousPage, }, + streamedChatState: [streamedChat, setStreamedChat], dataSourceSize, dataSourcesQuery: { dataSources: dataSources ?? [], diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx index bf4fab206..46b756c94 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx @@ -37,19 +37,13 @@ ******************************************************************************/ import { Alert, Divider, Flex, Typography } from "antd"; -import SourceNodes from "pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx"; import PendingRagOutputSkeleton from "pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx"; import { ChatMessageType, isPlaceholder } from "src/api/chatApi.ts"; -import { cdlBlue500, cdlGray200 } from "src/cuix/variables.ts"; import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion.tsx"; -import { Evaluations } from "pages/RagChatTab/ChatOutput/ChatMessages/Evaluations.tsx"; -import Images from "src/components/images/Images.ts"; -import RatingFeedbackWrapper from "pages/RagChatTab/ChatOutput/ChatMessages/RatingFeedbackWrapper.tsx"; -import Remark from "remark-gfm"; -import Markdown from "react-markdown"; import "../tableMarkdown.css"; import { ExclamationCircleTwoTone } from "@ant-design/icons"; +import { ChatMessageBody } from "pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx"; const isError = (data: ChatMessageType) => { return data.id.startsWith("error-"); @@ -93,64 +87,7 @@ const ChatMessage = ({ data }: { data: ChatMessageType }) => { return ; } - return ( -
- {data.rag_message.user ? ( -
- - -
- {data.source_nodes.length > 0 ? ( - - ) : ( - - )} -
- - - - - {data.rag_message.assistant.trimStart()} - - - - - - - -
-
- ) : null} - -
- ); + return ; }; export default ChatMessage; diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx new file mode 100644 index 000000000..d64cb3bc1 --- /dev/null +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx @@ -0,0 +1,109 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2025 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. ("Cloudera") to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +import { ChatMessageType } from "src/api/chatApi.ts"; +import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion.tsx"; +import { Divider, Flex, Typography } from "antd"; +import Images from "src/components/images/Images.ts"; +import { cdlBlue500, cdlGray200 } from "src/cuix/variables.ts"; +import SourceNodes from "pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx"; +import Markdown from "react-markdown"; +import Remark from "remark-gfm"; +import { Evaluations } from "pages/RagChatTab/ChatOutput/ChatMessages/Evaluations.tsx"; +import RatingFeedbackWrapper from "pages/RagChatTab/ChatOutput/ChatMessages/RatingFeedbackWrapper.tsx"; + +export const ChatMessageBody = ({ data }: { data: ChatMessageType }) => { + return ( +
+ {data.rag_message.user ? ( +
+ + +
+ {data.source_nodes.length > 0 ? ( + + ) : ( + + )} +
+ + + + + {data.rag_message.assistant.trimStart()} + + + + + + + +
+
+ ) : null} + +
+ ); +}; diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx index 5ddaccc90..7797c5451 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx @@ -50,7 +50,7 @@ import messageQueue from "src/utils/messageQueue.ts"; import { createQueryConfiguration, isPlaceholder, - useChatMutation, + useStreamingChatMutation, } from "src/api/chatApi.ts"; import { useRenameNameMutation } from "src/api/sessionApi.ts"; import NoDataSourcesState from "pages/RagChatTab/ChatOutput/Placeholders/NoDataSourcesState.tsx"; @@ -64,6 +64,7 @@ const ChatMessageController = () => { isFetching: isFetchingHistory, isFetchingPreviousPage, }, + streamedChatState: [, setStreamedChat], activeSession, } = useContext(RagChatContext); const { ref: refToFetchNextPage, inView } = useInView({ threshold: 0 }); @@ -77,8 +78,13 @@ const ChatMessageController = () => { messageQueue.error(err.message); }, }); - const { mutate: chatMutation } = useChatMutation({ + + const { mutate: chatMutation } = useStreamingChatMutation({ + onChunk: (chunk) => { + setStreamedChat((prev) => prev + chunk); + }, onSuccess: () => { + setStreamedChat(""); const url = new URL(window.location.href); url.searchParams.delete("question"); window.history.pushState(null, "", url.toString()); @@ -127,26 +133,15 @@ const ChatMessageController = () => { }, [fetchPreviousPage, inView]); useEffect(() => { - // scroll to bottom when changing the active session if (bottomElement.current) { - setTimeout(() => { - if (bottomElement.current) { - bottomElement.current.scrollIntoView({ behavior: "auto" }); - } - }, 50); - } - }, [bottomElement.current, activeSession?.id]); - - useEffect(() => { - if ( - flatChatHistory.length > 0 && - isPlaceholder(flatChatHistory[flatChatHistory.length - 1]) - ) { - setTimeout(() => { - if (bottomElement.current) { - bottomElement.current.scrollIntoView({ behavior: "auto" }); - } - }, 50); + if ( + flatChatHistory.length > 0 && + isPlaceholder(flatChatHistory[flatChatHistory.length - 1]) + ) { + bottomElement.current.scrollIntoView({ behavior: "smooth" }); + } else { + bottomElement.current.scrollIntoView({ behavior: "auto" }); + } } }, [bottomElement.current, flatChatHistory.length, activeSession?.id]); @@ -181,18 +176,27 @@ const ChatMessageController = () => {
{isFetchingPreviousPage && } {flatChatHistory.map((historyMessage, index) => { - // trigger fetching on second to la`st item + const isLast = index === flatChatHistory.length - 1; + // trigger fetching on second to last item if (index === 2) { return (
+ {isLast &&
}
); } - return ; + return ( +
+ +
+ ); })} -
); }; diff --git a/ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx b/ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx index b39f23f19..105ae9016 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx @@ -36,23 +36,29 @@ * DATA. ******************************************************************************/ -import { Divider, Row, Skeleton } from "antd"; -import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion.tsx"; +import { useContext } from "react"; +import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; +import { ChatMessageType, placeholderChatResponseId } from "src/api/chatApi.ts"; +import { ChatMessageBody } from "pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx"; const PendingRagOutputSkeleton = ({ question }: { question: string }) => { - return ( -
-
- - - - - - -
- -
- ); + const { + streamedChatState: [streamedChat], + } = useContext(RagChatContext); + + const streamedMessage: ChatMessageType = { + id: placeholderChatResponseId, + session_id: 0, + source_nodes: [], + rag_message: { + user: question, + assistant: streamedChat, + }, + evaluations: [], + timestamp: Date.now(), + }; + + return ; }; export default PendingRagOutputSkeleton; diff --git a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx index 94276ff1f..84694a3d1 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx @@ -40,7 +40,10 @@ import { Card, Flex, Skeleton, Typography } from "antd"; import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; import { useContext } from "react"; import { useSuggestQuestions } from "src/api/ragQueryApi.ts"; -import { createQueryConfiguration, useChatMutation } from "src/api/chatApi.ts"; +import { + createQueryConfiguration, + useStreamingChatMutation, +} from "src/api/chatApi.ts"; import useCreateSessionAndRedirect from "pages/RagChatTab/ChatOutput/hooks/useCreateSessionAndRedirect"; const QuestionCard = ({ @@ -77,6 +80,7 @@ const SuggestedQuestionsCards = () => { const { activeSession, excludeKnowledgeBaseState: [excludeKnowledgeBase], + streamedChatState: [, setStreamedChat], } = useContext(RagChatContext); const sessionId = activeSession?.id; const { @@ -88,9 +92,15 @@ const SuggestedQuestionsCards = () => { }); const createSessionAndRedirect = useCreateSessionAndRedirect(); - const { mutate: chatMutation, isPending: askRagIsPending } = useChatMutation( - {}, - ); + const { mutate: chatMutation, isPending: askRagIsPending } = + useStreamingChatMutation({ + onChunk: (chunk) => { + setStreamedChat((prev) => prev + chunk); + }, + onSuccess: () => { + setStreamedChat(""); + }, + }); const handleAskSample = (suggestedQuestion: string) => { if (suggestedQuestion.length > 0) { diff --git a/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx b/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx index 02e1799cc..92eff03fc 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceNodes.tsx @@ -36,15 +36,18 @@ * DATA. ******************************************************************************/ -import { Flex, Typography } from "antd"; +import { Flex, Skeleton, Typography } from "antd"; import { SourceCard } from "pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx"; -import { ChatMessageType } from "src/api/chatApi.ts"; +import { ChatMessageType, isPlaceholder } from "src/api/chatApi.ts"; import { WarningTwoTone } from "@ant-design/icons"; import { cdlOrange050, cdlOrange500 } from "src/cuix/variables.ts"; import { useGetModelById } from "src/api/modelsApi.ts"; import { useContext } from "react"; import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; +const SkeletonNode = () => { + return ; +}; const SourceNodes = ({ data }: { data: ChatMessageType }) => { const { data: inferenceModel } = useGetModelById(data.inference_model); const { activeSession } = useContext(RagChatContext); @@ -53,6 +56,21 @@ const SourceNodes = ({ data }: { data: ChatMessageType }) => { )); + if ( + isPlaceholder(data) && + activeSession && + activeSession.dataSourceIds.length > 0 + ) { + return ( + + + + + + + ); + } + if ( nodes.length === 0 && activeSession && diff --git a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx index 51cc4a43f..2d51f4d8c 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx @@ -40,7 +40,10 @@ import { Button, Flex, Input, InputRef, Switch, Tooltip } from "antd"; import { DatabaseFilled, SendOutlined } from "@ant-design/icons"; import { useContext, useEffect, useRef, useState } from "react"; import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; -import { createQueryConfiguration, useChatMutation } from "src/api/chatApi.ts"; +import { + createQueryConfiguration, + useStreamingChatMutation, +} from "src/api/chatApi.ts"; import { useParams, useSearch } from "@tanstack/react-router"; import { cdlBlue600 } from "src/cuix/variables.ts"; @@ -58,6 +61,7 @@ const RagChatQueryInput = ({ chatHistoryQuery: { flatChatHistory }, dataSourceSize, dataSourcesQuery: { dataSourcesStatus }, + streamedChatState: [, setStreamedChat], } = useContext(RagChatContext); const [userInput, setUserInput] = useState(""); @@ -66,7 +70,6 @@ const RagChatQueryInput = ({ strict: false, }); const inputRef = useRef(null); - const { data: sampleQuestions, isPending: sampleQuestionsIsPending, @@ -79,9 +82,13 @@ const RagChatQueryInput = ({ !search.question, ); - const chatMutation = useChatMutation({ + const streamChatMutation = useStreamingChatMutation({ + onChunk: (chunk) => { + setStreamedChat((prev) => prev + chunk); + }, onSuccess: () => { setUserInput(""); + setStreamedChat(""); }, }); @@ -97,7 +104,7 @@ const RagChatQueryInput = ({ } if (userInput.length > 0) { if (sessionId) { - chatMutation.mutate({ + streamChatMutation.mutate({ query: userInput, session_id: +sessionId, configuration: createQueryConfiguration(excludeKnowledgeBase), @@ -156,7 +163,7 @@ const RagChatQueryInput = ({ /> } - disabled={chatMutation.isPending} + disabled={streamChatMutation.isPending} />