Refactor UI state management (#1191)
* Added logs at generation of the UI, and generate the UI in an object * Make ingest script more verbose in case of an error at ingestion time * Removed the explicit state in the UI containing ingested files * Make script of ingestion a bit more verbose by displaying stack traces * Change the browser tab title of privateGPT ui to `My Private GPT`
This commit is contained in:
		
							parent
							
								
									55e626eac7
								
							
						
					
					
						commit
						a666fd5b73
					
				|  | @ -1,4 +1,5 @@ | ||||||
| """FastAPI app creation, logger configuration and main API routes.""" | """FastAPI app creation, logger configuration and main API routes.""" | ||||||
|  | import logging | ||||||
| from typing import Any | from typing import Any | ||||||
| 
 | 
 | ||||||
| import llama_index | import llama_index | ||||||
|  | @ -14,6 +15,8 @@ from private_gpt.server.health.health_router import health_router | ||||||
| from private_gpt.server.ingest.ingest_router import ingest_router | from private_gpt.server.ingest.ingest_router import ingest_router | ||||||
| from private_gpt.settings.settings import settings | from private_gpt.settings.settings import settings | ||||||
| 
 | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
| # Add LlamaIndex simple observability | # Add LlamaIndex simple observability | ||||||
| llama_index.set_global_handler("simple") | llama_index.set_global_handler("simple") | ||||||
| 
 | 
 | ||||||
|  | @ -103,6 +106,7 @@ app.include_router(health_router) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if settings.ui.enabled: | if settings.ui.enabled: | ||||||
|     from private_gpt.ui.ui import mount_in_app |     logger.debug("Importing the UI module") | ||||||
|  |     from private_gpt.ui.ui import PrivateGptUi | ||||||
| 
 | 
 | ||||||
|     mount_in_app(app) |     PrivateGptUi().mount_in_app(app) | ||||||
|  |  | ||||||
|  | @ -1,4 +1,6 @@ | ||||||
|  | """This file should be imported only and only if you want to run the UI locally.""" | ||||||
| import itertools | import itertools | ||||||
|  | import logging | ||||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Any, TextIO | from typing import Any, TextIO | ||||||
|  | @ -15,151 +17,176 @@ from private_gpt.server.ingest.ingest_service import IngestService | ||||||
| from private_gpt.settings.settings import settings | from private_gpt.settings.settings import settings | ||||||
| from private_gpt.ui.images import logo_svg | from private_gpt.ui.images import logo_svg | ||||||
| 
 | 
 | ||||||
| ingest_service = root_injector.get(IngestService) | logger = logging.getLogger(__name__) | ||||||
| chat_service = root_injector.get(ChatService) |  | ||||||
| chunks_service = root_injector.get(ChunksService) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _chat(message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | UI_TAB_TITLE = "My Private GPT" | ||||||
|     def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: |  | ||||||
|         full_response: str = "" |  | ||||||
|         for delta in stream: |  | ||||||
|             if isinstance(delta, str): |  | ||||||
|                 full_response += str(delta) |  | ||||||
|             elif isinstance(delta, ChatResponse): |  | ||||||
|                 full_response += delta.delta or "" |  | ||||||
|             yield full_response |  | ||||||
| 
 | 
 | ||||||
|     def build_history() -> list[ChatMessage]: | 
 | ||||||
|         history_messages: list[ChatMessage] = list( | class PrivateGptUi: | ||||||
|             itertools.chain( |     def __init__(self) -> None: | ||||||
|                 *[ |         self._ingest_service = root_injector.get(IngestService) | ||||||
|                     [ |         self._chat_service = root_injector.get(ChatService) | ||||||
|                         ChatMessage(content=interaction[0], role=MessageRole.USER), |         self._chunks_service = root_injector.get(ChunksService) | ||||||
|                         ChatMessage(content=interaction[1], role=MessageRole.ASSISTANT), | 
 | ||||||
|  |         # Cache the UI blocks | ||||||
|  |         self._ui_block = None | ||||||
|  | 
 | ||||||
|  |     def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | ||||||
|  |         def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: | ||||||
|  |             full_response: str = "" | ||||||
|  |             for delta in stream: | ||||||
|  |                 if isinstance(delta, str): | ||||||
|  |                     full_response += str(delta) | ||||||
|  |                 elif isinstance(delta, ChatResponse): | ||||||
|  |                     full_response += delta.delta or "" | ||||||
|  |                 yield full_response | ||||||
|  | 
 | ||||||
|  |         def build_history() -> list[ChatMessage]: | ||||||
|  |             history_messages: list[ChatMessage] = list( | ||||||
|  |                 itertools.chain( | ||||||
|  |                     *[ | ||||||
|  |                         [ | ||||||
|  |                             ChatMessage(content=interaction[0], role=MessageRole.USER), | ||||||
|  |                             ChatMessage( | ||||||
|  |                                 content=interaction[1], role=MessageRole.ASSISTANT | ||||||
|  |                             ), | ||||||
|  |                         ] | ||||||
|  |                         for interaction in history | ||||||
|                     ] |                     ] | ||||||
|                     for interaction in history |                 ) | ||||||
|                 ] |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # max 20 messages to try to avoid context overflow |  | ||||||
|         return history_messages[:20] |  | ||||||
| 
 |  | ||||||
|     new_message = ChatMessage(content=message, role=MessageRole.USER) |  | ||||||
|     all_messages = [*build_history(), new_message] |  | ||||||
|     match mode: |  | ||||||
|         case "Query Docs": |  | ||||||
|             query_stream = chat_service.stream_chat( |  | ||||||
|                 messages=all_messages, |  | ||||||
|                 use_context=True, |  | ||||||
|             ) |  | ||||||
|             yield from yield_deltas(query_stream) |  | ||||||
| 
 |  | ||||||
|         case "LLM Chat": |  | ||||||
|             llm_stream = chat_service.stream_chat( |  | ||||||
|                 messages=all_messages, |  | ||||||
|                 use_context=False, |  | ||||||
|             ) |  | ||||||
|             yield from yield_deltas(llm_stream) |  | ||||||
| 
 |  | ||||||
|         case "Search in Docs": |  | ||||||
|             response = chunks_service.retrieve_relevant( |  | ||||||
|                 text=message, limit=4, prev_next_chunks=0 |  | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             yield "\n\n\n".join( |             # max 20 messages to try to avoid context overflow | ||||||
|                 f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " |             return history_messages[:20] | ||||||
|                 f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " | 
 | ||||||
|                 f"{chunk.text}" |         new_message = ChatMessage(content=message, role=MessageRole.USER) | ||||||
|                 for index, chunk in enumerate(response, start=1) |         all_messages = [*build_history(), new_message] | ||||||
|  |         match mode: | ||||||
|  |             case "Query Docs": | ||||||
|  |                 query_stream = self._chat_service.stream_chat( | ||||||
|  |                     messages=all_messages, | ||||||
|  |                     use_context=True, | ||||||
|  |                 ) | ||||||
|  |                 yield from yield_deltas(query_stream) | ||||||
|  | 
 | ||||||
|  |             case "LLM Chat": | ||||||
|  |                 llm_stream = self._chat_service.stream_chat( | ||||||
|  |                     messages=all_messages, | ||||||
|  |                     use_context=False, | ||||||
|  |                 ) | ||||||
|  |                 yield from yield_deltas(llm_stream) | ||||||
|  | 
 | ||||||
|  |             case "Search in Docs": | ||||||
|  |                 response = self._chunks_service.retrieve_relevant( | ||||||
|  |                     text=message, limit=4, prev_next_chunks=0 | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 yield "\n\n\n".join( | ||||||
|  |                     f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " | ||||||
|  |                     f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " | ||||||
|  |                     f"{chunk.text}" | ||||||
|  |                     for index, chunk in enumerate(response, start=1) | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |     def _list_ingested_files(self) -> list[list[str]]: | ||||||
|  |         files = set() | ||||||
|  |         for ingested_document in self._ingest_service.list_ingested(): | ||||||
|  |             if ingested_document.doc_metadata is None: | ||||||
|  |                 # Skipping documents without metadata | ||||||
|  |                 continue | ||||||
|  |             file_name = ingested_document.doc_metadata.get( | ||||||
|  |                 "file_name", "[FILE NAME MISSING]" | ||||||
|             ) |             ) | ||||||
|  |             files.add(file_name) | ||||||
|  |         return [[row] for row in files] | ||||||
| 
 | 
 | ||||||
|  |     def _upload_file(self, file: TextIO) -> None: | ||||||
|  |         path = Path(file.name) | ||||||
|  |         self._ingest_service.ingest(file_name=path.name, file_data=path) | ||||||
| 
 | 
 | ||||||
| def _list_ingested_files() -> list[str]: |     def _build_ui_blocks(self) -> gr.Blocks: | ||||||
|     files = set() |         logger.debug("Creating the UI blocks") | ||||||
|     for ingested_document in ingest_service.list_ingested(): |         with gr.Blocks( | ||||||
|         if ingested_document.doc_metadata is not None: |             title=UI_TAB_TITLE, | ||||||
|             files.add( |             theme=gr.themes.Soft(primary_hue=slate), | ||||||
|                 ingested_document.doc_metadata.get("file_name") or "[FILE NAME MISSING]" |             css=".logo { " | ||||||
|             ) |             "display:flex;" | ||||||
|     return list(files) |             "background-color: #C7BAFF;" | ||||||
|  |             "height: 80px;" | ||||||
|  |             "border-radius: 8px;" | ||||||
|  |             "align-content: center;" | ||||||
|  |             "justify-content: center;" | ||||||
|  |             "align-items: center;" | ||||||
|  |             "}" | ||||||
|  |             ".logo img { height: 25% }", | ||||||
|  |         ) as blocks: | ||||||
|  |             with gr.Row(): | ||||||
|  |                 gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div") | ||||||
| 
 | 
 | ||||||
|  |             with gr.Row(): | ||||||
|  |                 with gr.Column(scale=3, variant="compact"): | ||||||
|  |                     mode = gr.Radio( | ||||||
|  |                         ["Query Docs", "Search in Docs", "LLM Chat"], | ||||||
|  |                         label="Mode", | ||||||
|  |                         value="Query Docs", | ||||||
|  |                     ) | ||||||
|  |                     upload_button = gr.components.UploadButton( | ||||||
|  |                         "Upload a File", | ||||||
|  |                         type="file", | ||||||
|  |                         file_count="single", | ||||||
|  |                         size="sm", | ||||||
|  |                     ) | ||||||
|  |                     ingested_dataset = gr.List( | ||||||
|  |                         self._list_ingested_files, | ||||||
|  |                         headers=["File name"], | ||||||
|  |                         label="Ingested Files", | ||||||
|  |                         interactive=False, | ||||||
|  |                         render=False,  # Rendered under the button | ||||||
|  |                     ) | ||||||
|  |                     upload_button.upload( | ||||||
|  |                         self._upload_file, | ||||||
|  |                         inputs=upload_button, | ||||||
|  |                         outputs=ingested_dataset, | ||||||
|  |                     ) | ||||||
|  |                     ingested_dataset.change( | ||||||
|  |                         self._list_ingested_files, | ||||||
|  |                         outputs=ingested_dataset, | ||||||
|  |                     ) | ||||||
|  |                     ingested_dataset.render() | ||||||
|  |                 with gr.Column(scale=7): | ||||||
|  |                     _ = gr.ChatInterface( | ||||||
|  |                         self._chat, | ||||||
|  |                         chatbot=gr.Chatbot( | ||||||
|  |                             label=f"LLM: {settings.llm.mode}", | ||||||
|  |                             show_copy_button=True, | ||||||
|  |                             render=False, | ||||||
|  |                             avatar_images=( | ||||||
|  |                                 None, | ||||||
|  |                                 "https://lh3.googleusercontent.com/drive-viewer/AK7aPa" | ||||||
|  |                                 "AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84" | ||||||
|  |                                 "IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560", | ||||||
|  |                             ), | ||||||
|  |                         ), | ||||||
|  |                         additional_inputs=[mode, upload_button], | ||||||
|  |                     ) | ||||||
|  |         return blocks | ||||||
| 
 | 
 | ||||||
| # Global state |     def get_ui_blocks(self) -> gr.Blocks: | ||||||
| _uploaded_file_list = [[row] for row in _list_ingested_files()] |         if self._ui_block is None: | ||||||
|  |             self._ui_block = self._build_ui_blocks() | ||||||
|  |         return self._ui_block | ||||||
| 
 | 
 | ||||||
| 
 |     def mount_in_app(self, app: FastAPI) -> None: | ||||||
| def _upload_file(file: TextIO) -> list[list[str]]: |         blocks = self.get_ui_blocks() | ||||||
|     path = Path(file.name) |         blocks.queue() | ||||||
|     ingest_service.ingest(file_name=path.name, file_data=path) |         base_path = settings.ui.path | ||||||
|     _uploaded_file_list.append([path.name]) |         logger.info("Mounting the gradio UI, at path=%s", base_path) | ||||||
|     return _uploaded_file_list |         gr.mount_gradio_app(app, blocks, path=base_path) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| with gr.Blocks( |  | ||||||
|     theme=gr.themes.Soft(primary_hue=slate), |  | ||||||
|     css=".logo { " |  | ||||||
|     "display:flex;" |  | ||||||
|     "background-color: #C7BAFF;" |  | ||||||
|     "height: 80px;" |  | ||||||
|     "border-radius: 8px;" |  | ||||||
|     "align-content: center;" |  | ||||||
|     "justify-content: center;" |  | ||||||
|     "align-items: center;" |  | ||||||
|     "}" |  | ||||||
|     ".logo img { height: 25% }", |  | ||||||
| ) as blocks: |  | ||||||
|     with gr.Row(): |  | ||||||
|         gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div") |  | ||||||
| 
 |  | ||||||
|     with gr.Row(): |  | ||||||
|         with gr.Column(scale=3, variant="compact"): |  | ||||||
|             mode = gr.Radio( |  | ||||||
|                 ["Query Docs", "Search in Docs", "LLM Chat"], |  | ||||||
|                 label="Mode", |  | ||||||
|                 value="Query Docs", |  | ||||||
|             ) |  | ||||||
|             upload_button = gr.components.UploadButton( |  | ||||||
|                 "Upload a File", |  | ||||||
|                 type="file", |  | ||||||
|                 file_count="single", |  | ||||||
|                 size="sm", |  | ||||||
|             ) |  | ||||||
|             ingested_dataset = gr.List( |  | ||||||
|                 _uploaded_file_list, |  | ||||||
|                 headers=["File name"], |  | ||||||
|                 label="Ingested Files", |  | ||||||
|                 interactive=False, |  | ||||||
|                 render=False,  # Rendered under the button |  | ||||||
|             ) |  | ||||||
|             upload_button.upload( |  | ||||||
|                 _upload_file, inputs=upload_button, outputs=ingested_dataset |  | ||||||
|             ) |  | ||||||
|             ingested_dataset.render() |  | ||||||
|         with gr.Column(scale=7): |  | ||||||
|             chatbot = gr.ChatInterface( |  | ||||||
|                 _chat, |  | ||||||
|                 chatbot=gr.Chatbot( |  | ||||||
|                     label=f"LLM: {settings.llm.mode}", |  | ||||||
|                     show_copy_button=True, |  | ||||||
|                     render=False, |  | ||||||
|                     avatar_images=( |  | ||||||
|                         None, |  | ||||||
|                         "https://lh3.googleusercontent.com/drive-viewer/AK7aPa" |  | ||||||
|                         "AicXck0k68nsscyfKrb18o9ak3BSaWM_Qzm338cKoQlw72Bp0UKN84" |  | ||||||
|                         "IFZjXjZApY01mtnUXDeL4qzwhkALoe_53AhwCg=s2560", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 additional_inputs=[mode, upload_button], |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def mount_in_app(app: FastAPI) -> None: |  | ||||||
|     blocks.queue() |  | ||||||
|     gr.mount_gradio_app(app, blocks, path=settings.ui.path) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     blocks.queue() |     ui = PrivateGptUi() | ||||||
|     blocks.launch(debug=False, show_api=False) |     _blocks = ui.get_ui_blocks() | ||||||
|  |     _blocks.queue() | ||||||
|  |     _blocks.launch(debug=False, show_api=False) | ||||||
|  |  | ||||||
|  | @ -69,8 +69,10 @@ def _do_ingest(changed_path: Path) -> None: | ||||||
|             logger.info(f"Started ingesting {changed_path}") |             logger.info(f"Started ingesting {changed_path}") | ||||||
|             ingest_service.ingest(changed_path.name, changed_path) |             ingest_service.ingest(changed_path.name, changed_path) | ||||||
|             logger.info(f"Completed ingesting {changed_path}") |             logger.info(f"Completed ingesting {changed_path}") | ||||||
|     except Exception as e: |     except Exception: | ||||||
|         logger.error(f"Failed to ingest document: {changed_path}. Error: {e}") |         logger.exception( | ||||||
|  |             f"Failed to ingest document: {changed_path}, find the exception attached" | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| path = Path(args.folder) | path = Path(args.folder) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue