Refactor UI state management (#1191)
* Added logs at generation of the UI, and generate the UI in an object * Make ingest script more verbose in case of an error at ingestion time * Removed the explicit state in the UI containing ingested files * Make script of ingestion a bit more verbose by displaying stack traces * Change the browser tab title of privateGPT ui to `My Private GPT`
This commit is contained in:
		
							parent
							
								
									55e626eac7
								
							
						
					
					
						commit
						a666fd5b73
					
				|  | @ -1,4 +1,5 @@ | |||
| """FastAPI app creation, logger configuration and main API routes.""" | ||||
| import logging | ||||
| from typing import Any | ||||
| 
 | ||||
| import llama_index | ||||
|  | @ -14,6 +15,8 @@ from private_gpt.server.health.health_router import health_router | |||
| from private_gpt.server.ingest.ingest_router import ingest_router | ||||
| from private_gpt.settings.settings import settings | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| # Add LlamaIndex simple observability | ||||
| llama_index.set_global_handler("simple") | ||||
| 
 | ||||
|  | @ -103,6 +106,7 @@ app.include_router(health_router) | |||
| 
 | ||||
| 
 | ||||
| if settings.ui.enabled: | ||||
|     from private_gpt.ui.ui import mount_in_app | ||||
|     logger.debug("Importing the UI module") | ||||
|     from private_gpt.ui.ui import PrivateGptUi | ||||
| 
 | ||||
|     mount_in_app(app) | ||||
|     PrivateGptUi().mount_in_app(app) | ||||
|  |  | |||
|  | @ -1,4 +1,6 @@ | |||
| """This file should be imported only and only if you want to run the UI locally.""" | ||||
| import itertools | ||||
| import logging | ||||
| from collections.abc import Iterable | ||||
| from pathlib import Path | ||||
| from typing import Any, TextIO | ||||
|  | @ -15,12 +17,22 @@ from private_gpt.server.ingest.ingest_service import IngestService | |||
| from private_gpt.settings.settings import settings | ||||
| from private_gpt.ui.images import logo_svg | ||||
| 
 | ||||
| ingest_service = root_injector.get(IngestService) | ||||
| chat_service = root_injector.get(ChatService) | ||||
| chunks_service = root_injector.get(ChunksService) | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | ||||
| UI_TAB_TITLE = "My Private GPT" | ||||
| 
 | ||||
| 
 | ||||
| class PrivateGptUi: | ||||
|     def __init__(self) -> None: | ||||
|         self._ingest_service = root_injector.get(IngestService) | ||||
|         self._chat_service = root_injector.get(ChatService) | ||||
|         self._chunks_service = root_injector.get(ChunksService) | ||||
| 
 | ||||
|         # Cache the UI blocks | ||||
|         self._ui_block = None | ||||
| 
 | ||||
|     def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | ||||
|         def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: | ||||
|             full_response: str = "" | ||||
|             for delta in stream: | ||||
|  | @ -36,7 +48,9 @@ def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | |||
|                     *[ | ||||
|                         [ | ||||
|                             ChatMessage(content=interaction[0], role=MessageRole.USER), | ||||
|                         ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT), | ||||
|                             ChatMessage( | ||||
|                                 content=interaction[1], role=MessageRole.ASSISTANT | ||||
|                             ), | ||||
|                         ] | ||||
|                         for interaction in history | ||||
|                     ] | ||||
|  | @ -50,21 +64,21 @@ def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | |||
|         all_messages = [*build_history(), new_message] | ||||
|         match mode: | ||||
|             case "Query Docs": | ||||
|             query_stream = chat_service.stream_chat( | ||||
|                 query_stream = self._chat_service.stream_chat( | ||||
|                     messages=all_messages, | ||||
|                     use_context=True, | ||||
|                 ) | ||||
|                 yield from yield_deltas(query_stream) | ||||
| 
 | ||||
|             case "LLM Chat": | ||||
|             llm_stream = chat_service.stream_chat( | ||||
|                 llm_stream = self._chat_service.stream_chat( | ||||
|                     messages=all_messages, | ||||
|                     use_context=False, | ||||
|                 ) | ||||
|                 yield from yield_deltas(llm_stream) | ||||
| 
 | ||||
|             case "Search in Docs": | ||||
|             response = chunks_service.retrieve_relevant( | ||||
|                 response = self._chunks_service.retrieve_relevant( | ||||
|                     text=message, limit=4, prev_next_chunks=0 | ||||
|                 ) | ||||
| 
 | ||||
|  | @ -75,29 +89,26 @@ def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | |||
|                     for index, chunk in enumerate(response, start=1) | ||||
|                 ) | ||||
| 
 | ||||
| 
 | ||||
| def _list_ingested_files() -> list[str]: | ||||
|     def _list_ingested_files(self) -> list[list[str]]: | ||||
|         files = set() | ||||
|     for ingested_document in ingest_service.list_ingested(): | ||||
|         if ingested_document.doc_metadata is not None: | ||||
|             files.add( | ||||
|                 ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]" | ||||
|         for ingested_document in self._ingest_service.list_ingested(): | ||||
|             if ingested_document.doc_metadata is None: | ||||
|                 # Skipping documents without metadata | ||||
|                 continue | ||||
|             file_name = ingested_document.doc_metadata.get( | ||||
|                 "file_name", "[FILE NAME MISSING]" | ||||
|             ) | ||||
|     return list(files) | ||||
|             files.add(file_name) | ||||
|         return [[row] for row in files] | ||||
| 
 | ||||
| 
 | ||||
| # Global state | ||||
| _uploaded_file_list = [[row] for row in _list_ingested_files()] | ||||
| 
 | ||||
| 
 | ||||
| def _upload_file(file: TextIO) -> list[list[str]]: | ||||
|     def _upload_file(self, file: TextIO) -> None: | ||||
|         path = Path(file.name) | ||||
|     ingest_service.ingest(file_name=path.name, file_data=path) | ||||
|     _uploaded_file_list.append([path.name]) | ||||
|     return _uploaded_file_list | ||||
|         self._ingest_service.ingest(file_name=path.name, file_data=path) | ||||
| 
 | ||||
| 
 | ||||
| with gr.Blocks( | ||||
|     def _build_ui_blocks(self) -> gr.Blocks: | ||||
|         logger.debug("Creating the UI blocks") | ||||
|         with gr.Blocks( | ||||
|             title=UI_TAB_TITLE, | ||||
|             theme=gr.themes.Soft(primary_hue=slate), | ||||
|             css=".logo { " | ||||
|             "display:flex;" | ||||
|  | @ -109,7 +120,7 @@ with gr.Blocks( | |||
|             "align-items: center;" | ||||
|             "}" | ||||
|             ".logo img { height: 25% }", | ||||
| ) as blocks: | ||||
|         ) as blocks: | ||||
|             with gr.Row(): | ||||
|                 gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div") | ||||
| 
 | ||||
|  | @ -127,19 +138,25 @@ with gr.Blocks( | |||
|                         size="sm", | ||||
|                     ) | ||||
|                     ingested_dataset = gr.List( | ||||
|                 _uploaded_file_list, | ||||
|                         self._list_ingested_files, | ||||
|                         headers=["File name"], | ||||
|                         label="Ingested Files", | ||||
|                         interactive=False, | ||||
|                         render=False,  # Rendered under the button | ||||
|                     ) | ||||
|                     upload_button.upload( | ||||
|                 _upload_file, inputs=upload_button, outputs=ingested_dataset | ||||
|                         self._upload_file, | ||||
|                         inputs=upload_button, | ||||
|                         outputs=ingested_dataset, | ||||
|                     ) | ||||
|                     ingested_dataset.change( | ||||
|                         self._list_ingested_files, | ||||
|                         outputs=ingested_dataset, | ||||
|                     ) | ||||
|                     ingested_dataset.render() | ||||
|                 with gr.Column(scale=7): | ||||
|             chatbot = gr.ChatInterface( | ||||
|                 _chat, | ||||
|                     _ = gr.ChatInterface( | ||||
|                         self._chat, | ||||
|                         chatbot=gr.Chatbot( | ||||
|                             label=f"LLM: {settings.llm.mode}", | ||||
|                             show_copy_button=True, | ||||
|  | @ -153,13 +170,23 @@ with gr.Blocks( | |||
|                         ), | ||||
|                         additional_inputs=[mode, upload_button], | ||||
|                     ) | ||||
|         return blocks | ||||
| 
 | ||||
|     def get_ui_blocks(self) -> gr.Blocks: | ||||
|         if self._ui_block is None: | ||||
|             self._ui_block = self._build_ui_blocks() | ||||
|         return self._ui_block | ||||
| 
 | ||||
| def mount_in_app(app: FastAPI) -> None: | ||||
|     def mount_in_app(self, app: FastAPI) -> None: | ||||
|         blocks = self.get_ui_blocks() | ||||
|         blocks.queue() | ||||
|     gr.mount_gradio_app(app, blocks, path=settings.ui.path) | ||||
|         base_path = settings.ui.path | ||||
|         logger.info("Mounting the gradio UI, at path=%s", base_path) | ||||
|         gr.mount_gradio_app(app, blocks, path=base_path) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     blocks.queue() | ||||
|     blocks.launch(debug=False, show_api=False) | ||||
|     ui = PrivateGptUi() | ||||
|     _blocks = ui.get_ui_blocks() | ||||
|     _blocks.queue() | ||||
|     _blocks.launch(debug=False, show_api=False) | ||||
|  |  | |||
|  | @ -69,8 +69,10 @@ def _do_ingest(changed_path: Path) -> None: | |||
|             logger.info(f"Started ingesting {changed_path}") | ||||
|             ingest_service.ingest(changed_path.name, changed_path) | ||||
|             logger.info(f"Completed ingesting {changed_path}") | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to ingest document: {changed_path}. Error: {e}") | ||||
|     except Exception: | ||||
|         logger.exception( | ||||
|             f"Failed to ingest document: {changed_path}, find the exception attached" | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| path = Path(args.folder) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue