Skip to content

Commit 2418ff4

Browse files
authored
Update main.py for new API Catalog models
1 parent 0be681f commit 2418ff4

File tree

1 file changed

+31
-48
lines changed

1 file changed

+31
-48
lines changed

community/5_mins_rag_no_gpu/main.py

Lines changed: 31 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,56 +22,44 @@
2222

2323
import streamlit as st
2424
import os
25+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
26+
from langchain.text_splitter import CharacterTextSplitter
27+
from langchain_community.document_loaders import DirectoryLoader
28+
from langchain_community.vectorstores import FAISS
29+
import pickle
30+
from langchain_core.output_parsers import StrOutputParser
31+
from langchain_core.prompts import ChatPromptTemplate
2532

26-
st.set_page_config(layout = "wide")
33+
st.set_page_config(layout="wide")
2734

35+
# Component #1 - Document Upload
2836
with st.sidebar:
2937
DOCS_DIR = os.path.abspath("./uploaded_docs")
3038
if not os.path.exists(DOCS_DIR):
3139
os.makedirs(DOCS_DIR)
3240
st.subheader("Add to the Knowledge Base")
3341
with st.form("my-form", clear_on_submit=True):
34-
uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files = True)
42+
uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files=True)
3543
submitted = st.form_submit_button("Upload!")
3644

3745
if uploaded_files and submitted:
3846
for uploaded_file in uploaded_files:
3947
st.success(f"File {uploaded_file.name} uploaded successfully!")
40-
with open(os.path.join(DOCS_DIR, uploaded_file.name),"wb") as f:
48+
with open(os.path.join(DOCS_DIR, uploaded_file.name), "wb") as f:
4149
f.write(uploaded_file.read())
4250

43-
############################################
4451
# Component #2 - Embedding Model and LLM
45-
############################################
52+
llm = ChatNVIDIA(model="meta/llama3-70b-instruct")
53+
document_embedder = NVIDIAEmbeddings(model="NV-Embed-QA", model_type="passage")
54+
#query_embedder = NVIDIAEmbeddings(model="NV-Embed-QA", model_type="query")
4655

47-
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
48-
49-
# make sure to export your NVIDIA AI Playground key as NVIDIA_API_KEY!
50-
llm = ChatNVIDIA(model="ai-llama3-70b")
51-
document_embedder = NVIDIAEmbeddings(model="ai-embed-qa-4", model_type="passage")
52-
query_embedder = NVIDIAEmbeddings(model="ai-embed-qa-4", model_type="query")
53-
54-
############################################
5556
# Component #3 - Vector Database Store
56-
############################################
57-
58-
from langchain.text_splitter import CharacterTextSplitter
59-
from langchain_community.document_loaders import DirectoryLoader
60-
from langchain_community.vectorstores import FAISS
61-
import pickle
62-
6357
with st.sidebar:
64-
# Option for using an existing vector store
6558
use_existing_vector_store = st.radio("Use existing vector store if available", ["Yes", "No"], horizontal=True)
6659

67-
# Path to the vector store file
6860
vector_store_path = "vectorstore.pkl"
69-
70-
# Load raw documents from the directory
7161
raw_documents = DirectoryLoader(DOCS_DIR).load()
7262

73-
74-
# Check for existing vector store file
7563
vector_store_exists = os.path.exists(vector_store_path)
7664
vectorstore = None
7765
if use_existing_vector_store == "Yes" and vector_store_exists:
@@ -81,9 +69,9 @@
8169
st.success("Existing vector store loaded successfully.")
8270
else:
8371
with st.sidebar:
84-
if raw_documents:
72+
if raw_documents and use_existing_vector_store == "Yes":
8573
with st.spinner("Splitting documents into chunks..."):
86-
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
74+
text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=200)
8775
documents = text_splitter.split_documents(raw_documents)
8876

8977
with st.spinner("Adding document chunks to vector database..."):
@@ -96,10 +84,7 @@
9684
else:
9785
st.warning("No documents available to process!", icon="⚠️")
9886

99-
############################################
10087
# Component #4 - LLM Response Generation and Chat
101-
############################################
102-
10388
st.subheader("Chat with your AI Assistant, Envie!")
10489

10590
if "messages" not in st.session_state:
@@ -109,34 +94,32 @@
10994
with st.chat_message(message["role"]):
11095
st.markdown(message["content"])
11196

112-
from langchain_core.output_parsers import StrOutputParser
113-
from langchain_core.prompts import ChatPromptTemplate
114-
115-
prompt_template = ChatPromptTemplate.from_messages(
116-
[("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")]
117-
)
118-
user_input = st.chat_input("Can you tell me what NVIDIA is known for?")
119-
llm = ChatNVIDIA(model="ai-llama3-70b")
97+
prompt_template = ChatPromptTemplate.from_messages([
98+
("system", "You are a helpful AI assistant named Envie. If provided with context, use it to inform your responses. If no context is available, use your general knowledge to provide a helpful response."),
99+
("human", "{input}")
100+
])
120101

121102
chain = prompt_template | llm | StrOutputParser()
122103

123-
if user_input and vectorstore!=None:
104+
user_input = st.chat_input("Can you tell me what NVIDIA is known for?")
105+
106+
if user_input:
124107
st.session_state.messages.append({"role": "user", "content": user_input})
125-
retriever = vectorstore.as_retriever()
126-
docs = retriever.invoke(user_input)
127108
with st.chat_message("user"):
128109
st.markdown(user_input)
129110

130-
context = ""
131-
for doc in docs:
132-
context += doc.page_content + "\n\n"
133-
134-
augmented_user_input = "Context: " + context + "\n\nQuestion: " + user_input + "\n"
135-
136111
with st.chat_message("assistant"):
137112
message_placeholder = st.empty()
138113
full_response = ""
139114

115+
if vectorstore is not None and use_existing_vector_store == "Yes":
116+
retriever = vectorstore.as_retriever()
117+
docs = retriever.invoke(user_input)
118+
context = "\n\n".join([doc.page_content for doc in docs])
119+
augmented_user_input = f"Context: {context}\n\nQuestion: {user_input}\n"
120+
else:
121+
augmented_user_input = f"Question: {user_input}\n"
122+
140123
for response in chain.stream({"input": augmented_user_input}):
141124
full_response += response
142125
message_placeholder.markdown(full_response + "▌")

0 commit comments

Comments
 (0)