"""This file should be imported only and only if you want to run the UI locally.""" import itertools import logging from collections.abc import Iterable from pathlib import Path from typing import Any import gradio as gr # type: ignore from fastapi import FastAPI from gradio.themes.utils.colors import slate # type: ignore from injector import inject, singleton from llama_index.llms import ChatMessage, ChatResponse, MessageRole from pydantic import BaseModel from private_gpt.constants import PROJECT_ROOT_PATH from private_gpt.di import global_injector from private_gpt.server.chat.chat_service import ChatService, CompletionGen from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.settings.settings import settings from private_gpt.ui.images import logo_svg logger = logging.getLogger(__name__) THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) # Should be "private_gpt/ui/avatar-bot.ico" AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico" UI_TAB_TITLE = "My Private GPT" SOURCES_SEPARATOR = "\n\n Sources: \n" MODES = ["Query Docs", "Search in Docs", "LLM Chat"] class Source(BaseModel): file: str page: str text: str class Config: frozen = True @staticmethod def curate_sources(sources: list[Chunk]) -> set["Source"]: curated_sources = set() for chunk in sources: doc_metadata = chunk.document.doc_metadata file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-" page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-" source = Source(file=file_name, page=page_label, text=chunk.text) curated_sources.add(source) return curated_sources @singleton class PrivateGptUi: @inject def __init__( self, ingest_service: IngestService, chat_service: ChatService, chunks_service: ChunksService, ) -> None: self._ingest_service = ingest_service self._chat_service = chat_service self._chunks_service = chunks_service # Cache the UI blocks self._ui_block = None # Initialize system prompt based on default mode self.mode = MODES[0] self._system_prompt = self._get_default_system_prompt(self.mode) def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: full_response: str = "" stream = completion_gen.response for delta in stream: if isinstance(delta, str): full_response += str(delta) elif isinstance(delta, ChatResponse): full_response += delta.delta or "" yield full_response if completion_gen.sources: full_response += SOURCES_SEPARATOR cur_sources = Source.curate_sources(completion_gen.sources) sources_text = "\n\n\n".join( f"{index}. {source.file} (page {source.page})" for index, source in enumerate(cur_sources, start=1) ) full_response += sources_text yield full_response def build_history() -> list[ChatMessage]: history_messages: list[ChatMessage] = list( itertools.chain( *[ [ ChatMessage(content=interaction[0], role=MessageRole.USER), ChatMessage( # Remove from history content the Sources information content=interaction[1].split(SOURCES_SEPARATOR)[0], role=MessageRole.ASSISTANT, ), ] for interaction in history ] ) ) # max 20 messages to try to avoid context overflow return history_messages[:20] new_message = ChatMessage(content=message, role=MessageRole.USER) all_messages = [*build_history(), new_message] # If a system prompt is set, add it as a system message if self._system_prompt: all_messages.insert( 0, ChatMessage( content=self._system_prompt, role=MessageRole.SYSTEM, ), ) match mode: case "Query Docs": query_stream = self._chat_service.stream_chat( messages=all_messages, use_context=True, ) yield from yield_deltas(query_stream) case "LLM Chat": llm_stream = self._chat_service.stream_chat( messages=all_messages, use_context=False, ) yield from yield_deltas(llm_stream) case "Search in Docs": response = self._chunks_service.retrieve_relevant( text=message, limit=4, prev_next_chunks=0 ) sources = Source.curate_sources(response) yield "\n\n\n".join( f"{index}. **{source.file} " f"(page {source.page})**\n " f"{source.text}" for index, source in enumerate(sources, start=1) ) # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @staticmethod def _get_default_system_prompt(mode: str) -> str: p = "" match mode: # For query chat mode, obtain default system prompt from settings case "Query Docs": p = settings().ui.default_query_system_prompt # For chat mode, obtain default system prompt from settings case "LLM Chat": p = settings().ui.default_chat_system_prompt # For any other mode, clear the system prompt case _: p = "" return p def _set_system_prompt(self, system_prompt_input: str) -> None: logger.info(f"Setting system prompt to: {system_prompt_input}") self._system_prompt = system_prompt_input def _set_current_mode(self, mode: str) -> Any: self.mode = mode self._set_system_prompt(self._get_default_system_prompt(mode)) # Update placeholder and allow interaction if default system prompt is set if self._system_prompt: return gr.update(placeholder=self._system_prompt, interactive=True) # Update placeholder and disable interaction if no default system prompt is set else: return gr.update(placeholder=self._system_prompt, interactive=False) def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): if ingested_document.doc_metadata is None: # Skipping documents without metadata continue file_name = ingested_document.doc_metadata.get( "file_name", "[FILE NAME MISSING]" ) files.add(file_name) return [[row] for row in files] def _upload_file(self, files: list[str]) -> None: logger.debug("Loading count=%s files", len(files)) paths = [Path(file) for file in files] self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) def _build_ui_blocks(self) -> gr.Blocks: logger.debug("Creating the UI blocks") with gr.Blocks( title=UI_TAB_TITLE, theme=gr.themes.Soft(primary_hue=slate), css=".logo { " "display:flex;" "background-color: #C7BAFF;" "height: 80px;" "border-radius: 8px;" "align-content: center;" "justify-content: center;" "align-items: center;" "}" ".logo img { height: 25% }" ".contain { display: flex !important; flex-direction: column !important; }" "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" "#col { height: calc(100vh - 112px - 16px) !important; }", ) as blocks: with gr.Row(): gr.HTML(f"