Allow passing a system prompt (#1318)
This commit is contained in:
		
							parent
							
								
									9c192ddd73
								
							
						
					
					
						commit
						64ed9cd872
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -28,10 +28,14 @@ class ChatBody(BaseModel): | ||||||
|             "examples": [ |             "examples": [ | ||||||
|                 { |                 { | ||||||
|                     "messages": [ |                     "messages": [ | ||||||
|  |                         { | ||||||
|  |                             "role": "system", | ||||||
|  |                             "content": "You are a rapper. Always answer with a rap.", | ||||||
|  |                         }, | ||||||
|                         { |                         { | ||||||
|                             "role": "user", |                             "role": "user", | ||||||
|                             "content": "How do you fry an egg?", |                             "content": "How do you fry an egg?", | ||||||
|                         } |                         }, | ||||||
|                     ], |                     ], | ||||||
|                     "stream": False, |                     "stream": False, | ||||||
|                     "use_context": True, |                     "use_context": True, | ||||||
|  | @ -56,6 +60,9 @@ def chat_completion( | ||||||
| ) -> OpenAICompletion | StreamingResponse: | ) -> OpenAICompletion | StreamingResponse: | ||||||
|     """Given a list of messages comprising a conversation, return a response. |     """Given a list of messages comprising a conversation, return a response. | ||||||
| 
 | 
 | ||||||
|  |     Optionally include an initial `role: system` message to influence the way | ||||||
|  |     the LLM answers. | ||||||
|  | 
 | ||||||
|     If `use_context` is set to `true`, the model will use context coming |     If `use_context` is set to `true`, the model will use context coming | ||||||
|     from the ingested documents to create the response. The documents being used can |     from the ingested documents to create the response. The documents being used can | ||||||
|     be filtered using the `context_filter` and passing the document IDs to be used. |     be filtered using the `context_filter` and passing the document IDs to be used. | ||||||
|  | @ -79,7 +86,9 @@ def chat_completion( | ||||||
|     ] |     ] | ||||||
|     if body.stream: |     if body.stream: | ||||||
|         completion_gen = service.stream_chat( |         completion_gen = service.stream_chat( | ||||||
|             all_messages, body.use_context, body.context_filter |             messages=all_messages, | ||||||
|  |             use_context=body.use_context, | ||||||
|  |             context_filter=body.context_filter, | ||||||
|         ) |         ) | ||||||
|         return StreamingResponse( |         return StreamingResponse( | ||||||
|             to_openai_sse_stream( |             to_openai_sse_stream( | ||||||
|  | @ -89,7 +98,11 @@ def chat_completion( | ||||||
|             media_type="text/event-stream", |             media_type="text/event-stream", | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         completion = service.chat(all_messages, body.use_context, body.context_filter) |         completion = service.chat( | ||||||
|  |             messages=all_messages, | ||||||
|  |             use_context=body.use_context, | ||||||
|  |             context_filter=body.context_filter, | ||||||
|  |         ) | ||||||
|         return to_openai_response( |         return to_openai_response( | ||||||
|             completion.response, completion.sources if body.include_sources else None |             completion.response, completion.sources if body.include_sources else None | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -1,12 +1,13 @@ | ||||||
|  | from dataclasses import dataclass | ||||||
|  | 
 | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index import ServiceContext, StorageContext, VectorStoreIndex | from llama_index import ServiceContext, StorageContext, VectorStoreIndex | ||||||
| from llama_index.chat_engine import ContextChatEngine | from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine | ||||||
| from llama_index.chat_engine.types import ( | from llama_index.chat_engine.types import ( | ||||||
|     BaseChatEngine, |     BaseChatEngine, | ||||||
| ) | ) | ||||||
| from llama_index.indices.postprocessor import MetadataReplacementPostProcessor | from llama_index.indices.postprocessor import MetadataReplacementPostProcessor | ||||||
| from llama_index.llm_predictor.utils import stream_chat_response_to_tokens | from llama_index.llms import ChatMessage, MessageRole | ||||||
| from llama_index.llms import ChatMessage |  | ||||||
| from llama_index.types import TokenGen | from llama_index.types import TokenGen | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
|  | @ -30,6 +31,40 @@ class CompletionGen(BaseModel): | ||||||
|     sources: list[Chunk] | None = None |     sources: list[Chunk] | None = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class ChatEngineInput: | ||||||
|  |     system_message: ChatMessage | None = None | ||||||
|  |     last_message: ChatMessage | None = None | ||||||
|  |     chat_history: list[ChatMessage] | None = None | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput": | ||||||
|  |         # Detect if there is a system message, extract the last message and chat history | ||||||
|  |         system_message = ( | ||||||
|  |             messages[0] | ||||||
|  |             if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM | ||||||
|  |             else None | ||||||
|  |         ) | ||||||
|  |         last_message = ( | ||||||
|  |             messages[-1] | ||||||
|  |             if len(messages) > 0 and messages[-1].role == MessageRole.USER | ||||||
|  |             else None | ||||||
|  |         ) | ||||||
|  |         # Remove from messages list the system message and last message, | ||||||
|  |         # if they exist. The rest is the chat history. | ||||||
|  |         if system_message: | ||||||
|  |             messages.pop(0) | ||||||
|  |         if last_message: | ||||||
|  |             messages.pop(-1) | ||||||
|  |         chat_history = messages if len(messages) > 0 else None | ||||||
|  | 
 | ||||||
|  |         return cls( | ||||||
|  |             system_message=system_message, | ||||||
|  |             last_message=last_message, | ||||||
|  |             chat_history=chat_history, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @singleton | @singleton | ||||||
| class ChatService: | class ChatService: | ||||||
|     @inject |     @inject | ||||||
|  | @ -58,18 +93,28 @@ class ChatService: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def _chat_engine( |     def _chat_engine( | ||||||
|         self, context_filter: ContextFilter | None = None |         self, | ||||||
|  |         system_prompt: str | None = None, | ||||||
|  |         use_context: bool = False, | ||||||
|  |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> BaseChatEngine: |     ) -> BaseChatEngine: | ||||||
|         vector_index_retriever = self.vector_store_component.get_retriever( |         if use_context: | ||||||
|             index=self.index, context_filter=context_filter |             vector_index_retriever = self.vector_store_component.get_retriever( | ||||||
|         ) |                 index=self.index, context_filter=context_filter | ||||||
|         return ContextChatEngine.from_defaults( |             ) | ||||||
|             retriever=vector_index_retriever, |             return ContextChatEngine.from_defaults( | ||||||
|             service_context=self.service_context, |                 system_prompt=system_prompt, | ||||||
|             node_postprocessors=[ |                 retriever=vector_index_retriever, | ||||||
|                 MetadataReplacementPostProcessor(target_metadata_key="window"), |                 service_context=self.service_context, | ||||||
|             ], |                 node_postprocessors=[ | ||||||
|         ) |                     MetadataReplacementPostProcessor(target_metadata_key="window"), | ||||||
|  |                 ], | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             return SimpleChatEngine.from_defaults( | ||||||
|  |                 system_prompt=system_prompt, | ||||||
|  |                 service_context=self.service_context, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def stream_chat( |     def stream_chat( | ||||||
|         self, |         self, | ||||||
|  | @ -77,24 +122,34 @@ class ChatService: | ||||||
|         use_context: bool = False, |         use_context: bool = False, | ||||||
|         context_filter: ContextFilter | None = None, |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> CompletionGen: |     ) -> CompletionGen: | ||||||
|         if use_context: |         chat_engine_input = ChatEngineInput.from_messages(messages) | ||||||
|             last_message = messages[-1].content |         last_message = ( | ||||||
|             chat_engine = self._chat_engine(context_filter=context_filter) |             chat_engine_input.last_message.content | ||||||
|             streaming_response = chat_engine.stream_chat( |             if chat_engine_input.last_message | ||||||
|                 message=last_message if last_message is not None else "", |             else None | ||||||
|                 chat_history=messages[:-1], |         ) | ||||||
|             ) |         system_prompt = ( | ||||||
|             sources = [ |             chat_engine_input.system_message.content | ||||||
|                 Chunk.from_node(node) for node in streaming_response.source_nodes |             if chat_engine_input.system_message | ||||||
|             ] |             else None | ||||||
|             completion_gen = CompletionGen( |         ) | ||||||
|                 response=streaming_response.response_gen, sources=sources |         chat_history = ( | ||||||
|             ) |             chat_engine_input.chat_history if chat_engine_input.chat_history else None | ||||||
|         else: |         ) | ||||||
|             stream = self.llm_service.llm.stream_chat(messages) | 
 | ||||||
|             completion_gen = CompletionGen( |         chat_engine = self._chat_engine( | ||||||
|                 response=stream_chat_response_to_tokens(stream) |             system_prompt=system_prompt, | ||||||
|             ) |             use_context=use_context, | ||||||
|  |             context_filter=context_filter, | ||||||
|  |         ) | ||||||
|  |         streaming_response = chat_engine.stream_chat( | ||||||
|  |             message=last_message if last_message is not None else "", | ||||||
|  |             chat_history=chat_history, | ||||||
|  |         ) | ||||||
|  |         sources = [Chunk.from_node(node) for node in streaming_response.source_nodes] | ||||||
|  |         completion_gen = CompletionGen( | ||||||
|  |             response=streaming_response.response_gen, sources=sources | ||||||
|  |         ) | ||||||
|         return completion_gen |         return completion_gen | ||||||
| 
 | 
 | ||||||
|     def chat( |     def chat( | ||||||
|  | @ -103,18 +158,30 @@ class ChatService: | ||||||
|         use_context: bool = False, |         use_context: bool = False, | ||||||
|         context_filter: ContextFilter | None = None, |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> Completion: |     ) -> Completion: | ||||||
|         if use_context: |         chat_engine_input = ChatEngineInput.from_messages(messages) | ||||||
|             last_message = messages[-1].content |         last_message = ( | ||||||
|             chat_engine = self._chat_engine(context_filter=context_filter) |             chat_engine_input.last_message.content | ||||||
|             wrapped_response = chat_engine.chat( |             if chat_engine_input.last_message | ||||||
|                 message=last_message if last_message is not None else "", |             else None | ||||||
|                 chat_history=messages[:-1], |         ) | ||||||
|             ) |         system_prompt = ( | ||||||
|             sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] |             chat_engine_input.system_message.content | ||||||
|             completion = Completion(response=wrapped_response.response, sources=sources) |             if chat_engine_input.system_message | ||||||
|         else: |             else None | ||||||
|             chat_response = self.llm_service.llm.chat(messages) |         ) | ||||||
|             response_content = chat_response.message.content |         chat_history = ( | ||||||
|             response = response_content if response_content is not None else "" |             chat_engine_input.chat_history if chat_engine_input.chat_history else None | ||||||
|             completion = Completion(response=response) |         ) | ||||||
|  | 
 | ||||||
|  |         chat_engine = self._chat_engine( | ||||||
|  |             system_prompt=system_prompt, | ||||||
|  |             use_context=use_context, | ||||||
|  |             context_filter=context_filter, | ||||||
|  |         ) | ||||||
|  |         wrapped_response = chat_engine.chat( | ||||||
|  |             message=last_message if last_message is not None else "", | ||||||
|  |             chat_history=chat_history, | ||||||
|  |         ) | ||||||
|  |         sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] | ||||||
|  |         completion = Completion(response=wrapped_response.response, sources=sources) | ||||||
|         return completion |         return completion | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated | ||||||
| 
 | 
 | ||||||
| class CompletionsBody(BaseModel): | class CompletionsBody(BaseModel): | ||||||
|     prompt: str |     prompt: str | ||||||
|  |     system_prompt: str | None = None | ||||||
|     use_context: bool = False |     use_context: bool = False | ||||||
|     context_filter: ContextFilter | None = None |     context_filter: ContextFilter | None = None | ||||||
|     include_sources: bool = True |     include_sources: bool = True | ||||||
|  | @ -25,6 +26,7 @@ class CompletionsBody(BaseModel): | ||||||
|             "examples": [ |             "examples": [ | ||||||
|                 { |                 { | ||||||
|                     "prompt": "How do you fry an egg?", |                     "prompt": "How do you fry an egg?", | ||||||
|  |                     "system_prompt": "You are a rapper. Always answer with a rap.", | ||||||
|                     "stream": False, |                     "stream": False, | ||||||
|                     "use_context": False, |                     "use_context": False, | ||||||
|                     "include_sources": False, |                     "include_sources": False, | ||||||
|  | @ -46,7 +48,11 @@ def prompt_completion( | ||||||
| ) -> OpenAICompletion | StreamingResponse: | ) -> OpenAICompletion | StreamingResponse: | ||||||
|     """We recommend most users use our Chat completions API. |     """We recommend most users use our Chat completions API. | ||||||
| 
 | 
 | ||||||
|     Given a prompt, the model will return one predicted completion. If `use_context` |     Given a prompt, the model will return one predicted completion. | ||||||
|  | 
 | ||||||
|  |     Optionally include a `system_prompt` to influence the way the LLM answers. | ||||||
|  | 
 | ||||||
|  |     If `use_context` | ||||||
|     is set to `true`, the model will use context coming from the ingested documents |     is set to `true`, the model will use context coming from the ingested documents | ||||||
|     to create the response. The documents being used can be filtered using the |     to create the response. The documents being used can be filtered using the | ||||||
|     `context_filter` and passing the document IDs to be used. Ingested documents IDs |     `context_filter` and passing the document IDs to be used. Ingested documents IDs | ||||||
|  | @ -64,9 +70,13 @@ def prompt_completion( | ||||||
|     "finish_reason":null}]} |     "finish_reason":null}]} | ||||||
|     ``` |     ``` | ||||||
|     """ |     """ | ||||||
|     message = OpenAIMessage(content=body.prompt, role="user") |     messages = [OpenAIMessage(content=body.prompt, role="user")] | ||||||
|  |     # If system prompt is passed, create a fake message with the system prompt. | ||||||
|  |     if body.system_prompt: | ||||||
|  |         messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system")) | ||||||
|  | 
 | ||||||
|     chat_body = ChatBody( |     chat_body = ChatBody( | ||||||
|         messages=[message], |         messages=messages, | ||||||
|         use_context=body.use_context, |         use_context=body.use_context, | ||||||
|         stream=body.stream, |         stream=body.stream, | ||||||
|         include_sources=body.include_sources, |         include_sources=body.include_sources, | ||||||
|  |  | ||||||
|  | @ -116,6 +116,17 @@ class PrivateGptUi: | ||||||
|         all_messages = [*build_history(), new_message] |         all_messages = [*build_history(), new_message] | ||||||
|         match mode: |         match mode: | ||||||
|             case "Query Docs": |             case "Query Docs": | ||||||
|  |                 # Add a system message to force the behaviour of the LLM | ||||||
|  |                 # to answer only questions about the provided context. | ||||||
|  |                 all_messages.insert( | ||||||
|  |                     0, | ||||||
|  |                     ChatMessage( | ||||||
|  |                         content="You can only answer questions about the provided context. If you know the answer " | ||||||
|  |                         "but it is not based in the provided context, don't provide the answer, just state " | ||||||
|  |                         "the answer is not in the context provided.", | ||||||
|  |                         role=MessageRole.SYSTEM, | ||||||
|  |                     ), | ||||||
|  |                 ) | ||||||
|                 query_stream = self._chat_service.stream_chat( |                 query_stream = self._chat_service.stream_chat( | ||||||
|                     messages=all_messages, |                     messages=all_messages, | ||||||
|                     use_context=True, |                     use_context=True, | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ ui: | ||||||
| 
 | 
 | ||||||
| llm: | llm: | ||||||
|   mode: local |   mode: local | ||||||
|  | 
 | ||||||
| embedding: | embedding: | ||||||
|   # Should be matching the value above in most cases |   # Should be matching the value above in most cases | ||||||
|   mode: local |   mode: local | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue