Add sources to completions APIs and UI (#1206)
This commit is contained in:
		
							parent
							
								
									dbd99e7b4b
								
							
						
					
					
						commit
						a22969ad1f
					
				
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -5,6 +5,8 @@ from collections.abc import Iterator | ||||||
| from llama_index.llms import ChatResponse, CompletionResponse | from llama_index.llms import ChatResponse, CompletionResponse | ||||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||||
| 
 | 
 | ||||||
|  | from private_gpt.server.chunks.chunks_service import Chunk | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAIDelta(BaseModel): | class OpenAIDelta(BaseModel): | ||||||
|     """A piece of completion that needs to be concatenated to get the full message.""" |     """A piece of completion that needs to be concatenated to get the full message.""" | ||||||
|  | @ -27,11 +29,13 @@ class OpenAIChoice(BaseModel): | ||||||
|     """Response from AI. |     """Response from AI. | ||||||
| 
 | 
 | ||||||
|     Either the delta or the message will be present, but never both. |     Either the delta or the message will be present, but never both. | ||||||
|  |     Sources used will be returned in case context retrieval was enabled. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     finish_reason: str | None = Field(examples=["stop"]) |     finish_reason: str | None = Field(examples=["stop"]) | ||||||
|     delta: OpenAIDelta | None = None |     delta: OpenAIDelta | None = None | ||||||
|     message: OpenAIMessage | None = None |     message: OpenAIMessage | None = None | ||||||
|  |     sources: list[Chunk] | None = None | ||||||
|     index: int = 0 |     index: int = 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -49,7 +53,10 @@ class OpenAICompletion(BaseModel): | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def from_text( |     def from_text( | ||||||
|         cls, text: str | None, finish_reason: str | None = None |         cls, | ||||||
|  |         text: str | None, | ||||||
|  |         finish_reason: str | None = None, | ||||||
|  |         sources: list[Chunk] | None = None, | ||||||
|     ) -> "OpenAICompletion": |     ) -> "OpenAICompletion": | ||||||
|         return OpenAICompletion( |         return OpenAICompletion( | ||||||
|             id=str(uuid.uuid4()), |             id=str(uuid.uuid4()), | ||||||
|  | @ -60,13 +67,18 @@ class OpenAICompletion(BaseModel): | ||||||
|                 OpenAIChoice( |                 OpenAIChoice( | ||||||
|                     message=OpenAIMessage(role="assistant", content=text), |                     message=OpenAIMessage(role="assistant", content=text), | ||||||
|                     finish_reason=finish_reason, |                     finish_reason=finish_reason, | ||||||
|  |                     sources=sources, | ||||||
|                 ) |                 ) | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def json_from_delta( |     def json_from_delta( | ||||||
|         cls, *, text: str | None, finish_reason: str | None = None |         cls, | ||||||
|  |         *, | ||||||
|  |         text: str | None, | ||||||
|  |         finish_reason: str | None = None, | ||||||
|  |         sources: list[Chunk] | None = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         chunk = OpenAICompletion( |         chunk = OpenAICompletion( | ||||||
|             id=str(uuid.uuid4()), |             id=str(uuid.uuid4()), | ||||||
|  | @ -77,6 +89,7 @@ class OpenAICompletion(BaseModel): | ||||||
|                 OpenAIChoice( |                 OpenAIChoice( | ||||||
|                     delta=OpenAIDelta(content=text), |                     delta=OpenAIDelta(content=text), | ||||||
|                     finish_reason=finish_reason, |                     finish_reason=finish_reason, | ||||||
|  |                     sources=sources, | ||||||
|                 ) |                 ) | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
|  | @ -84,20 +97,25 @@ class OpenAICompletion(BaseModel): | ||||||
|         return chunk.model_dump_json() |         return chunk.model_dump_json() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_openai_response(response: str | ChatResponse) -> OpenAICompletion: | def to_openai_response( | ||||||
|  |     response: str | ChatResponse, sources: list[Chunk] | None = None | ||||||
|  | ) -> OpenAICompletion: | ||||||
|     if isinstance(response, ChatResponse): |     if isinstance(response, ChatResponse): | ||||||
|         return OpenAICompletion.from_text(response.delta, finish_reason="stop") |         return OpenAICompletion.from_text(response.delta, finish_reason="stop") | ||||||
|     else: |     else: | ||||||
|         return OpenAICompletion.from_text(response, finish_reason="stop") |         return OpenAICompletion.from_text( | ||||||
|  |             response, finish_reason="stop", sources=sources | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_openai_sse_stream( | def to_openai_sse_stream( | ||||||
|     response_generator: Iterator[str | CompletionResponse | ChatResponse], |     response_generator: Iterator[str | CompletionResponse | ChatResponse], | ||||||
|  |     sources: list[Chunk] | None = None, | ||||||
| ) -> Iterator[str]: | ) -> Iterator[str]: | ||||||
|     for response in response_generator: |     for response in response_generator: | ||||||
|         if isinstance(response, CompletionResponse | ChatResponse): |         if isinstance(response, CompletionResponse | ChatResponse): | ||||||
|             yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n" |             yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n" | ||||||
|         else: |         else: | ||||||
|             yield f"data: {OpenAICompletion.json_from_delta(text=response)}\n\n" |             yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n" | ||||||
|     yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n" |     yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n" | ||||||
|     yield "data: [DONE]\n\n" |     yield "data: [DONE]\n\n" | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ class ChatBody(BaseModel): | ||||||
|     messages: list[OpenAIMessage] |     messages: list[OpenAIMessage] | ||||||
|     use_context: bool = False |     use_context: bool = False | ||||||
|     context_filter: ContextFilter | None = None |     context_filter: ContextFilter | None = None | ||||||
|  |     include_sources: bool = True | ||||||
|     stream: bool = False |     stream: bool = False | ||||||
| 
 | 
 | ||||||
|     model_config = { |     model_config = { | ||||||
|  | @ -34,6 +35,7 @@ class ChatBody(BaseModel): | ||||||
|                     ], |                     ], | ||||||
|                     "stream": False, |                     "stream": False, | ||||||
|                     "use_context": True, |                     "use_context": True, | ||||||
|  |                     "include_sources": True, | ||||||
|                     "context_filter": { |                     "context_filter": { | ||||||
|                         "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"] |                         "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"] | ||||||
|                     }, |                     }, | ||||||
|  | @ -58,6 +60,9 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: | ||||||
|     Ingested documents IDs can be found using `/ingest/list` endpoint. If you want |     Ingested documents IDs can be found using `/ingest/list` endpoint. If you want | ||||||
|     all ingested documents to be used, remove `context_filter` altogether. |     all ingested documents to be used, remove `context_filter` altogether. | ||||||
| 
 | 
 | ||||||
|  |     When using `'include_sources': true`, the API will return the source Chunks used | ||||||
|  |     to create the response, which come from the context provided. | ||||||
|  | 
 | ||||||
|     When using `'stream': true`, the API will return data chunks following [OpenAI's |     When using `'stream': true`, the API will return data chunks following [OpenAI's | ||||||
|     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): |     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): | ||||||
|     ``` |     ``` | ||||||
|  | @ -71,12 +76,18 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: | ||||||
|         ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages |         ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages | ||||||
|     ] |     ] | ||||||
|     if body.stream: |     if body.stream: | ||||||
|         stream = service.stream_chat( |         completion_gen = service.stream_chat( | ||||||
|             all_messages, body.use_context, body.context_filter |             all_messages, body.use_context, body.context_filter | ||||||
|         ) |         ) | ||||||
|         return StreamingResponse( |         return StreamingResponse( | ||||||
|             to_openai_sse_stream(stream), media_type="text/event-stream" |             to_openai_sse_stream( | ||||||
|  |                 completion_gen.response, | ||||||
|  |                 completion_gen.sources if body.include_sources else None, | ||||||
|  |             ), | ||||||
|  |             media_type="text/event-stream", | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         response = service.chat(all_messages, body.use_context, body.context_filter) |         completion = service.chat(all_messages, body.use_context, body.context_filter) | ||||||
|         return to_openai_response(response) |         return to_openai_response( | ||||||
|  |             completion.response, completion.sources if body.include_sources else None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  | @ -1,13 +1,14 @@ | ||||||
| from collections.abc import Sequence |  | ||||||
| from typing import TYPE_CHECKING, Any |  | ||||||
| 
 |  | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index import ServiceContext, StorageContext, VectorStoreIndex | from llama_index import ServiceContext, StorageContext, VectorStoreIndex | ||||||
| from llama_index.chat_engine import ContextChatEngine | from llama_index.chat_engine import ContextChatEngine | ||||||
|  | from llama_index.chat_engine.types import ( | ||||||
|  |     BaseChatEngine, | ||||||
|  | ) | ||||||
| from llama_index.indices.postprocessor import MetadataReplacementPostProcessor | from llama_index.indices.postprocessor import MetadataReplacementPostProcessor | ||||||
| from llama_index.llm_predictor.utils import stream_chat_response_to_tokens | from llama_index.llm_predictor.utils import stream_chat_response_to_tokens | ||||||
| from llama_index.llms import ChatMessage | from llama_index.llms import ChatMessage | ||||||
| from llama_index.types import TokenGen | from llama_index.types import TokenGen | ||||||
|  | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from private_gpt.components.embedding.embedding_component import EmbeddingComponent | from private_gpt.components.embedding.embedding_component import EmbeddingComponent | ||||||
| from private_gpt.components.llm.llm_component import LLMComponent | from private_gpt.components.llm.llm_component import LLMComponent | ||||||
|  | @ -16,12 +17,17 @@ from private_gpt.components.vector_store.vector_store_component import ( | ||||||
|     VectorStoreComponent, |     VectorStoreComponent, | ||||||
| ) | ) | ||||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||||
|  | from private_gpt.server.chunks.chunks_service import Chunk | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | 
 | ||||||
|     from llama_index.chat_engine.types import ( | class Completion(BaseModel): | ||||||
|         AgentChatResponse, |     response: str | ||||||
|         StreamingAgentChatResponse, |     sources: list[Chunk] | None = None | ||||||
|     ) | 
 | ||||||
|  | 
 | ||||||
|  | class CompletionGen(BaseModel): | ||||||
|  |     response: TokenGen | ||||||
|  |     sources: list[Chunk] | None = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @singleton | @singleton | ||||||
|  | @ -51,66 +57,64 @@ class ChatService: | ||||||
|             show_progress=True, |             show_progress=True, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def _chat_with_contex( |     def _chat_engine( | ||||||
|         self, |         self, context_filter: ContextFilter | None = None | ||||||
|         message: str, |     ) -> BaseChatEngine: | ||||||
|         context_filter: ContextFilter | None = None, |  | ||||||
|         chat_history: Sequence[ChatMessage] | None = None, |  | ||||||
|         streaming: bool = False, |  | ||||||
|     ) -> Any: |  | ||||||
|         vector_index_retriever = self.vector_store_component.get_retriever( |         vector_index_retriever = self.vector_store_component.get_retriever( | ||||||
|             index=self.index, context_filter=context_filter |             index=self.index, context_filter=context_filter | ||||||
|         ) |         ) | ||||||
|         chat_engine = ContextChatEngine.from_defaults( |         return ContextChatEngine.from_defaults( | ||||||
|             retriever=vector_index_retriever, |             retriever=vector_index_retriever, | ||||||
|             service_context=self.service_context, |             service_context=self.service_context, | ||||||
|             node_postprocessors=[ |             node_postprocessors=[ | ||||||
|                 MetadataReplacementPostProcessor(target_metadata_key="window"), |                 MetadataReplacementPostProcessor(target_metadata_key="window"), | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
|         if streaming: |  | ||||||
|             result = chat_engine.stream_chat(message, chat_history) |  | ||||||
|         else: |  | ||||||
|             result = chat_engine.chat(message, chat_history) |  | ||||||
|         return result |  | ||||||
| 
 | 
 | ||||||
|     def stream_chat( |     def stream_chat( | ||||||
|         self, |         self, | ||||||
|         messages: list[ChatMessage], |         messages: list[ChatMessage], | ||||||
|         use_context: bool = False, |         use_context: bool = False, | ||||||
|         context_filter: ContextFilter | None = None, |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> TokenGen: |     ) -> CompletionGen: | ||||||
|         if use_context: |         if use_context: | ||||||
|             last_message = messages[-1].content |             last_message = messages[-1].content | ||||||
|             response: StreamingAgentChatResponse = self._chat_with_contex( |             chat_engine = self._chat_engine(context_filter=context_filter) | ||||||
|  |             streaming_response = chat_engine.stream_chat( | ||||||
|                 message=last_message if last_message is not None else "", |                 message=last_message if last_message is not None else "", | ||||||
|                 chat_history=messages[:-1], |                 chat_history=messages[:-1], | ||||||
|                 context_filter=context_filter, |  | ||||||
|                 streaming=True, |  | ||||||
|             ) |             ) | ||||||
|             response_gen = response.response_gen |             sources = [ | ||||||
|  |                 Chunk.from_node(node) for node in streaming_response.source_nodes | ||||||
|  |             ] | ||||||
|  |             completion_gen = CompletionGen( | ||||||
|  |                 response=streaming_response.response_gen, sources=sources | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             stream = self.llm_service.llm.stream_chat(messages) |             stream = self.llm_service.llm.stream_chat(messages) | ||||||
|             response_gen = stream_chat_response_to_tokens(stream) |             completion_gen = CompletionGen( | ||||||
|         return response_gen |                 response=stream_chat_response_to_tokens(stream) | ||||||
|  |             ) | ||||||
|  |         return completion_gen | ||||||
| 
 | 
 | ||||||
|     def chat( |     def chat( | ||||||
|         self, |         self, | ||||||
|         messages: list[ChatMessage], |         messages: list[ChatMessage], | ||||||
|         use_context: bool = False, |         use_context: bool = False, | ||||||
|         context_filter: ContextFilter | None = None, |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> str: |     ) -> Completion: | ||||||
|         if use_context: |         if use_context: | ||||||
|             last_message = messages[-1].content |             last_message = messages[-1].content | ||||||
|             wrapped_response: AgentChatResponse = self._chat_with_contex( |             chat_engine = self._chat_engine(context_filter=context_filter) | ||||||
|  |             wrapped_response = chat_engine.chat( | ||||||
|                 message=last_message if last_message is not None else "", |                 message=last_message if last_message is not None else "", | ||||||
|                 chat_history=messages[:-1], |                 chat_history=messages[:-1], | ||||||
|                 context_filter=context_filter, |  | ||||||
|                 streaming=False, |  | ||||||
|             ) |             ) | ||||||
|             response = wrapped_response.response |             sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] | ||||||
|  |             completion = Completion(response=wrapped_response.response, sources=sources) | ||||||
|         else: |         else: | ||||||
|             chat_response = self.llm_service.llm.chat(messages) |             chat_response = self.llm_service.llm.chat(messages) | ||||||
|             response_content = chat_response.message.content |             response_content = chat_response.message.content | ||||||
|             response = response_content if response_content is not None else "" |             response = response_content if response_content is not None else "" | ||||||
|         return response |             completion = Completion(response=response) | ||||||
|  |         return completion | ||||||
|  |  | ||||||
|  | @ -24,17 +24,33 @@ class Chunk(BaseModel): | ||||||
|     document: IngestedDoc |     document: IngestedDoc | ||||||
|     text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) |     text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) | ||||||
|     previous_texts: list[str] | None = Field( |     previous_texts: list[str] | None = Field( | ||||||
|         examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]] |         default=None, | ||||||
|  |         examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]], | ||||||
|     ) |     ) | ||||||
|     next_texts: list[str] | None = Field( |     next_texts: list[str] | None = Field( | ||||||
|  |         default=None, | ||||||
|         examples=[ |         examples=[ | ||||||
|             [ |             [ | ||||||
|                 "New leads came from Google Ads campaign.", |                 "New leads came from Google Ads campaign.", | ||||||
|                 "The campaign was run by the Marketing Department", |                 "The campaign was run by the Marketing Department", | ||||||
|             ] |             ] | ||||||
|         ] |         ], | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk": | ||||||
|  |         doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" | ||||||
|  |         return cls( | ||||||
|  |             object="context.chunk", | ||||||
|  |             score=node.score or 0.0, | ||||||
|  |             document=IngestedDoc( | ||||||
|  |                 object="ingest.document", | ||||||
|  |                 doc_id=doc_id, | ||||||
|  |                 doc_metadata=node.metadata, | ||||||
|  |             ), | ||||||
|  |             text=node.get_content(), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @singleton | @singleton | ||||||
| class ChunksService: | class ChunksService: | ||||||
|  | @ -98,22 +114,11 @@ class ChunksService: | ||||||
| 
 | 
 | ||||||
|         retrieved_nodes = [] |         retrieved_nodes = [] | ||||||
|         for node in nodes: |         for node in nodes: | ||||||
|             doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" |             chunk = Chunk.from_node(node) | ||||||
|             retrieved_nodes.append( |             chunk.previous_texts = self._get_sibling_nodes_text( | ||||||
|                 Chunk( |                 node, prev_next_chunks, False | ||||||
|                     object="context.chunk", |  | ||||||
|                     score=node.score or 0.0, |  | ||||||
|                     document=IngestedDoc( |  | ||||||
|                         object="ingest.document", |  | ||||||
|                         doc_id=doc_id, |  | ||||||
|                         doc_metadata=node.metadata, |  | ||||||
|                     ), |  | ||||||
|                     text=node.get_content(), |  | ||||||
|                     previous_texts=self._get_sibling_nodes_text( |  | ||||||
|                         node, prev_next_chunks, False |  | ||||||
|                     ), |  | ||||||
|                     next_texts=self._get_sibling_nodes_text(node, prev_next_chunks), |  | ||||||
|                 ) |  | ||||||
|             ) |             ) | ||||||
|  |             chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks) | ||||||
|  |             retrieved_nodes.append(chunk) | ||||||
| 
 | 
 | ||||||
|         return retrieved_nodes |         return retrieved_nodes | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ class CompletionsBody(BaseModel): | ||||||
|     prompt: str |     prompt: str | ||||||
|     use_context: bool = False |     use_context: bool = False | ||||||
|     context_filter: ContextFilter | None = None |     context_filter: ContextFilter | None = None | ||||||
|  |     include_sources: bool = True | ||||||
|     stream: bool = False |     stream: bool = False | ||||||
| 
 | 
 | ||||||
|     model_config = { |     model_config = { | ||||||
|  | @ -25,6 +26,7 @@ class CompletionsBody(BaseModel): | ||||||
|                     "prompt": "How do you fry an egg?", |                     "prompt": "How do you fry an egg?", | ||||||
|                     "stream": False, |                     "stream": False, | ||||||
|                     "use_context": False, |                     "use_context": False, | ||||||
|  |                     "include_sources": False, | ||||||
|                 } |                 } | ||||||
|             ] |             ] | ||||||
|         } |         } | ||||||
|  | @ -48,6 +50,9 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp | ||||||
|     can be found using `/ingest/list` endpoint. If you want all ingested documents to |     can be found using `/ingest/list` endpoint. If you want all ingested documents to | ||||||
|     be used, remove `context_filter` altogether. |     be used, remove `context_filter` altogether. | ||||||
| 
 | 
 | ||||||
|  |     When using `'include_sources': true`, the API will return the source Chunks used | ||||||
|  |     to create the response, which come from the context provided. | ||||||
|  | 
 | ||||||
|     When using `'stream': true`, the API will return data chunks following [OpenAI's |     When using `'stream': true`, the API will return data chunks following [OpenAI's | ||||||
|     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): |     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): | ||||||
|     ``` |     ``` | ||||||
|  | @ -61,6 +66,7 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp | ||||||
|         messages=[message], |         messages=[message], | ||||||
|         use_context=body.use_context, |         use_context=body.use_context, | ||||||
|         stream=body.stream, |         stream=body.stream, | ||||||
|  |         include_sources=body.include_sources, | ||||||
|         context_filter=body.context_filter, |         context_filter=body.context_filter, | ||||||
|     ) |     ) | ||||||
|     return chat_completion(chat_body) |     return chat_completion(chat_body) | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ from gradio.themes.utils.colors import slate  # type: ignore | ||||||
| from llama_index.llms import ChatMessage, ChatResponse, MessageRole | from llama_index.llms import ChatMessage, ChatResponse, MessageRole | ||||||
| 
 | 
 | ||||||
