Skip to content

Commit 87c87f0

Browse files
jkwatsonewilliams-clouderamliu-clouderabaasitshariefactions-user
authored
Streaming!!!! (#202)
* wip on simple streaming * simple poc for streaming * remove usage from RagChatQueryInput.tsx * remove stream hypothetical * remove unused import * wip on doing something once the gen is done * progress on generators * go back to simple streaming only endpoint * wip lastFile:llm-service/app/services/chat.py * add response id on every chunk returned lastFile:llm-service/app/routers/index/sessions/__init__.py * remove duplicate calls, but still not rendering * getting there * Consolidate response_id generation * wip lastFile:ui/src/api/chatApi.ts * drop databases lastFile:ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx * mob next [ci-skip] [ci skip] [skip ci] lastFile:ui/src/pages/RagChatTab/ChatOutput/Loaders/PendingRagOutputSkeleton.tsx * mob next [ci-skip] [ci skip] [skip ci] lastFile:llm-service/app/routers/index/sessions/__init__.py * small refactor * remove deps * things are getting close * wip lastFile:ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx * drop databases lastFile:llm-service/app/services/chat.py * wip lastFile:llm-service/app/services/chat.py * mob next [ci-skip] [ci skip] [skip ci] lastFile:llm-service/app/services/chat.py * drop databases lastFile:llm-service/app/services/chat.py * wip lastFile:llm-service/app/services/chat.py * fixing scrolling * only show loading nodes if kb * remove unused * removing active loading state * fix mypy issues * ruff * Update release version to dev-testing * handle file not found error for summaries when local * remove log * renaming * better error handling * bump bedrock to use max tokens of 1024 * python refactoring lastFile:llm-service/app/routers/index/sessions/__init__.py * mob next [ci-skip] [ci skip] [skip ci] lastFile:llm-service/app/routers/index/sessions/__init__.py * nits --------- Co-authored-by: Elijah Williams <ewilliams@cloudera.com> Co-authored-by: Michael Liu <mliu@cloudera.com> Co-authored-by: Baasit Sharief <baasitsharief@gmail.com> Co-authored-by: actions-user <actions@github.com>
1 parent 4613ff7 commit 87c87f0

File tree

28 files changed

+1051
-405
lines changed

28 files changed

+1051
-405
lines changed

llm-service/app/ai/indexing/summary_indexer.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,17 @@ def create_storage_context(
218218
@classmethod
219219
def get_all_data_source_summaries(cls) -> dict[str, str]:
220220
root_dir = cls.__persist_root_dir()
221-
# if not os.path.exists(root_dir):
222-
# return {}
223-
storage_context = SummaryIndexer.create_storage_context(
224-
persist_dir=root_dir,
225-
vector_store=SimpleVectorStore(),
226-
)
227-
indices = load_indices_from_storage(storage_context=storage_context, index_ids=None,
221+
try:
222+
storage_context = SummaryIndexer.create_storage_context(
223+
persist_dir=root_dir,
224+
vector_store=SimpleVectorStore(),
225+
)
226+
except FileNotFoundError:
227+
# If the directory doesn't exist, we don't have any summaries.
228+
return {}
229+
indices = load_indices_from_storage(
230+
storage_context=storage_context,
231+
index_ids=None,
228232
**{
229233
"llm": models.LLM.get_noop(),
230234
"response_synthesizer": models.LLM.get_noop(),
@@ -234,11 +238,13 @@ def get_all_data_source_summaries(cls) -> dict[str, str]:
234238
"summary_query": "None",
235239
"data_source_id": 0,
236240
},
237-
)
241+
)
238242
if len(indices) == 0:
239243
return {}
240244

241-
global_summary_store: DocumentSummaryIndex = cast(DocumentSummaryIndex, indices[0])
245+
global_summary_store: DocumentSummaryIndex = cast(
246+
DocumentSummaryIndex, indices[0]
247+
)
242248

243249
summary_ids = global_summary_store.index_struct.doc_id_to_summary_id.values()
244250
nodes = global_summary_store.docstore.get_nodes(list(summary_ids))

llm-service/app/routers/index/chat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pydantic import BaseModel
4343

4444
from app import exceptions
45-
from app.services.chat import generate_suggested_questions
45+
from app.services.chat.suggested_questions import generate_suggested_questions
4646

4747
logger = logging.getLogger(__name__)
4848
router = APIRouter(prefix="/chat", tags=["Chat"])

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@
3838
import base64
3939
import json
4040
import logging
41-
from typing import Optional
41+
from typing import Optional, Generator
4242

43-
from fastapi import APIRouter, Header
43+
from fastapi import APIRouter, Header, HTTPException
44+
from fastapi.responses import StreamingResponse
4445
from pydantic import BaseModel
4546

47+
from app.services.chat.streaming_chat import stream_chat
4648
from .... import exceptions
4749
from ....rag_types import RagPredictConfiguration
48-
from ....services.chat import (
49-
v2_chat,
50+
from ....services.chat.chat import (
51+
chat as run_chat,
5052
)
5153
from ....services.chat_history.chat_history_manager import (
5254
RagStudioChatMessage,
@@ -100,6 +102,24 @@ def chat_history(
100102
)
101103

102104

105+
@router.get(
106+
"/chat-history/{message_id}",
107+
summary="Returns a specific chat messages for the provided session.",
108+
)
109+
@exceptions.propagates
110+
def get_message_by_id(session_id: int, message_id: str) -> RagStudioChatMessage:
111+
results: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history(
112+
session_id=session_id
113+
)
114+
for message in results:
115+
if message.id == message_id:
116+
return message
117+
raise HTTPException(
118+
status_code=404,
119+
detail=f"Message with id {message_id} not found in session {session_id}",
120+
)
121+
122+
103123
@router.delete(
104124
"/chat-history", summary="Deletes the chat history for the provided session."
105125
)
@@ -161,6 +181,10 @@ class RagStudioChatRequest(BaseModel):
161181
configuration: RagPredictConfiguration | None = None
162182

163183

184+
class StreamCompletionRequest(BaseModel):
185+
query: str
186+
187+
164188
def parse_jwt_cookie(jwt_cookie: str | None) -> str:
165189
if jwt_cookie is None:
166190
return "unknown"
@@ -187,4 +211,34 @@ def chat(
187211
session = session_metadata_api.get_session(session_id, user_name=origin_remote_user)
188212

189213
configuration = request.configuration or RagPredictConfiguration()
190-
return v2_chat(session, request.query, configuration, user_name=origin_remote_user)
214+
return run_chat(session, request.query, configuration, user_name=origin_remote_user)
215+
216+
217+
@router.post(
218+
"/stream-completion", summary="Stream completion responses for the given query"
219+
)
220+
@exceptions.propagates
221+
def stream_chat_completion(
222+
session_id: int,
223+
request: RagStudioChatRequest,
224+
origin_remote_user: Optional[str] = Header(None),
225+
) -> StreamingResponse:
226+
session = session_metadata_api.get_session(session_id, user_name=origin_remote_user)
227+
configuration = request.configuration or RagPredictConfiguration()
228+
229+
def generate_stream() -> Generator[str, None, None]:
230+
response_id: str = ""
231+
try:
232+
for response in stream_chat(
233+
session, request.query, configuration, user_name=origin_remote_user
234+
):
235+
print(response)
236+
response_id = response.additional_kwargs["response_id"]
237+
json_delta = json.dumps({"text": response.delta})
238+
yield f"data: {json_delta}" + "\n\n"
239+
yield f'data: {{"response_id" : "{response_id}"}}\n\n'
240+
except Exception as e:
241+
logger.exception("Failed to stream chat completion")
242+
yield f'data: {{"error" : "{e}"}}\n\n'
243+
244+
return StreamingResponse(generate_stream(), media_type="text/event-stream")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
3+
# (C) Cloudera, Inc. 2025
4+
# All rights reserved.
5+
#
6+
# Applicable Open Source License: Apache 2.0
7+
#
8+
# NOTE: Cloudera open source products are modular software products
9+
# made up of hundreds of individual components, each of which was
10+
# individually copyrighted. Each Cloudera open source product is a
11+
# collective work under U.S. Copyright Law. Your license to use the
12+
# collective work is as provided in your written agreement with
13+
# Cloudera. Used apart from the collective work, this file is
14+
# licensed for your use pursuant to the open source license
15+
# identified above.
16+
#
17+
# This code is provided to you pursuant a written agreement with
18+
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
19+
# this code. If you do not have a written agreement with Cloudera nor
20+
# with an authorized and properly licensed third party, you do not
21+
# have any rights to access nor to use this code.
22+
#
23+
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
24+
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
25+
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
26+
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
27+
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
28+
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
29+
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
30+
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
31+
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
32+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
33+
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
34+
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
35+
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
36+
# DATA.
37+
#
38+

llm-service/app/services/chat/chat.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#
2+
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
3+
# (C) Cloudera, Inc. 2025
4+
# All rights reserved.
5+
#
6+
# Applicable Open Source License: Apache 2.0
7+
#
8+
# NOTE: Cloudera open source products are modular software products
9+
# made up of hundreds of individual components, each of which was
10+
# individually copyrighted. Each Cloudera open source product is a
11+
# collective work under U.S. Copyright Law. Your license to use the
12+
# collective work is as provided in your written agreement with
13+
# Cloudera. Used apart from the collective work, this file is
14+
# licensed for your use pursuant to the open source license
15+
# identified above.
16+
#
17+
# This code is provided to you pursuant a written agreement with
18+
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
19+
# this code. If you do not have a written agreement with Cloudera nor
20+
# with an authorized and properly licensed third party, you do not
21+
# have any rights to access nor to use this code.
22+
#
23+
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
24+
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
25+
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
26+
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
27+
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
28+
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
29+
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
30+
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
31+
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
32+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
33+
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
34+
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
35+
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
36+
# DATA.
37+
#
38+
39+
import time
40+
import uuid
41+
from typing import Optional
42+
43+
from fastapi import HTTPException
44+
45+
from app.services import evaluators, llm_completion
46+
from app.services.chat.utils import retrieve_chat_history, format_source_nodes
47+
from app.services.chat_history.chat_history_manager import (
48+
Evaluation,
49+
RagMessage,
50+
RagStudioChatMessage,
51+
chat_history_manager,
52+
)
53+
from app.services.metadata_apis.session_metadata_api import Session
54+
from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run
55+
from app.services.query import querier
56+
from app.services.query.query_configuration import QueryConfiguration
57+
from app.ai.vector_stores.vector_store_factory import VectorStoreFactory
58+
from app.rag_types import RagPredictConfiguration
59+
60+
61+
def chat(
62+
session: Session,
63+
query: str,
64+
configuration: RagPredictConfiguration,
65+
user_name: Optional[str],
66+
) -> RagStudioChatMessage:
67+
query_configuration = QueryConfiguration(
68+
top_k=session.response_chunks,
69+
model_name=session.inference_model,
70+
rerank_model_name=session.rerank_model,
71+
exclude_knowledge_base=configuration.exclude_knowledge_base,
72+
use_question_condensing=configuration.use_question_condensing,
73+
use_hyde=session.query_configuration.enable_hyde,
74+
use_summary_filter=session.query_configuration.enable_summary_filter,
75+
)
76+
77+
response_id = str(uuid.uuid4())
78+
79+
if configuration.exclude_knowledge_base or len(session.data_source_ids) == 0:
80+
return direct_llm_chat(session, response_id, query, user_name)
81+
82+
total_data_sources_size: int = sum(
83+
map(
84+
lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0,
85+
session.data_source_ids,
86+
)
87+
)
88+
if total_data_sources_size == 0:
89+
return direct_llm_chat(session, response_id, query, user_name)
90+
91+
new_chat_message: RagStudioChatMessage = _run_chat(
92+
session, response_id, query, query_configuration, user_name
93+
)
94+
95+
chat_history_manager.append_to_history(session.id, [new_chat_message])
96+
return new_chat_message
97+
98+
99+
def _run_chat(
100+
session: Session,
101+
response_id: str,
102+
query: str,
103+
query_configuration: QueryConfiguration,
104+
user_name: Optional[str],
105+
) -> RagStudioChatMessage:
106+
if len(session.data_source_ids) != 1:
107+
raise HTTPException(
108+
status_code=400, detail="Only one datasource is supported for chat."
109+
)
110+
111+
data_source_id: int = session.data_source_ids[0]
112+
response, condensed_question = querier.query(
113+
data_source_id,
114+
query,
115+
query_configuration,
116+
retrieve_chat_history(session.id),
117+
)
118+
if condensed_question and (condensed_question.strip() == query.strip()):
119+
condensed_question = None
120+
relevance, faithfulness = evaluators.evaluate_response(
121+
query, response, session.inference_model
122+
)
123+
response_source_nodes = format_source_nodes(response, data_source_id)
124+
new_chat_message = RagStudioChatMessage(
125+
id=response_id,
126+
session_id=session.id,
127+
source_nodes=response_source_nodes,
128+
inference_model=session.inference_model,
129+
rag_message=RagMessage(
130+
user=query,
131+
assistant=response.response,
132+
),
133+
evaluations=[
134+
Evaluation(name="relevance", value=relevance),
135+
Evaluation(name="faithfulness", value=faithfulness),
136+
],
137+
timestamp=time.time(),
138+
condensed_question=condensed_question,
139+
)
140+
141+
record_rag_mlflow_run(
142+
new_chat_message, query_configuration, response_id, session, user_name
143+
)
144+
return new_chat_message
145+
146+
147+
def direct_llm_chat(
148+
session: Session, response_id: str, query: str, user_name: Optional[str]
149+
) -> RagStudioChatMessage:
150+
record_direct_llm_mlflow_run(response_id, session, user_name)
151+
152+
chat_response = llm_completion.completion(
153+
session.id, query, session.inference_model
154+
)
155+
new_chat_message = RagStudioChatMessage(
156+
id=response_id,
157+
session_id=session.id,
158+
source_nodes=[],
159+
inference_model=session.inference_model,
160+
evaluations=[],
161+
rag_message=RagMessage(
162+
user=query,
163+
assistant=str(chat_response.message.content),
164+
),
165+
timestamp=time.time(),
166+
condensed_question=None,
167+
)
168+
chat_history_manager.append_to_history(session.id, [new_chat_message])
169+
return new_chat_message
170+
171+

0 commit comments

Comments
 (0)