Refactor UI state management (#1191)

* Added logs at generation of the UI, and generate the UI in an object
* Make ingest script more verbose in case of an error at ingestion time
* Removed the explicit state in the UI containing ingested files
* Make script of ingestion a bit more verbose by displaying stack traces
* Change the browser tab title of privateGPT ui to `My Private GPT`
This commit is contained in:
lopagela 2023-11-10 10:42:43 +01:00 committed by GitHub
parent 55e626eac7
commit a666fd5b73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 137 deletions

View File

@ -1,4 +1,5 @@
"""FastAPI app creation, logger configuration and main API routes.""" """FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any from typing import Any
import llama_index import llama_index
@ -14,6 +15,8 @@ from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import settings from private_gpt.settings.settings import settings
logger = logging.getLogger(__name__)
# Add LlamaIndex simple observability # Add LlamaIndex simple observability
llama_index.set_global_handler("simple") llama_index.set_global_handler("simple")
@ -103,6 +106,7 @@ app.include_router(health_router)
if settings.ui.enabled: if settings.ui.enabled:
from private_gpt.ui.ui import mount_in_app logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
mount_in_app(app) PrivateGptUi().mount_in_app(app)

View File

@ -1,4 +1,6 @@
"""This file should be imported only and only if you want to run the UI locally."""
import itertools import itertools
import logging
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any, TextIO from typing import Any, TextIO
@ -15,151 +17,176 @@ from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg from private_gpt.ui.images import logo_svg
ingest_service = root_injector.get(IngestService) logger = logging.getLogger(__name__)
chat_service = root_injector.get(ChatService)
chunks_service = root_injector.get(ChunksService)
def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: UI_TAB_TITLE = "My Private GPT"
def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]:
full_response: str = ""
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list( class PrivateGptUi:
itertools.chain( def __init__(self) -> None:
*[ self._ingest_service = root_injector.get(IngestService)
[ self._chat_service = root_injector.get(ChatService)
ChatMessage(content=interaction[0], role=MessageRole.USER), self._chunks_service = root_injector.get(ChunksService)
ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT),
# Cache the UI blocks
self._ui_block = None
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]:
full_response: str = ""
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
content=interaction[1], role=MessageRole.ASSISTANT
),
]
for interaction in history
] ]
for interaction in history )
]
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
match mode:
case "Query Docs":
query_stream = chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)
case "LLM Chat":
llm_stream = chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case "Search in Docs":
response = chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
) )
yield "\n\n\n".join( # max 20 messages to try to avoid context overflow
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " return history_messages[:20]
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}" new_message = ChatMessage(content=message, role=MessageRole.USER)
for index, chunk in enumerate(response, start=1) all_messages = [*build_history(), new_message]
match mode:
case "Query Docs":
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)
case "LLM Chat":
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case "Search in Docs":
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)
yield "\n\n\n".join(
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} "
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}"
for index, chunk in enumerate(response, start=1)
)
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
) )
files.add(file_name)
return [[row] for row in files]
def _upload_file(self, file: TextIO) -> None:
path = Path(file.name)
self._ingest_service.ingest(file_name=path.name, file_data=path)
def _list_ingested_files() -> list[str]: def _build_ui_blocks(self) -> gr.Blocks:
files = set() logger.debug("Creating the UI blocks")
for ingested_document in ingest_service.list_ingested(): with gr.Blocks(
if ingested_document.doc_metadata is not None: title=UI_TAB_TITLE,
files.add( theme=gr.themes.Soft(primary_hue=slate),
ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]" css=".logo { "
) "display:flex;"
return list(files) "background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
with gr.Row():
with gr.Column(scale=3, variant="compact"):
mode = gr.Radio(
["Query Docs", "Search in Docs", "LLM Chat"],
label="Mode",
value="Query Docs",
)
upload_button = gr.components.UploadButton(
"Upload a File",
type="file",
file_count="single",
size="sm",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
self._upload_file,
inputs=upload_button,
outputs=ingested_dataset,
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
with gr.Column(scale=7):
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}",
show_copy_button=True,
render=False,
avatar_images=(
None,
"https://lh3.googleusercontent.com/drive-viewer/AK7aPa"
"AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84"
"IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560",
),
),
additional_inputs=[mode, upload_button],
)
return blocks
# Global state def get_ui_blocks(self) -> gr.Blocks:
_uploaded_file_list = [[row] for row in _list_ingested_files()] if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI) -> None:
def _upload_file(file: TextIO) -> list[list[str]]: blocks = self.get_ui_blocks()
path = Path(file.name) blocks.queue()
ingest_service.ingest(file_name=path.name, file_data=path) base_path = settings.ui.path
_uploaded_file_list.append([path.name]) logger.info("Mounting the gradio UI, at path=%s", base_path)
return _uploaded_file_list gr.mount_gradio_app(app, blocks, path=base_path)
with gr.Blocks(
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
with gr.Row():
with gr.Column(scale=3, variant="compact"):
mode = gr.Radio(
["Query Docs", "Search in Docs", "LLM Chat"],
label="Mode",
value="Query Docs",
)
upload_button = gr.components.UploadButton(
"Upload a File",
type="file",
file_count="single",
size="sm",
)
ingested_dataset = gr.List(
_uploaded_file_list,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
_upload_file, inputs=upload_button, outputs=ingested_dataset
)
ingested_dataset.render()
with gr.Column(scale=7):
chatbot = gr.ChatInterface(
_chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}",
show_copy_button=True,
render=False,
avatar_images=(
None,
"https://lh3.googleusercontent.com/drive-viewer/AK7aPa"
"AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84"
"IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560",
),
),
additional_inputs=[mode, upload_button],
)
def mount_in_app(app: FastAPI) -> None:
blocks.queue()
gr.mount_gradio_app(app, blocks, path=settings.ui.path)
if __name__ == "__main__": if __name__ == "__main__":
blocks.queue() ui = PrivateGptUi()
blocks.launch(debug=False, show_api=False) _blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)

View File

@ -69,8 +69,10 @@ def _do_ingest(changed_path: Path) -> None:
logger.info(f"Started ingesting {changed_path}") logger.info(f"Started ingesting {changed_path}")
ingest_service.ingest(changed_path.name, changed_path) ingest_service.ingest(changed_path.name, changed_path)
logger.info(f"Completed ingesting {changed_path}") logger.info(f"Completed ingesting {changed_path}")
except Exception as e: except Exception:
logger.error(f"Failed to ingest document: {changed_path}. Error: {e}") logger.exception(
f"Failed to ingest document: {changed_path}, find the exception attached"
)
path = Path(args.folder) path = Path(args.folder)