Allow passing a system prompt (#1318)

This commit is contained in:
Iván Martínez 2023-11-29 15:51:19 +01:00 committed by GitHub
parent 9c192ddd73
commit 64ed9cd872
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1129 additions and 1039 deletions

View File

@ -23,7 +23,7 @@
"Contextual Completions" "Contextual Completions"
], ],
"summary": "Completion", "summary": "Completion",
"description": "We recommend most users use our Chat completions API.\n\nGiven a prompt, the model will return one predicted completion. If `use_context`\nis set to `true`, the model will use context coming from the ingested documents\nto create the response. The documents being used can be filtered using the\n`context_filter` and passing the document IDs to be used. Ingested documents IDs\ncan be found using `/ingest/list` endpoint. If you want all ingested documents to\nbe used, remove `context_filter` altogether.\n\nWhen using `'include_sources': true`, the API will return the source Chunks used\nto create the response, which come from the context provided.\n\nWhen using `'stream': true`, the API will return data chunks following [OpenAI's\nstreaming model](https://platform.openai.com/docs/api-reference/chat/streaming):\n```\n{\"id\":\"12345\",\"object\":\"completion.chunk\",\"created\":1694268190,\n\"model\":\"private-gpt\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\n\"finish_reason\":null}]}\n```", "description": "We recommend most users use our Chat completions API.\n\nGiven a prompt, the model will return one predicted completion.\n\nOptionally include a `system_prompt` to influence the way the LLM answers.\n\nIf `use_context`\nis set to `true`, the model will use context coming from the ingested documents\nto create the response. The documents being used can be filtered using the\n`context_filter` and passing the document IDs to be used. Ingested documents IDs\ncan be found using `/ingest/list` endpoint. If you want all ingested documents to\nbe used, remove `context_filter` altogether.\n\nWhen using `'include_sources': true`, the API will return the source Chunks used\nto create the response, which come from the context provided.\n\nWhen using `'stream': true`, the API will return data chunks following [OpenAI's\nstreaming model](https://platform.openai.com/docs/api-reference/chat/streaming):\n```\n{\"id\":\"12345\",\"object\":\"completion.chunk\",\"created\":1694268190,\n\"model\":\"private-gpt\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\n\"finish_reason\":null}]}\n```",
"operationId": "prompt_completion_v1_completions_post", "operationId": "prompt_completion_v1_completions_post",
"requestBody": { "requestBody": {
"content": { "content": {
@ -65,7 +65,7 @@
"Contextual Completions" "Contextual Completions"
], ],
"summary": "Chat Completion", "summary": "Chat Completion",
"description": "Given a list of messages comprising a conversation, return a response.\n\nIf `use_context` is set to `true`, the model will use context coming\nfrom the ingested documents to create the response. The documents being used can\nbe filtered using the `context_filter` and passing the document IDs to be used.\nIngested documents IDs can be found using `/ingest/list` endpoint. If you want\nall ingested documents to be used, remove `context_filter` altogether.\n\nWhen using `'include_sources': true`, the API will return the source Chunks used\nto create the response, which come from the context provided.\n\nWhen using `'stream': true`, the API will return data chunks following [OpenAI's\nstreaming model](https://platform.openai.com/docs/api-reference/chat/streaming):\n```\n{\"id\":\"12345\",\"object\":\"completion.chunk\",\"created\":1694268190,\n\"model\":\"private-gpt\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\n\"finish_reason\":null}]}\n```", "description": "Given a list of messages comprising a conversation, return a response.\n\nOptionally include a `system_prompt` to influence the way the LLM answers.\n\nIf `use_context` is set to `true`, the model will use context coming\nfrom the ingested documents to create the response. The documents being used can\nbe filtered using the `context_filter` and passing the document IDs to be used.\nIngested documents IDs can be found using `/ingest/list` endpoint. If you want\nall ingested documents to be used, remove `context_filter` altogether.\n\nWhen using `'include_sources': true`, the API will return the source Chunks used\nto create the response, which come from the context provided.\n\nWhen using `'stream': true`, the API will return data chunks following [OpenAI's\nstreaming model](https://platform.openai.com/docs/api-reference/chat/streaming):\n```\n{\"id\":\"12345\",\"object\":\"completion.chunk\",\"created\":1694268190,\n\"model\":\"private-gpt\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\n\"finish_reason\":null}]}\n```",
"operationId": "chat_completion_v1_chat_completions_post", "operationId": "chat_completion_v1_chat_completions_post",
"requestBody": { "requestBody": {
"content": { "content": {
@ -338,6 +338,17 @@
"type": "array", "type": "array",
"title": "Messages" "title": "Messages"
}, },
"system_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "System Prompt"
},
"use_context": { "use_context": {
"type": "boolean", "type": "boolean",
"title": "Use Context", "title": "Use Context",
@ -384,6 +395,7 @@
} }
], ],
"stream": false, "stream": false,
"system_prompt": "You are a rapper. Always answer with a rap.",
"use_context": true "use_context": true
} }
] ]
@ -391,10 +403,7 @@
"Chunk": { "Chunk": {
"properties": { "properties": {
"object": { "object": {
"type": "string", "const": "context.chunk",
"enum": [
"context.chunk"
],
"title": "Object" "title": "Object"
}, },
"score": { "score": {
@ -506,17 +515,11 @@
"ChunksResponse": { "ChunksResponse": {
"properties": { "properties": {
"object": { "object": {
"type": "string", "const": "list",
"enum": [
"list"
],
"title": "Object" "title": "Object"
}, },
"model": { "model": {
"type": "string", "const": "private-gpt",
"enum": [
"private-gpt"
],
"title": "Model" "title": "Model"
}, },
"data": { "data": {
@ -541,6 +544,17 @@
"type": "string", "type": "string",
"title": "Prompt" "title": "Prompt"
}, },
"system_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "System Prompt"
},
"use_context": { "use_context": {
"type": "boolean", "type": "boolean",
"title": "Use Context", "title": "Use Context",
@ -616,10 +630,7 @@
"title": "Index" "title": "Index"
}, },
"object": { "object": {
"type": "string", "const": "embedding",
"enum": [
"embedding"
],
"title": "Object" "title": "Object"
}, },
"embedding": { "embedding": {
@ -670,17 +681,11 @@
"EmbeddingsResponse": { "EmbeddingsResponse": {
"properties": { "properties": {
"object": { "object": {
"type": "string", "const": "list",
"enum": [
"list"
],
"title": "Object" "title": "Object"
}, },
"model": { "model": {
"type": "string", "const": "private-gpt",
"enum": [
"private-gpt"
],
"title": "Model" "title": "Model"
}, },
"data": { "data": {
@ -715,33 +720,22 @@
"HealthResponse": { "HealthResponse": {
"properties": { "properties": {
"status": { "status": {
"type": "string", "const": "ok",
"enum": [ "title": "Status",
"ok" "default": "ok"
],
"title": "Status"
} }
}, },
"type": "object", "type": "object",
"required": [
"status"
],
"title": "HealthResponse" "title": "HealthResponse"
}, },
"IngestResponse": { "IngestResponse": {
"properties": { "properties": {
"object": { "object": {
"type": "string", "const": "list",
"enum": [
"list"
],
"title": "Object" "title": "Object"
}, },
"model": { "model": {
"type": "string", "const": "private-gpt",
"enum": [
"private-gpt"
],
"title": "Model" "title": "Model"
}, },
"data": { "data": {
@ -763,10 +757,7 @@
"IngestedDoc": { "IngestedDoc": {
"properties": { "properties": {
"object": { "object": {
"type": "string", "const": "ingest.document",
"enum": [
"ingest.document"
],
"title": "Object" "title": "Object"
}, },
"doc_id": { "doc_id": {
@ -888,10 +879,7 @@
] ]
}, },
"model": { "model": {
"type": "string", "const": "private-gpt",
"enum": [
"private-gpt"
],
"title": "Model" "title": "Model"
}, },
"choices": { "choices": {

View File

@ -28,10 +28,14 @@ class ChatBody(BaseModel):
"examples": [ "examples": [
{ {
"messages": [ "messages": [
{
"role": "system",
"content": "You are a rapper. Always answer with a rap.",
},
{ {
"role": "user", "role": "user",
"content": "How do you fry an egg?", "content": "How do you fry an egg?",
} },
], ],
"stream": False, "stream": False,
"use_context": True, "use_context": True,
@ -56,6 +60,9 @@ def chat_completion(
) -> OpenAICompletion | StreamingResponse: ) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response. """Given a list of messages comprising a conversation, return a response.
Optionally include an initial `role: system` message to influence the way
the LLM answers.
If `use_context` is set to `true`, the model will use context coming If `use_context` is set to `true`, the model will use context coming
from the ingested documents to create the response. The documents being used can from the ingested documents to create the response. The documents being used can
be filtered using the `context_filter` and passing the document IDs to be used. be filtered using the `context_filter` and passing the document IDs to be used.
@ -79,7 +86,9 @@ def chat_completion(
] ]
if body.stream: if body.stream:
completion_gen = service.stream_chat( completion_gen = service.stream_chat(
all_messages, body.use_context, body.context_filter messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
) )
return StreamingResponse( return StreamingResponse(
to_openai_sse_stream( to_openai_sse_stream(
@ -89,7 +98,11 @@ def chat_completion(
media_type="text/event-stream", media_type="text/event-stream",
) )
else: else:
completion = service.chat(all_messages, body.use_context, body.context_filter) completion = service.chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return to_openai_response( return to_openai_response(
completion.response, completion.sources if body.include_sources else None completion.response, completion.sources if body.include_sources else None
) )

View File

@ -1,12 +1,13 @@
from dataclasses import dataclass
from injector import inject, singleton from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.chat_engine import ContextChatEngine from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.chat_engine.types import ( from llama_index.chat_engine.types import (
BaseChatEngine, BaseChatEngine,
) )
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.llm_predictor.utils import stream_chat_response_to_tokens from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms import ChatMessage
from llama_index.types import TokenGen from llama_index.types import TokenGen
from pydantic import BaseModel from pydantic import BaseModel
@ -30,6 +31,40 @@ class CompletionGen(BaseModel):
sources: list[Chunk] | None = None sources: list[Chunk] | None = None
@dataclass
class ChatEngineInput:
system_message: ChatMessage | None = None
last_message: ChatMessage | None = None
chat_history: list[ChatMessage] | None = None
@classmethod
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
# Detect if there is a system message, extract the last message and chat history
system_message = (
messages[0]
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
else None
)
last_message = (
messages[-1]
if len(messages) > 0 and messages[-1].role == MessageRole.USER
else None
)
# Remove from messages list the system message and last message,
# if they exist. The rest is the chat history.
if system_message:
messages.pop(0)
if last_message:
messages.pop(-1)
chat_history = messages if len(messages) > 0 else None
return cls(
system_message=system_message,
last_message=last_message,
chat_history=chat_history,
)
@singleton @singleton
class ChatService: class ChatService:
@inject @inject
@ -58,18 +93,28 @@ class ChatService:
) )
def _chat_engine( def _chat_engine(
self, context_filter: ContextFilter | None = None self,
system_prompt: str | None = None,
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> BaseChatEngine: ) -> BaseChatEngine:
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter index=self.index, context_filter=context_filter
) )
return ContextChatEngine.from_defaults( return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever, retriever=vector_index_retriever,
service_context=self.service_context, service_context=self.service_context,
node_postprocessors=[ node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"), MetadataReplacementPostProcessor(target_metadata_key="window"),
], ],
) )
else:
return SimpleChatEngine.from_defaults(
system_prompt=system_prompt,
service_context=self.service_context,
)
def stream_chat( def stream_chat(
self, self,
@ -77,24 +122,34 @@ class ChatService:
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> CompletionGen: ) -> CompletionGen:
if use_context: chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = messages[-1].content last_message = (
chat_engine = self._chat_engine(context_filter=context_filter) chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
streaming_response = chat_engine.stream_chat( streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",
chat_history=messages[:-1], chat_history=chat_history,
) )
sources = [ sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
Chunk.from_node(node) for node in streaming_response.source_nodes
]
completion_gen = CompletionGen( completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources response=streaming_response.response_gen, sources=sources
) )
else:
stream = self.llm_service.llm.stream_chat(messages)
completion_gen = CompletionGen(
response=stream_chat_response_to_tokens(stream)
)
return completion_gen return completion_gen
def chat( def chat(
@ -103,18 +158,30 @@ class ChatService:
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> Completion: ) -> Completion:
if use_context: chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = messages[-1].content last_message = (
chat_engine = self._chat_engine(context_filter=context_filter) chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
wrapped_response = chat_engine.chat( wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",
chat_history=messages[:-1], chat_history=chat_history,
) )
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources) completion = Completion(response=wrapped_response.response, sources=sources)
else:
chat_response = self.llm_service.llm.chat(messages)
response_content = chat_response.message.content
response = response_content if response_content is not None else ""
completion = Completion(response=response)
return completion return completion

View File

@ -15,6 +15,7 @@ completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated
class CompletionsBody(BaseModel): class CompletionsBody(BaseModel):
prompt: str prompt: str
system_prompt: str | None = None
use_context: bool = False use_context: bool = False
context_filter: ContextFilter | None = None context_filter: ContextFilter | None = None
include_sources: bool = True include_sources: bool = True
@ -25,6 +26,7 @@ class CompletionsBody(BaseModel):
"examples": [ "examples": [
{ {
"prompt": "How do you fry an egg?", "prompt": "How do you fry an egg?",
"system_prompt": "You are a rapper. Always answer with a rap.",
"stream": False, "stream": False,
"use_context": False, "use_context": False,
"include_sources": False, "include_sources": False,
@ -46,7 +48,11 @@ def prompt_completion(
) -> OpenAICompletion | StreamingResponse: ) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API. """We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context` Given a prompt, the model will return one predicted completion.
Optionally include a `system_prompt` to influence the way the LLM answers.
If `use_context`
is set to `true`, the model will use context coming from the ingested documents is set to `true`, the model will use context coming from the ingested documents
to create the response. The documents being used can be filtered using the to create the response. The documents being used can be filtered using the
`context_filter` and passing the document IDs to be used. Ingested documents IDs `context_filter` and passing the document IDs to be used. Ingested documents IDs
@ -64,9 +70,13 @@ def prompt_completion(
"finish_reason":null}]} "finish_reason":null}]}
``` ```
""" """
message = OpenAIMessage(content=body.prompt, role="user") messages = [OpenAIMessage(content=body.prompt, role="user")]
# If system prompt is passed, create a fake message with the system prompt.
if body.system_prompt:
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
chat_body = ChatBody( chat_body = ChatBody(
messages=[message], messages=messages,
use_context=body.use_context, use_context=body.use_context,
stream=body.stream, stream=body.stream,
include_sources=body.include_sources, include_sources=body.include_sources,

View File

@ -116,6 +116,17 @@ class PrivateGptUi:
all_messages = [*build_history(), new_message] all_messages = [*build_history(), new_message]
match mode: match mode:
case "Query Docs": case "Query Docs":
# Add a system message to force the behaviour of the LLM
# to answer only questions about the provided context.
all_messages.insert(
0,
ChatMessage(
content="You can only answer questions about the provided context. If you know the answer "
"but it is not based in the provided context, don't provide the answer, just state "
"the answer is not in the context provided.",
role=MessageRole.SYSTEM,
),
)
query_stream = self._chat_service.stream_chat( query_stream = self._chat_service.stream_chat(
messages=all_messages, messages=all_messages,
use_context=True, use_context=True,

View File

@ -22,6 +22,7 @@ ui:
llm: llm:
mode: local mode: local
embedding: embedding:
# Should be matching the value above in most cases # Should be matching the value above in most cases
mode: local mode: local