Allow passing a system prompt (#1318)

This commit is contained in:
Iván Martínez 2023-11-29 15:51:19 +01:00 committed by GitHub
parent 9c192ddd73
commit 64ed9cd872
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1129 additions and 1039 deletions

File diff suppressed because it is too large Load Diff

View File

@ -28,10 +28,14 @@ class ChatBody(BaseModel):
"examples": [ "examples": [
{ {
"messages": [ "messages": [
{
"role": "system",
"content": "You are a rapper. Always answer with a rap.",
},
{ {
"role": "user", "role": "user",
"content": "How do you fry an egg?", "content": "How do you fry an egg?",
} },
], ],
"stream": False, "stream": False,
"use_context": True, "use_context": True,
@ -56,6 +60,9 @@ def chat_completion(
) -> OpenAICompletion | StreamingResponse: ) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response. """Given a list of messages comprising a conversation, return a response.
Optionally include an initial `role: system` message to influence the way
the LLM answers.
If `use_context` is set to `true`, the model will use context coming If `use_context` is set to `true`, the model will use context coming
from the ingested documents to create the response. The documents being used can from the ingested documents to create the response. The documents being used can
be filtered using the `context_filter` and passing the document IDs to be used. be filtered using the `context_filter` and passing the document IDs to be used.
@ -79,7 +86,9 @@ def chat_completion(
] ]
if body.stream: if body.stream:
completion_gen = service.stream_chat( completion_gen = service.stream_chat(
all_messages, body.use_context, body.context_filter messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
) )
return StreamingResponse( return StreamingResponse(
to_openai_sse_stream( to_openai_sse_stream(
@ -89,7 +98,11 @@ def chat_completion(
media_type="text/event-stream", media_type="text/event-stream",
) )
else: else:
completion = service.chat(all_messages, body.use_context, body.context_filter) completion = service.chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return to_openai_response( return to_openai_response(
completion.response, completion.sources if body.include_sources else None completion.response, completion.sources if body.include_sources else None
) )

View File

@ -1,12 +1,13 @@
from dataclasses import dataclass
from injector import inject, singleton from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.chat_engine import ContextChatEngine from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.chat_engine.types import ( from llama_index.chat_engine.types import (
BaseChatEngine, BaseChatEngine,
) )
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.llm_predictor.utils import stream_chat_response_to_tokens from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms import ChatMessage
from llama_index.types import TokenGen from llama_index.types import TokenGen
from pydantic import BaseModel from pydantic import BaseModel
@ -30,6 +31,40 @@ class CompletionGen(BaseModel):
sources: list[Chunk] | None = None sources: list[Chunk] | None = None
@dataclass
class ChatEngineInput:
system_message: ChatMessage | None = None
last_message: ChatMessage | None = None
chat_history: list[ChatMessage] | None = None
@classmethod
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
# Detect if there is a system message, extract the last message and chat history
system_message = (
messages[0]
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
else None
)
last_message = (
messages[-1]
if len(messages) > 0 and messages[-1].role == MessageRole.USER
else None
)
# Remove from messages list the system message and last message,
# if they exist. The rest is the chat history.
if system_message:
messages.pop(0)
if last_message:
messages.pop(-1)
chat_history = messages if len(messages) > 0 else None
return cls(
system_message=system_message,
last_message=last_message,
chat_history=chat_history,
)
@singleton @singleton
class ChatService: class ChatService:
@inject @inject
@ -58,18 +93,28 @@ class ChatService:
) )
def _chat_engine( def _chat_engine(
self, context_filter: ContextFilter | None = None self,
system_prompt: str | None = None,
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> BaseChatEngine: ) -> BaseChatEngine:
vector_index_retriever = self.vector_store_component.get_retriever( if use_context:
index=self.index, context_filter=context_filter vector_index_retriever = self.vector_store_component.get_retriever(
) index=self.index, context_filter=context_filter
return ContextChatEngine.from_defaults( )
retriever=vector_index_retriever, return ContextChatEngine.from_defaults(
service_context=self.service_context, system_prompt=system_prompt,
node_postprocessors=[ retriever=vector_index_retriever,
MetadataReplacementPostProcessor(target_metadata_key="window"), service_context=self.service_context,
], node_postprocessors=[
) MetadataReplacementPostProcessor(target_metadata_key="window"),
],
)
else:
return SimpleChatEngine.from_defaults(
system_prompt=system_prompt,
service_context=self.service_context,
)
def stream_chat( def stream_chat(
self, self,
@ -77,24 +122,34 @@ class ChatService:
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> CompletionGen: ) -> CompletionGen:
if use_context: chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = messages[-1].content last_message = (
chat_engine = self._chat_engine(context_filter=context_filter) chat_engine_input.last_message.content
streaming_response = chat_engine.stream_chat( if chat_engine_input.last_message
message=last_message if last_message is not None else "", else None
chat_history=messages[:-1], )
) system_prompt = (
sources = [ chat_engine_input.system_message.content
Chunk.from_node(node) for node in streaming_response.source_nodes if chat_engine_input.system_message
] else None
completion_gen = CompletionGen( )
response=streaming_response.response_gen, sources=sources chat_history = (
) chat_engine_input.chat_history if chat_engine_input.chat_history else None
else: )
stream = self.llm_service.llm.stream_chat(messages)
completion_gen = CompletionGen( chat_engine = self._chat_engine(
response=stream_chat_response_to_tokens(stream) system_prompt=system_prompt,
) use_context=use_context,
context_filter=context_filter,
)
streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources
)
return completion_gen return completion_gen
def chat( def chat(
@ -103,18 +158,30 @@ class ChatService:
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> Completion: ) -> Completion:
if use_context: chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = messages[-1].content last_message = (
chat_engine = self._chat_engine(context_filter=context_filter) chat_engine_input.last_message.content
wrapped_response = chat_engine.chat( if chat_engine_input.last_message
message=last_message if last_message is not None else "", else None
chat_history=messages[:-1], )
) system_prompt = (
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] chat_engine_input.system_message.content
completion = Completion(response=wrapped_response.response, sources=sources) if chat_engine_input.system_message
else: else None
chat_response = self.llm_service.llm.chat(messages) )
response_content = chat_response.message.content chat_history = (
response = response_content if response_content is not None else "" chat_engine_input.chat_history if chat_engine_input.chat_history else None
completion = Completion(response=response) )
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
return completion return completion

View File

@ -15,6 +15,7 @@ completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated
class CompletionsBody(BaseModel): class CompletionsBody(BaseModel):
prompt: str prompt: str
system_prompt: str | None = None
use_context: bool = False use_context: bool = False
context_filter: ContextFilter | None = None context_filter: ContextFilter | None = None
include_sources: bool = True include_sources: bool = True
@ -25,6 +26,7 @@ class CompletionsBody(BaseModel):
"examples": [ "examples": [
{ {
"prompt": "How do you fry an egg?", "prompt": "How do you fry an egg?",
"system_prompt": "You are a rapper. Always answer with a rap.",
"stream": False, "stream": False,
"use_context": False, "use_context": False,
"include_sources": False, "include_sources": False,
@ -46,7 +48,11 @@ def prompt_completion(
) -> OpenAICompletion | StreamingResponse: ) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API. """We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context` Given a prompt, the model will return one predicted completion.
Optionally include a `system_prompt` to influence the way the LLM answers.
If `use_context`
is set to `true`, the model will use context coming from the ingested documents is set to `true`, the model will use context coming from the ingested documents
to create the response. The documents being used can be filtered using the to create the response. The documents being used can be filtered using the
`context_filter` and passing the document IDs to be used. Ingested documents IDs `context_filter` and passing the document IDs to be used. Ingested documents IDs
@ -64,9 +70,13 @@ def prompt_completion(
"finish_reason":null}]} "finish_reason":null}]}
``` ```
""" """
message = OpenAIMessage(content=body.prompt, role="user") messages = [OpenAIMessage(content=body.prompt, role="user")]
# If system prompt is passed, create a fake message with the system prompt.
if body.system_prompt:
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
chat_body = ChatBody( chat_body = ChatBody(
messages=[message], messages=messages,
use_context=body.use_context, use_context=body.use_context,
stream=body.stream, stream=body.stream,
include_sources=body.include_sources, include_sources=body.include_sources,

View File

@ -116,6 +116,17 @@ class PrivateGptUi:
all_messages = [*build_history(), new_message] all_messages = [*build_history(), new_message]
match mode: match mode:
case "Query Docs": case "Query Docs":
# Add a system message to force the behaviour of the LLM
# to answer only questions about the provided context.
all_messages.insert(
0,
ChatMessage(
content="You can only answer questions about the provided context. If you know the answer "
"but it is not based in the provided context, don't provide the answer, just state "
"the answer is not in the context provided.",
role=MessageRole.SYSTEM,
),
)
query_stream = self._chat_service.stream_chat( query_stream = self._chat_service.stream_chat(
messages=all_messages, messages=all_messages,
use_context=True, use_context=True,

View File

@ -22,6 +22,7 @@ ui:
llm: llm:
mode: local mode: local
embedding: embedding:
# Should be matching the value above in most cases # Should be matching the value above in most cases
mode: local mode: local