| from private_gpt.di import root_injector | from private_gpt.di import root_injector | ||||||
| from private_gpt.server.chat.chat_service import ChatService | from private_gpt.server.chat.chat_service import ChatService, CompletionGen | ||||||
| from private_gpt.server.chunks.chunks_service import ChunksService | from private_gpt.server.chunks.chunks_service import ChunksService | ||||||
| from private_gpt.server.ingest.ingest_service import IngestService | from private_gpt.server.ingest.ingest_service import IngestService | ||||||
| from private_gpt.settings.settings import settings | from private_gpt.settings.settings import settings | ||||||
|  | @ -33,8 +33,9 @@ class PrivateGptUi: | ||||||
|         self._ui_block = None |         self._ui_block = None | ||||||
| 
 | 
 | ||||||
|     def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: |     def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | ||||||
|         def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: |         def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: | ||||||
|             full_response: str = "" |             full_response: str = "" | ||||||
|  |             stream = completion_gen.response | ||||||
|             for delta in stream: |             for delta in stream: | ||||||
|                 if isinstance(delta, str): |                 if isinstance(delta, str): | ||||||
|                     full_response += str(delta) |                     full_response += str(delta) | ||||||
|  | @ -42,6 +43,26 @@ class PrivateGptUi: | ||||||
|                     full_response += delta.delta or "" |                     full_response += delta.delta or "" | ||||||
|                 yield full_response |                 yield full_response | ||||||
| 
 | 
 | ||||||
