"""This file should be imported only and only if you want to run the UI locally.""" import itertools import logging from collections.abc import Iterable from pathlib import Path from typing import Any, TextIO import gradio as gr # type: ignore from fastapi import FastAPI from gradio.themes.utils.colors import slate # type: ignore from llama_index.llms import ChatMessage, ChatResponse, MessageRole from private_gpt.di import root_injector from private_gpt.server.chat.chat_service import ChatService, CompletionGen from private_gpt.server.chunks.chunks_service import ChunksService from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.settings.settings import settings from private_gpt.ui.images import logo_svg logger = logging.getLogger(__name__) UI_TAB_TITLE = "My Private GPT" class PrivateGptUi: def __init__(self) -> None: self._ingest_service = root_injector.get(IngestService) self._chat_service = root_injector.get(ChatService) self._chunks_service = root_injector.get(ChunksService) # Cache the UI blocks self._ui_block = None def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: full_response: str = "" stream = completion_gen.response for delta in stream: if isinstance(delta, str): full_response += str(delta) elif isinstance(delta, ChatResponse): full_response += delta.delta or "" yield full_response if completion_gen.sources: full_response += "\n\n Sources: \n" sources = ( { "file": chunk.document.doc_metadata["file_name"] if chunk.document.doc_metadata else "", "page": chunk.document.doc_metadata["page_label"] if chunk.document.doc_metadata else "", } for chunk in completion_gen.sources ) sources_text = "\n\n\n".join( f"{index}. {source['file']} (page {source['page']})" for index, source in enumerate(sources, start=1) ) full_response += sources_text yield full_response def build_history() -> list[ChatMessage]: history_messages: list[ChatMessage] = list( itertools.chain( *[ [ ChatMessage(content=interaction[0], role=MessageRole.USER), ChatMessage( content=interaction[1], role=MessageRole.ASSISTANT ), ] for interaction in history ] ) ) # max 20 messages to try to avoid context overflow return history_messages[:20] new_message = ChatMessage(content=message, role=MessageRole.USER) all_messages = [*build_history(), new_message] match mode: case "Query Docs": query_stream = self._chat_service.stream_chat( messages=all_messages, use_context=True, ) yield from yield_deltas(query_stream) case "LLM Chat": llm_stream = self._chat_service.stream_chat( messages=all_messages, use_context=False, ) yield from yield_deltas(llm_stream) case "Search in Docs": response = self._chunks_service.retrieve_relevant( text=message, limit=4, prev_next_chunks=0 ) yield "\n\n\n".join( f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " f"{chunk.text}" for index, chunk in enumerate(response, start=1) ) def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): if ingested_document.doc_metadata is None: # Skipping documents without metadata continue file_name = ingested_document.doc_metadata.get( "file_name", "[FILE NAME MISSING]" ) files.add(file_name) return [[row] for row in files] def _upload_file(self, file: TextIO) -> None: path = Path(file.name) self._ingest_service.ingest(file_name=path.name, file_data=path) def _build_ui_blocks(self) -> gr.Blocks: logger.debug("Creating the UI blocks") with gr.Blocks( title=UI_TAB_TITLE, theme=gr.themes.Soft(primary_hue=slate), css=".logo { " "display:flex;" "background-color: #C7BAFF;" "height: 80px;" "border-radius: 8px;" "align-content: center;" "justify-content: center;" "align-items: center;" "}" ".logo img { height: 25% }", ) as blocks: with gr.Row(): gr.HTML(f"