Curate sources to avoid the UI crashing (#1212)
* Curate sources to avoid the UI crashing * Remove sources from chat history to avoid confusing the LLM
This commit is contained in:
		
							parent
							
								
									a579c9bdc5
								
							
						
					
					
						commit
						b7647542f4
					
				|  | @ -9,18 +9,43 @@ import gradio as gr  # type: ignore | ||||||
| from fastapi import FastAPI | from fastapi import FastAPI | ||||||
| from gradio.themes.utils.colors import slate  # type: ignore | from gradio.themes.utils.colors import slate  # type: ignore | ||||||
| from llama_index.llms import ChatMessage, ChatResponse, MessageRole | from llama_index.llms import ChatMessage, ChatResponse, MessageRole | ||||||
|  | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from private_gpt.di import root_injector | from private_gpt.di import root_injector | ||||||
| from private_gpt.server.chat.chat_service import ChatService, CompletionGen | from private_gpt.server.chat.chat_service import ChatService, CompletionGen | ||||||
| from private_gpt.server.chunks.chunks_service import ChunksService | from private_gpt.server.chunks.chunks_service import Chunk, ChunksService | ||||||
| from private_gpt.server.ingest.ingest_service import IngestService | from private_gpt.server.ingest.ingest_service import IngestService | ||||||
| from private_gpt.settings.settings import settings | from private_gpt.settings.settings import settings | ||||||
| from private_gpt.ui.images import logo_svg | from private_gpt.ui.images import logo_svg | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| UI_TAB_TITLE = "My Private GPT" | UI_TAB_TITLE = "My Private GPT" | ||||||
|  | SOURCES_SEPARATOR = "\n\n Sources: \n" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Source(BaseModel): | ||||||
|  |     file: str | ||||||
|  |     page: str | ||||||
|  |     text: str | ||||||
|  | 
 | ||||||
|  |     class Config: | ||||||
|  |         frozen = True | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def curate_sources(sources: list[Chunk]) -> set["Source"]: | ||||||
|  |         curated_sources = set() | ||||||
|  | 
 | ||||||
|  |         for chunk in sources: | ||||||
|  |             doc_metadata = chunk.document.doc_metadata | ||||||
|  | 
 | ||||||
|  |             file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-" | ||||||
|  |             page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-" | ||||||
|  | 
 | ||||||
|  |             source = Source(file=file_name, page=page_label, text=chunk.text) | ||||||
|  |             curated_sources.add(source) | ||||||
|  | 
 | ||||||
|  |         return curated_sources | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PrivateGptUi: | class PrivateGptUi: | ||||||
|  | @ -44,21 +69,11 @@ class PrivateGptUi: | ||||||
|                 yield full_response |                 yield full_response | ||||||
| 
 | 
 | ||||||
|             if completion_gen.sources: |             if completion_gen.sources: | ||||||
|                 full_response += "\n\n Sources: \n" |                 full_response += SOURCES_SEPARATOR | ||||||
|                 sources = ( |                 cur_sources = Source.curate_sources(completion_gen.sources) | ||||||
|                     { |  | ||||||
|                         "file": chunk.document.doc_metadata["file_name"] |  | ||||||
|                         if chunk.document.doc_metadata |  | ||||||
|                         else "", |  | ||||||
|                         "page": chunk.document.doc_metadata["page_label"] |  | ||||||
|                         if chunk.document.doc_metadata |  | ||||||
|                         else "", |  | ||||||
|                     } |  | ||||||
|                     for chunk in completion_gen.sources |  | ||||||
|                 ) |  | ||||||
|                 sources_text = "\n\n\n".join( |                 sources_text = "\n\n\n".join( | ||||||
|                     f"{index}. {source['file']} (page {source['page']})" |                     f"{index}. {source.file} (page {source.page})" | ||||||
|                     for index, source in enumerate(sources, start=1) |                     for index, source in enumerate(cur_sources, start=1) | ||||||
|                 ) |                 ) | ||||||
|                 full_response += sources_text |                 full_response += sources_text | ||||||
|             yield full_response |             yield full_response | ||||||
|  | @ -70,7 +85,9 @@ class PrivateGptUi: | ||||||
|                         [ |                         [ | ||||||
|                             ChatMessage(content=interaction[0], role=MessageRole.USER), |                             ChatMessage(content=interaction[0], role=MessageRole.USER), | ||||||
|                             ChatMessage( |                             ChatMessage( | ||||||
|                                 content=interaction[1], role=MessageRole.ASSISTANT |                                 # Remove from history content the Sources information | ||||||
|  |                                 content=interaction[1].split(SOURCES_SEPARATOR)[0], | ||||||
|  |                                 role=MessageRole.ASSISTANT, | ||||||
|                             ), |                             ), | ||||||
|                         ] |                         ] | ||||||
|                         for interaction in history |                         for interaction in history | ||||||
|  | @ -103,11 +120,13 @@ class PrivateGptUi: | ||||||
|                     text=message, limit=4, prev_next_chunks=0 |                     text=message, limit=4, prev_next_chunks=0 | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |                 sources = Source.curate_sources(response) | ||||||
|  | 
 | ||||||
|                 yield "\n\n\n".join( |                 yield "\n\n\n".join( | ||||||
|                     f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " |                     f"{index}. **{source.file} " | ||||||
|                     f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " |                     f"(page {source.page})**\n " | ||||||
|                     f"{chunk.text}" |                     f"{source.text}" | ||||||
|                     for index, chunk in enumerate(response, start=1) |                     for index, source in enumerate(sources, start=1) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|     def _list_ingested_files(self) -> list[list[str]]: |     def _list_ingested_files(self) -> list[list[str]]: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue