Add sources to completions APIs and UI (#1206)
This commit is contained in:
		
							parent
							
								
									dbd99e7b4b
								
							
						
					
					
						commit
						a22969ad1f
					
				
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -5,6 +5,8 @@ from collections.abc import Iterator | |||
| from llama_index.llms import ChatResponse, CompletionResponse | ||||
| from pydantic import BaseModel, Field | ||||
| 
 | ||||
| from private_gpt.server.chunks.chunks_service import Chunk | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIDelta(BaseModel): | ||||
|     """A piece of completion that needs to be concatenated to get the full message.""" | ||||
|  | @ -27,11 +29,13 @@ class OpenAIChoice(BaseModel): | |||
|     """Response from AI. | ||||
| 
 | ||||
|     Either the delta or the message will be present, but never both. | ||||
|     Sources used will be returned in case context retrieval was enabled. | ||||
|     """ | ||||
| 
 | ||||
|     finish_reason: str | None = Field(examples=["stop"]) | ||||
|     delta: OpenAIDelta | None = None | ||||
|     message: OpenAIMessage | None = None | ||||
|     sources: list[Chunk] | None = None | ||||
|     index: int = 0 | ||||
| 
 | ||||
| 
 | ||||
|  | @ -49,7 +53,10 @@ class OpenAICompletion(BaseModel): | |||
| 
 | ||||
|     @classmethod | ||||
|     def from_text( | ||||
|         cls, text: str | None, finish_reason: str | None = None | ||||
|         cls, | ||||
|         text: str | None, | ||||
|         finish_reason: str | None = None, | ||||
|         sources: list[Chunk] | None = None, | ||||
|     ) -> "OpenAICompletion": | ||||
|         return OpenAICompletion( | ||||
|             id=str(uuid.uuid4()), | ||||
|  | @ -60,13 +67,18 @@ class OpenAICompletion(BaseModel): | |||
|                 OpenAIChoice( | ||||
|                     message=OpenAIMessage(role="assistant", content=text), | ||||
|                     finish_reason=finish_reason, | ||||
|                     sources=sources, | ||||
|                 ) | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def json_from_delta( | ||||
|         cls, *, text: str | None, finish_reason: str | None = None | ||||
|         cls, | ||||
|         *, | ||||
|         text: str | None, | ||||
|         finish_reason: str | None = None, | ||||
|         sources: list[Chunk] | None = None, | ||||
|     ) -> str: | ||||
|         chunk = OpenAICompletion( | ||||
|             id=str(uuid.uuid4()), | ||||
|  | @ -77,6 +89,7 @@ class OpenAICompletion(BaseModel): | |||
|                 OpenAIChoice( | ||||
|                     delta=OpenAIDelta(content=text), | ||||
|                     finish_reason=finish_reason, | ||||
|                     sources=sources, | ||||
|                 ) | ||||
|             ], | ||||
|         ) | ||||
|  | @ -84,20 +97,25 @@ class OpenAICompletion(BaseModel): | |||
|         return chunk.model_dump_json() | ||||
| 
 | ||||
| 
 | ||||
| def to_openai_response(response: str | ChatResponse) -> OpenAICompletion: | ||||
| def to_openai_response( | ||||
|     response: str | ChatResponse, sources: list[Chunk] | None = None | ||||
| ) -> OpenAICompletion: | ||||
|     if isinstance(response, ChatResponse): | ||||
|         return OpenAICompletion.from_text(response.delta, finish_reason="stop") | ||||
|     else: | ||||
|         return OpenAICompletion.from_text(response, finish_reason="stop") | ||||
|         return OpenAICompletion.from_text( | ||||
|             response, finish_reason="stop", sources=sources | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def to_openai_sse_stream( | ||||
|     response_generator: Iterator[str | CompletionResponse | ChatResponse], | ||||
|     sources: list[Chunk] | None = None, | ||||
| ) -> Iterator[str]: | ||||
|     for response in response_generator: | ||||
|         if isinstance(response, CompletionResponse | ChatResponse): | ||||
|             yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n" | ||||
|         else: | ||||
|             yield f"data: {OpenAICompletion.json_from_delta(text=response)}\n\n" | ||||
|             yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n" | ||||
|     yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n" | ||||
|     yield "data: [DONE]\n\n" | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ class ChatBody(BaseModel): | |||
|     messages: list[OpenAIMessage] | ||||
|     use_context: bool = False | ||||
|     context_filter: ContextFilter | None = None | ||||
|     include_sources: bool = True | ||||
|     stream: bool = False | ||||
| 
 | ||||
|     model_config = { | ||||
|  | @ -34,6 +35,7 @@ class ChatBody(BaseModel): | |||
|                     ], | ||||
|                     "stream": False, | ||||
|                     "use_context": True, | ||||
|                     "include_sources": True, | ||||
|                     "context_filter": { | ||||
|                         "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"] | ||||
|                     }, | ||||
|  | @ -58,6 +60,9 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: | |||
|     Ingested documents IDs can be found using `/ingest/list` endpoint. If you want | ||||
|     all ingested documents to be used, remove `context_filter` altogether. | ||||
| 
 | ||||
|     When using `'include_sources': true`, the API will return the source Chunks used | ||||
|     to create the response, which come from the context provided. | ||||
| 
 | ||||
|     When using `'stream': true`, the API will return data chunks following [OpenAI's | ||||
|     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): | ||||
|     ``` | ||||
|  | @ -71,12 +76,18 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: | |||
|         ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages | ||||
|     ] | ||||
|     if body.stream: | ||||
|         stream = service.stream_chat( | ||||
|         completion_gen = service.stream_chat( | ||||
|             all_messages, body.use_context, body.context_filter | ||||
|         ) | ||||
|         return StreamingResponse( | ||||
|             to_openai_sse_stream(stream), media_type="text/event-stream" | ||||
|             to_openai_sse_stream( | ||||
|                 completion_gen.response, | ||||
|                 completion_gen.sources if body.include_sources else None, | ||||
|             ), | ||||
|             media_type="text/event-stream", | ||||
|         ) | ||||
|     else: | ||||
|         response = service.chat(all_messages, body.use_context, body.context_filter) | ||||
|         return to_openai_response(response) | ||||
|         completion = service.chat(all_messages, body.use_context, body.context_filter) | ||||
|         return to_openai_response( | ||||
|             completion.response, completion.sources if body.include_sources else None | ||||
|         ) | ||||
|  |  | |||
|  | @ -1,13 +1,14 @@ | |||
| from collections.abc import Sequence | ||||
| from typing import TYPE_CHECKING, Any | ||||
| 
 | ||||
| from injector import inject, singleton | ||||
| from llama_index import ServiceContext, StorageContext, VectorStoreIndex | ||||
| from llama_index.chat_engine import ContextChatEngine | ||||
| from llama_index.chat_engine.types import ( | ||||
|     BaseChatEngine, | ||||
| ) | ||||
| from llama_index.indices.postprocessor import MetadataReplacementPostProcessor | ||||
| from llama_index.llm_predictor.utils import stream_chat_response_to_tokens | ||||
| from llama_index.llms import ChatMessage | ||||
| from llama_index.types import TokenGen | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from private_gpt.components.embedding.embedding_component import EmbeddingComponent | ||||
| from private_gpt.components.llm.llm_component import LLMComponent | ||||
|  | @ -16,12 +17,17 @@ from private_gpt.components.vector_store.vector_store_component import ( | |||
|     VectorStoreComponent, | ||||
| ) | ||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||
| from private_gpt.server.chunks.chunks_service import Chunk | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from llama_index.chat_engine.types import ( | ||||
|         AgentChatResponse, | ||||
|         StreamingAgentChatResponse, | ||||
|     ) | ||||
| 
 | ||||
| class Completion(BaseModel): | ||||
|     response: str | ||||
|     sources: list[Chunk] | None = None | ||||
| 
 | ||||
| 
 | ||||
| class CompletionGen(BaseModel): | ||||
|     response: TokenGen | ||||
|     sources: list[Chunk] | None = None | ||||
| 
 | ||||
| 
 | ||||
| @singleton | ||||
|  | @ -51,66 +57,64 @@ class ChatService: | |||
|             show_progress=True, | ||||
|         ) | ||||
| 
 | ||||
|     def _chat_with_contex( | ||||
|         self, | ||||
|         message: str, | ||||
|         context_filter: ContextFilter | None = None, | ||||
|         chat_history: Sequence[ChatMessage] | None = None, | ||||
|         streaming: bool = False, | ||||
|     ) -> Any: | ||||
|     def _chat_engine( | ||||
|         self, context_filter: ContextFilter | None = None | ||||
|     ) -> BaseChatEngine: | ||||
|         vector_index_retriever = self.vector_store_component.get_retriever( | ||||
|             index=self.index, context_filter=context_filter | ||||
|         ) | ||||
|         chat_engine = ContextChatEngine.from_defaults( | ||||
|         return ContextChatEngine.from_defaults( | ||||
|             retriever=vector_index_retriever, | ||||
|             service_context=self.service_context, | ||||
|             node_postprocessors=[ | ||||
|                 MetadataReplacementPostProcessor(target_metadata_key="window"), | ||||
|             ], | ||||
|         ) | ||||
|         if streaming: | ||||
|             result = chat_engine.stream_chat(message, chat_history) | ||||
|         else: | ||||
|             result = chat_engine.chat(message, chat_history) | ||||
|         return result | ||||
| 
 | ||||
|     def stream_chat( | ||||
|         self, | ||||
|         messages: list[ChatMessage], | ||||
|         use_context: bool = False, | ||||
|         context_filter: ContextFilter | None = None, | ||||
|     ) -> TokenGen: | ||||
|     ) -> CompletionGen: | ||||
|         if use_context: | ||||
|             last_message = messages[-1].content | ||||
|             response: StreamingAgentChatResponse = self._chat_with_contex( | ||||
|             chat_engine = self._chat_engine(context_filter=context_filter) | ||||
|             streaming_response = chat_engine.stream_chat( | ||||
|                 message=last_message if last_message is not None else "", | ||||
|                 chat_history=messages[:-1], | ||||
|                 context_filter=context_filter, | ||||
|                 streaming=True, | ||||
|             ) | ||||
|             response_gen = response.response_gen | ||||
|             sources = [ | ||||
|                 Chunk.from_node(node) for node in streaming_response.source_nodes | ||||
|             ] | ||||
|             completion_gen = CompletionGen( | ||||
|                 response=streaming_response.response_gen, sources=sources | ||||
|             ) | ||||
|         else: | ||||
|             stream = self.llm_service.llm.stream_chat(messages) | ||||
|             response_gen = stream_chat_response_to_tokens(stream) | ||||
|         return response_gen | ||||
|             completion_gen = CompletionGen( | ||||
|                 response=stream_chat_response_to_tokens(stream) | ||||
|             ) | ||||
|         return completion_gen | ||||
| 
 | ||||
|     def chat( | ||||
|         self, | ||||
|         messages: list[ChatMessage], | ||||
|         use_context: bool = False, | ||||
|         context_filter: ContextFilter | None = None, | ||||
|     ) -> str: | ||||
|     ) -> Completion: | ||||
|         if use_context: | ||||
|             last_message = messages[-1].content | ||||
|             wrapped_response: AgentChatResponse = self._chat_with_contex( | ||||
|             chat_engine = self._chat_engine(context_filter=context_filter) | ||||
|             wrapped_response = chat_engine.chat( | ||||
|                 message=last_message if last_message is not None else "", | ||||
|                 chat_history=messages[:-1], | ||||
|                 context_filter=context_filter, | ||||
|                 streaming=False, | ||||
|             ) | ||||
|             response = wrapped_response.response | ||||
|             sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] | ||||
|             completion = Completion(response=wrapped_response.response, sources=sources) | ||||
|         else: | ||||
|             chat_response = self.llm_service.llm.chat(messages) | ||||
|             response_content = chat_response.message.content | ||||
|             response = response_content if response_content is not None else "" | ||||
|         return response | ||||
|             completion = Completion(response=response) | ||||
|         return completion | ||||
|  |  | |||
|  | @ -24,15 +24,31 @@ class Chunk(BaseModel): | |||
|     document: IngestedDoc | ||||
|     text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) | ||||
|     previous_texts: list[str] | None = Field( | ||||
|         examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]] | ||||
|         default=None, | ||||
|         examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]], | ||||
|     ) | ||||
|     next_texts: list[str] | None = Field( | ||||
|         default=None, | ||||
|         examples=[ | ||||
|             [ | ||||
|                 "New leads came from Google Ads campaign.", | ||||
|                 "The campaign was run by the Marketing Department", | ||||
|             ] | ||||
|         ] | ||||
|         ], | ||||
|     ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk": | ||||
|         doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" | ||||
|         return cls( | ||||
|             object="context.chunk", | ||||
|             score=node.score or 0.0, | ||||
|             document=IngestedDoc( | ||||
|                 object="ingest.document", | ||||
|                 doc_id=doc_id, | ||||
|                 doc_metadata=node.metadata, | ||||
|             ), | ||||
|             text=node.get_content(), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -98,22 +114,11 @@ class ChunksService: | |||
| 
 | ||||
|         retrieved_nodes = [] | ||||
|         for node in nodes: | ||||
|             doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" | ||||
|             retrieved_nodes.append( | ||||
|                 Chunk( | ||||
|                     object="context.chunk", | ||||
|                     score=node.score or 0.0, | ||||
|                     document=IngestedDoc( | ||||
|                         object="ingest.document", | ||||
|                         doc_id=doc_id, | ||||
|                         doc_metadata=node.metadata, | ||||
|                     ), | ||||
|                     text=node.get_content(), | ||||
|                     previous_texts=self._get_sibling_nodes_text( | ||||
|             chunk = Chunk.from_node(node) | ||||
|             chunk.previous_texts = self._get_sibling_nodes_text( | ||||
|                 node, prev_next_chunks, False | ||||
|                     ), | ||||
|                     next_texts=self._get_sibling_nodes_text(node, prev_next_chunks), | ||||
|                 ) | ||||
|             ) | ||||
|             chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks) | ||||
|             retrieved_nodes.append(chunk) | ||||
| 
 | ||||
|         return retrieved_nodes | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ class CompletionsBody(BaseModel): | |||
|     prompt: str | ||||
|     use_context: bool = False | ||||
|     context_filter: ContextFilter | None = None | ||||
|     include_sources: bool = True | ||||
|     stream: bool = False | ||||
| 
 | ||||
|     model_config = { | ||||
|  | @ -25,6 +26,7 @@ class CompletionsBody(BaseModel): | |||
|                     "prompt": "How do you fry an egg?", | ||||
|                     "stream": False, | ||||
|                     "use_context": False, | ||||
|                     "include_sources": False, | ||||
|                 } | ||||
|             ] | ||||
|         } | ||||
|  | @ -48,6 +50,9 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp | |||
|     can be found using `/ingest/list` endpoint. If you want all ingested documents to | ||||
|     be used, remove `context_filter` altogether. | ||||
| 
 | ||||
|     When using `'include_sources': true`, the API will return the source Chunks used | ||||
|     to create the response, which come from the context provided. | ||||
| 
 | ||||
|     When using `'stream': true`, the API will return data chunks following [OpenAI's | ||||
|     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): | ||||
|     ``` | ||||
|  | @ -61,6 +66,7 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp | |||
|         messages=[message], | ||||
|         use_context=body.use_context, | ||||
|         stream=body.stream, | ||||
|         include_sources=body.include_sources, | ||||
|         context_filter=body.context_filter, | ||||
|     ) | ||||
|     return chat_completion(chat_body) | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ from gradio.themes.utils.colors import slate  # type: ignore | |||
| from llama_index.llms import ChatMessage, ChatResponse, MessageRole | ||||
| 
 | ||||
| from private_gpt.di import root_injector | ||||
| from private_gpt.server.chat.chat_service import ChatService | ||||
| from private_gpt.server.chat.chat_service import ChatService, CompletionGen | ||||
| from private_gpt.server.chunks.chunks_service import ChunksService | ||||
| from private_gpt.server.ingest.ingest_service import IngestService | ||||
| from private_gpt.settings.settings import settings | ||||
|  | @ -33,8 +33,9 @@ class PrivateGptUi: | |||
|         self._ui_block = None | ||||
| 
 | ||||
|     def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | ||||
|         def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: | ||||
|         def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: | ||||
|             full_response: str = "" | ||||
|             stream = completion_gen.response | ||||
|             for delta in stream: | ||||
|                 if isinstance(delta, str): | ||||
|                     full_response += str(delta) | ||||
|  | @ -42,6 +43,26 @@ class PrivateGptUi: | |||
|                     full_response += delta.delta or "" | ||||
|                 yield full_response | ||||
| 
 | ||||
|             if completion_gen.sources: | ||||
|                 full_response += "\n\n Sources: \n" | ||||
|                 sources = ( | ||||
|                     { | ||||
|                         "file": chunk.document.doc_metadata["file_name"] | ||||
|                         if chunk.document.doc_metadata | ||||
|                         else "", | ||||
|                         "page": chunk.document.doc_metadata["page_label"] | ||||
|                         if chunk.document.doc_metadata | ||||
|                         else "", | ||||
|                     } | ||||
|                     for chunk in completion_gen.sources | ||||
|                 ) | ||||
|                 sources_text = "\n\n\n".join( | ||||
|                     f"{index}. {source['file']} (page {source['page']})" | ||||
|                     for index, source in enumerate(sources, start=1) | ||||
|                 ) | ||||
|                 full_response += sources_text | ||||
|             yield full_response | ||||
| 
 | ||||
|         def build_history() -> list[ChatMessage]: | ||||
|             history_messages: list[ChatMessage] = list( | ||||
|                 itertools.chain( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue