How to stream LLM response with streamlit?
fabmeyer opened this issue · 7 comments
I am following this script using RetrievalQA chain.
Code:
llm = OpenAI(client=OpenAI, streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)
chain = RetrievalQA.from_chain_type(llm=llm, chain_type='refine', retriever=docsearch.as_retriever())
...
if 'user_input' not in st.session_state:
st.session_state['user_input'] = []
if 'generated_text' not in st.session_state:
st.session_state['generated_text'] = []
user_input = st.text_area('Enter a question', value=f"What are trends in {st.session_state['thematic']['term']}?")
button_2 = st.button('Get answer')
if user_input and button_2:
st.session_state.user_input.append(user_input)
with st.spinner('Running LLM...'):
st.session_state.generated_text.append(st.session_state['chain'].run(user_input))
if 'generated_text' in st.session_state and len(st.session_state['generated_text']) > 0:
for i in range(len(st.session_state['generated_text']) - 1, -1, -1):
message(st.session_state['user_input'][i], is_user=True, key=str(i) + '_user')
message(st.session_state['generated_text'][i], key=str(i))
How can I stream the response of the LLM in real time (like on the console)?
Take a look at the new Streamlit Chat elements. It may help?
https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
Use these instead: https://docs.streamlit.io/library/api-reference/chat
@tractorjuice Yes, i read this implementation, but do not bring them up running. Streamlit write:
StreamlitAPIException: st.stream_write expects a generator or stream-like object as input not <class 'str'>. Please use st.write instead for this data type.
My code is bellow, did you have any idea how to rewrite this?
def prepare_llm(prompt):
st.chat_message(name="user", avatar=IMAGE["user"]).markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
msg = []
msg.append({"role": "assistant", "content": description})
for x in st.session_state.messages:
msg.append(x)
embeddings = OpenAIEmbeddings()
docsearch = Pinecone.from_existing_index(
index_name="langchain-index", embedding=embeddings
)
with (st.chat_message(name="assistant", avatar=IMAGE["assistant"])):
llm = ChatOpenAI(
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
temperature=0,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=docsearch.as_retriever(),
chain_type_kwargs={
"prompt": prompt,
"memory": ConversationBufferMemory(
memory_key="history",
input_key="question"),
},
)
response = st.write_stream(
qa.run({"query": "What Atlas Client?"})
)
@ucola I'm facing the same error. Were you able to resolve it?
This is working for me (with groq):
full_response = "" # Initialize outside the generator
def generate_responses(completion):
global full_response
for chunk in completion:
response = chunk.choices[0].delta.content or ""
if response:
full_response += response # Append to the full response
yield response
st.chat_message("assistant")
stream = generate_responses(completion)
st.write_stream(stream)
# After streaming
if full_response: # Check and use the full_response as needed
response_message = {"role": "assistant", "content": full_response}
# with st.chat_message("assistant"):
# st.markdown(full_response)
st.session_state.messages.append(response_message)