|  |             if completion_gen.sources: | ||||||
|  |                 full_response += "\n\n Sources: \n" | ||||||
|  |                 sources = ( | ||||||
|  |                     { | ||||||
|  |                         "file": chunk.document.doc_metadata["file_name"] | ||||||
|  |                         if chunk.document.doc_metadata | ||||||
|  |                         else "", | ||||||
|  |                         "page": chunk.document.doc_metadata["page_label"] | ||||||
|  |                         if chunk.document.doc_metadata | ||||||
|  |                         else "", | ||||||
|  |                     } | ||||||
|  |                     for chunk in completion_gen.sources | ||||||
|  |                 ) | ||||||
|  |                 sources_text = "\n\n\n".join( | ||||||
|  |                     f"{index}. {source['file']} (page {source['page']})" | ||||||
|  |                     for index, source in enumerate(sources, start=1) | ||||||
|  |                 ) | ||||||
|  |                 full_response += sources_text | ||||||
|  |             yield full_response | ||||||
|  | 
 | ||||||
|         def build_history() -> list[ChatMessage]: |         def build_history() -> list[ChatMessage]: | ||||||
|             history_messages: list[ChatMessage] = list( |             history_messages: list[ChatMessage] = list( | ||||||
|                 itertools.chain( |                 itertools.chain( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue