Multi language support - fern debug (#1307)
--------- Co-authored-by: Louis <lpglm@orange.fr> Co-authored-by: LeMoussel <cnhx27@gmail.com>
This commit is contained in:
		
							parent
							
								
									e8d88f8952
								
							
						
					
					
						commit
						944c43bfa8
					
				|  | @ -28,7 +28,7 @@ jobs: | ||||||
|         env: |         env: | ||||||
|           FERN_TOKEN: ${{ secrets.FERN_TOKEN }} |           FERN_TOKEN: ${{ secrets.FERN_TOKEN }} | ||||||
|         run: | |         run: | | ||||||
|           output=$(fern generate --docs --preview) |           output=$(fern generate --docs --preview --log-level debug) | ||||||
|           echo "$output" |           echo "$output" | ||||||
|           # Extract the URL |           # Extract the URL | ||||||
|           preview_url=$(echo "$output" | grep -oP '(?<=Published docs to )https://[^\s]*') |           preview_url=$(echo "$output" | grep -oP '(?<=Published docs to )https://[^\s]*') | ||||||
|  |  | ||||||
|  | @ -44,6 +44,10 @@ navigation: | ||||||
|   # Manual of privateGPT: how to use it and configure it |   # Manual of privateGPT: how to use it and configure it | ||||||
|   - tab: manual |   - tab: manual | ||||||
|     layout: |     layout: | ||||||
|  |       - section: General configuration | ||||||
|  |         contents: | ||||||
|  |           - page: Configuration | ||||||
|  |             path: ./docs/pages/manual/settings.mdx | ||||||
|       - section: Document management |       - section: Document management | ||||||
|         contents: |         contents: | ||||||
|           - page: Ingestion |           - page: Ingestion | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| # Configure your private GPT | # Settings and profiles for your private GPT | ||||||
| 
 | 
 | ||||||
| The configuration of your private GPT server is done thanks to `settings` files (more precisely `settings.yaml`). | The configuration of your private GPT server is done thanks to `settings` files (more precisely `settings.yaml`). | ||||||
| These text files are written using the [YAML](https://en.wikipedia.org/wiki/YAML) syntax. | These text files are written using the [YAML](https://en.wikipedia.org/wiki/YAML) syntax. | ||||||
|  | @ -1,8 +1,41 @@ | ||||||
| ## List of working LLM | # List of working LLM | ||||||
| 
 | 
 | ||||||
| **Do you have any working combination of LLM and embeddings?** | **Do you have any working combination of LLM and embeddings?** | ||||||
| Please open a PR to add it to the list, and come on our Discord to tell us about it! | Please open a PR to add it to the list, and come on our Discord to tell us about it! | ||||||
| 
 | 
 | ||||||
|  | ## Prompt style | ||||||
|  | 
 | ||||||
|  | LLMs might have been trained with different prompt styles. | ||||||
|  | The prompt style is the way the prompt is written, and how the system message is injected in the prompt. | ||||||
|  | 
 | ||||||
|  | For example, `llama2` looks like this: | ||||||
|  | ```text | ||||||
|  | <s>[INST] <<SYS>> | ||||||
|  | {{ system_prompt }} | ||||||
|  | <</SYS>> | ||||||
|  | 
 | ||||||
|  | {{ user_message }} [/INST] | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | While `default` (the `llama_index` default) looks like this: | ||||||
|  | ```text | ||||||
|  | system: {{ system_prompt }} | ||||||
|  | user: {{ user_message }} | ||||||
|  | assistant: {{ assistant_message }} | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | And the "`tag`" style looks like this: | ||||||
|  | 
 | ||||||
|  | ```text | ||||||
|  | <|system|>: {{ system_prompt }} | ||||||
|  | <|user|>: {{ user_message }} | ||||||
|  | <|assistant|>: {{ assistant_message }} | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | Some LLMs will not understand this prompt style, and will not work (returning nothing). | ||||||
|  | You can try to change the prompt style to `default` (or `tag`) in the settings, and it will | ||||||
|  | change the way the messages are formatted to be passed to the LLM. | ||||||
|  | 
 | ||||||
| ## Example of configuration | ## Example of configuration | ||||||
| 
 | 
 | ||||||
| You might want to change the prompt depending on the language and model you are using. | You might want to change the prompt depending on the language and model you are using. | ||||||
|  | @ -15,6 +48,7 @@ local: | ||||||
|   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF |   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF | ||||||
|   llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf |   llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf | ||||||
|   embedding_hf_model_name: BAAI/bge-small-en-v1.5 |   embedding_hf_model_name: BAAI/bge-small-en-v1.5 | ||||||
|  |   prompt_style: "llama2" | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| ### French, with instructions | ### French, with instructions | ||||||
|  | @ -25,4 +59,26 @@ local: | ||||||
|   llm_hf_repo_id: TheBloke/Vigogne-2-7B-Instruct-GGUF |   llm_hf_repo_id: TheBloke/Vigogne-2-7B-Instruct-GGUF | ||||||
|   llm_hf_model_file: vigogne-2-7b-instruct.Q4_K_M.gguf |   llm_hf_model_file: vigogne-2-7b-instruct.Q4_K_M.gguf | ||||||
|   embedding_hf_model_name: dangvantuan/sentence-camembert-base |   embedding_hf_model_name: dangvantuan/sentence-camembert-base | ||||||
|  |   prompt_style: "default" | ||||||
|  |   # prompt_style: "tag" # also works | ||||||
|  |   # The default system prompt is injected only when the `prompt_style` != default, and there are no system message in the discussion | ||||||
|  |   # default_system_prompt: Vous êtes un assistant IA qui répond à la question posée à la fin en utilisant le contexte suivant. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse. Veuillez répondre exclusivement en français. | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | You might want to change the prompt as the one above might not directly answer your question. | ||||||
|  | You can read online about how to write a good prompt, but in a nutshell, make it (extremely) directive. | ||||||
|  | 
 | ||||||
|  | You can try and troubleshot your prompt by writing multiline requests in the UI, while | ||||||
|  | writing your interaction with the model, for example: | ||||||
|  | 
 | ||||||
|  | ```text | ||||||
|  | Tu es un programmeur senior qui programme en python et utilise le framework fastapi. Ecrit moi un serveur qui retourne "hello world". | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | Another example: | ||||||
|  | ```text | ||||||
|  | Context: None | ||||||
|  | Situation: tu es au milieu d'un champ. | ||||||
|  | Tache: va a la rivière, en bas du champ. | ||||||
|  | Décrit comment aller a la rivière. | ||||||
| ``` | ``` | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| { | { | ||||||
|   "organization": "privategpt", |   "organization": "privategpt", | ||||||
|   "version": "0.15.0-rc80" |   "version": "0.15.3" | ||||||
| } | } | ||||||
|  | @ -1,8 +1,8 @@ | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index.llms import MockLLM | from llama_index.llms import MockLLM | ||||||
| from llama_index.llms.base import LLM | from llama_index.llms.base import LLM | ||||||
| from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt |  | ||||||
| 
 | 
 | ||||||
|  | from private_gpt.components.llm.prompt_helper import get_prompt_style | ||||||
| from private_gpt.paths import models_path | from private_gpt.paths import models_path | ||||||
| from private_gpt.settings.settings import Settings | from private_gpt.settings.settings import Settings | ||||||
| 
 | 
 | ||||||
|  | @ -17,6 +17,11 @@ class LLMComponent: | ||||||
|             case "local": |             case "local": | ||||||
|                 from llama_index.llms import LlamaCPP |                 from llama_index.llms import LlamaCPP | ||||||
| 
 | 
 | ||||||
|  |                 prompt_style_cls = get_prompt_style(settings.local.prompt_style) | ||||||
|  |                 prompt_style = prompt_style_cls( | ||||||
|  |                     default_system_prompt=settings.local.default_system_prompt | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|                 self.llm = LlamaCPP( |                 self.llm = LlamaCPP( | ||||||
|                     model_path=str(models_path / settings.local.llm_hf_model_file), |                     model_path=str(models_path / settings.local.llm_hf_model_file), | ||||||
|                     temperature=0.1, |                     temperature=0.1, | ||||||
|  | @ -27,8 +32,8 @@ class LLMComponent: | ||||||
|                     # All to GPU |                     # All to GPU | ||||||
|                     model_kwargs={"n_gpu_layers": -1}, |                     model_kwargs={"n_gpu_layers": -1}, | ||||||
|                     # transform inputs into Llama2 format |                     # transform inputs into Llama2 format | ||||||
|                     messages_to_prompt=messages_to_prompt, |                     messages_to_prompt=prompt_style.messages_to_prompt, | ||||||
|                     completion_to_prompt=completion_to_prompt, |                     completion_to_prompt=prompt_style.completion_to_prompt, | ||||||
|                     verbose=True, |                     verbose=True, | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,179 @@ | ||||||
|  | import abc | ||||||
|  | import logging | ||||||
|  | from collections.abc import Sequence | ||||||
|  | from typing import Any, Literal | ||||||
|  | 
 | ||||||
|  | from llama_index.llms import ChatMessage, MessageRole | ||||||
|  | from llama_index.llms.llama_utils import ( | ||||||
|  |     DEFAULT_SYSTEM_PROMPT, | ||||||
|  |     completion_to_prompt, | ||||||
|  |     messages_to_prompt, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AbstractPromptStyle(abc.ABC): | ||||||
|  |     """Abstract class for prompt styles. | ||||||
|  | 
 | ||||||
|  |     This class is used to format a series of messages into a prompt that can be | ||||||
|  |     understood by the models. A series of messages represents the interaction(s) | ||||||
|  |     between a user and an assistant. This series of messages can be considered as a | ||||||
|  |     session between a user X and an assistant Y.This session holds, through the | ||||||
|  |     messages, the state of the conversation. This session, to be understood by the | ||||||
|  |     model, needs to be formatted into a prompt (i.e. a string that the models | ||||||
|  |     can understand). Prompts can be formatted in different ways, | ||||||
|  |     depending on the model. | ||||||
|  | 
 | ||||||
|  |     The implementations of this class represent the different ways to format a | ||||||
|  |     series of messages into a prompt. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||||
|  |         logger.debug("Initializing prompt_style=%s", self.__class__.__name__) | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def _completion_to_prompt(self, completion: str) -> str: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | ||||||
|  |         prompt = self._messages_to_prompt(messages) | ||||||
|  |         logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt) | ||||||
|  |         return prompt | ||||||
|  | 
 | ||||||
|  |     def completion_to_prompt(self, completion: str) -> str: | ||||||
|  |         prompt = self._completion_to_prompt(completion) | ||||||
|  |         logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) | ||||||
|  |         return prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC): | ||||||
|  |     _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT | ||||||
|  | 
 | ||||||
|  |     def __init__(self, default_system_prompt: str | None) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         logger.debug("Got default_system_prompt='%s'", default_system_prompt) | ||||||
|  |         self.default_system_prompt = default_system_prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DefaultPromptStyle(AbstractPromptStyle): | ||||||
|  |     """Default prompt style that uses the defaults from llama_utils. | ||||||
|  | 
 | ||||||
|  |     It basically passes None to the LLM, indicating it should use | ||||||
|  |     the default functions. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  |         # Hacky way to override the functions | ||||||
|  |         # Override the functions to be None, and pass None to the LLM. | ||||||
|  |         self.messages_to_prompt = None  # type: ignore[method-assign, assignment] | ||||||
|  |         self.completion_to_prompt = None  # type: ignore[method-assign, assignment] | ||||||
|  | 
 | ||||||
|  |     def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | ||||||
|  |         return "" | ||||||
|  | 
 | ||||||
|  |     def _completion_to_prompt(self, completion: str) -> str: | ||||||
|  |         return "" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt): | ||||||
|  |     """Simple prompt style that just uses the default llama_utils functions. | ||||||
|  | 
 | ||||||
|  |     It transforms the sequence of messages into a prompt that should look like: | ||||||
|  |     ```text | ||||||
|  |     <s> [INST] <<SYS>> your system prompt here. <</SYS>> | ||||||
|  | 
 | ||||||
|  |     user message here [/INST] assistant (model) response here </s> | ||||||
|  |     ``` | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, default_system_prompt: str | None = None) -> None: | ||||||
|  |         # If no system prompt is given, the default one of the implementation is used. | ||||||
|  |         super().__init__(default_system_prompt=default_system_prompt) | ||||||
|  | 
 | ||||||
|  |     def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | ||||||
|  |         return messages_to_prompt(messages, self.default_system_prompt) | ||||||
|  | 
 | ||||||
|  |     def _completion_to_prompt(self, completion: str) -> str: | ||||||
|  |         return completion_to_prompt(completion, self.default_system_prompt) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TagPromptStyle(AbstractPromptStyleWithSystemPrompt): | ||||||
|  |     """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. | ||||||
|  | 
 | ||||||
|  |     It transforms the sequence of messages into a prompt that should look like: | ||||||
|  |     ```text | ||||||
|  |     <|system|>: your system prompt here. | ||||||
|  |     <|user|>: user message here | ||||||
|  |     (possibly with context and question) | ||||||
|  |     <|assistant|>: assistant (model) response here. | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2? | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, default_system_prompt: str | None = None) -> None: | ||||||
|  |         # We have to define a default system prompt here as the LLM will not | ||||||
|  |         # use the default llama_utils functions. | ||||||
|  |         default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT | ||||||
|  |         super().__init__(default_system_prompt) | ||||||
|  |         self.system_prompt: str = default_system_prompt | ||||||
|  | 
 | ||||||
|  |     def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: | ||||||
|  |         messages = list(messages) | ||||||
|  |         if messages[0].role != MessageRole.SYSTEM: | ||||||
|  |             logger.info( | ||||||
|  |                 "Adding system_promt='%s' to the given messages as there are none given in the session", | ||||||
|  |                 self.system_prompt, | ||||||
|  |             ) | ||||||
|  |             messages = [ | ||||||
|  |                 ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM), | ||||||
|  |                 *messages, | ||||||
|  |             ] | ||||||
|  |         return self._format_messages_to_prompt(messages) | ||||||
|  | 
 | ||||||
|  |     def _completion_to_prompt(self, completion: str) -> str: | ||||||
|  |         return ( | ||||||
|  |             f"<|system|>: {self.system_prompt.strip()}\n" | ||||||
|  |             f"<|user|>: {completion.strip()}\n" | ||||||
|  |             "<|assistant|>: " | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _format_messages_to_prompt(messages: list[ChatMessage]) -> str: | ||||||
|  |         """Format message to prompt with `<|ROLE|>: MSG` style.""" | ||||||
|  |         assert messages[0].role == MessageRole.SYSTEM | ||||||
|  |         prompt = "" | ||||||
|  |         for message in messages: | ||||||
|  |             role = message.role | ||||||
|  |             content = message.content or "" | ||||||
|  |             message_from_user = f"<|{role.lower()}|>: {content.strip()}" | ||||||
|  |             message_from_user += "\n" | ||||||
|  |             prompt += message_from_user | ||||||
|  |         # we are missing the last <|assistant|> tag that will trigger a completion | ||||||
|  |         prompt += "<|assistant|>: " | ||||||
|  |         return prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_prompt_style( | ||||||
|  |     prompt_style: Literal["default", "llama2", "tag"] | None | ||||||
|  | ) -> type[AbstractPromptStyle]: | ||||||
|  |     """Get the prompt style to use from the given string. | ||||||
|  | 
 | ||||||
|  |     :param prompt_style: The prompt style to use. | ||||||
|  |     :return: The prompt style to use. | ||||||
|  |     """ | ||||||
|  |     if prompt_style is None or prompt_style == "default": | ||||||
|  |         return DefaultPromptStyle | ||||||
|  |     elif prompt_style == "llama2": | ||||||
|  |         return Llama2PromptStyle | ||||||
|  |     elif prompt_style == "tag": | ||||||
|  |         return TagPromptStyle | ||||||
|  |     raise ValueError(f"Unknown prompt_style='{prompt_style}'") | ||||||
|  | @ -91,7 +91,28 @@ class VectorstoreSettings(BaseModel): | ||||||
| class LocalSettings(BaseModel): | class LocalSettings(BaseModel): | ||||||
|     llm_hf_repo_id: str |     llm_hf_repo_id: str | ||||||
|     llm_hf_model_file: str |     llm_hf_model_file: str | ||||||
|     embedding_hf_model_name: str |     embedding_hf_model_name: str = Field( | ||||||
|  |         description="Name of the HuggingFace model to use for embeddings" | ||||||
|  |     ) | ||||||
|  |     prompt_style: Literal["default", "llama2", "tag"] = Field( | ||||||
|  |         "llama2", | ||||||
|  |         description=( | ||||||
|  |             "The prompt style to use for the chat engine. " | ||||||
|  |             "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" | ||||||
|  |             "If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n" | ||||||
|  |             "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" | ||||||
|  |             "`llama2` is the historic behaviour. `default` might work better with your custom models." | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  |     default_system_prompt: str | None = Field( | ||||||
|  |         None, | ||||||
|  |         description=( | ||||||
|  |             "The default system prompt to use for the chat engine. " | ||||||
|  |             "If none is given - use the default system prompt (from the llama_index). " | ||||||
|  |             "Please note that the default prompt might not be the same for all prompt styles. " | ||||||
|  |             "Also note that this is only used if the first message is not a system message. " | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SagemakerSettings(BaseModel): | class SagemakerSettings(BaseModel): | ||||||
|  |  | ||||||
|  | @ -30,6 +30,7 @@ qdrant: | ||||||
|   path: local_data/private_gpt/qdrant |   path: local_data/private_gpt/qdrant | ||||||
| 
 | 
 | ||||||
| local: | local: | ||||||
|  |   prompt_style: "llama2" | ||||||
|   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF |   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF | ||||||
|   llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf |   llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf | ||||||
|   embedding_hf_model_name: BAAI/bge-small-en-v1.5 |   embedding_hf_model_name: BAAI/bge-small-en-v1.5 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,128 @@ | ||||||
|  | import pytest | ||||||
|  | from llama_index.llms import ChatMessage, MessageRole | ||||||
|  | 
 | ||||||
|  | from private_gpt.components.llm.prompt_helper import ( | ||||||
|  |     DefaultPromptStyle, | ||||||
|  |     Llama2PromptStyle, | ||||||
|  |     TagPromptStyle, | ||||||
|  |     get_prompt_style, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     ("prompt_style", "expected_prompt_style"), | ||||||
|  |     [ | ||||||
|  |         ("default", DefaultPromptStyle), | ||||||
|  |         ("llama2", Llama2PromptStyle), | ||||||
|  |         ("tag", TagPromptStyle), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_get_prompt_style_success(prompt_style, expected_prompt_style): | ||||||
|  |     assert get_prompt_style(prompt_style) == expected_prompt_style | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_prompt_style_failure(): | ||||||
|  |     prompt_style = "unknown" | ||||||
|  |     with pytest.raises(ValueError) as exc_info: | ||||||
|  |         get_prompt_style(prompt_style) | ||||||
|  |     assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_tag_prompt_style_format(): | ||||||
|  |     prompt_style = TagPromptStyle() | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         "<|system|>: You are an AI assistant.\n" | ||||||
|  |         "<|user|>: Hello, how are you doing?\n" | ||||||
|  |         "<|assistant|>: " | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_tag_prompt_style_format_with_system_prompt(): | ||||||
|  |     system_prompt = "This is a system prompt from configuration." | ||||||
|  |     prompt_style = TagPromptStyle(default_system_prompt=system_prompt) | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         f"<|system|>: {system_prompt}\n" | ||||||
|  |         "<|user|>: Hello, how are you doing?\n" | ||||||
|  |         "<|assistant|>: " | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
|  | 
 | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage( | ||||||
|  |             content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | ||||||
|  |         ), | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         "<|system|>: FOO BAR Custom sys prompt from messages.\n" | ||||||
|  |         "<|user|>: Hello, how are you doing?\n" | ||||||
|  |         "<|assistant|>: " | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_llama2_prompt_style_format(): | ||||||
|  |     prompt_style = Llama2PromptStyle() | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         "<s> [INST] <<SYS>>\n" | ||||||
|  |         " You are an AI assistant. \n" | ||||||
|  |         "<</SYS>>\n" | ||||||
|  |         "\n" | ||||||
|  |         " Hello, how are you doing? [/INST]" | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_llama2_prompt_style_with_system_prompt(): | ||||||
|  |     system_prompt = "This is a system prompt from configuration." | ||||||
|  |     prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt) | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         "<s> [INST] <<SYS>>\n" | ||||||
|  |         f" {system_prompt} \n" | ||||||
|  |         "<</SYS>>\n" | ||||||
|  |         "\n" | ||||||
|  |         " Hello, how are you doing? [/INST]" | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
|  | 
 | ||||||
|  |     messages = [ | ||||||
|  |         ChatMessage( | ||||||
|  |             content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | ||||||
|  |         ), | ||||||
|  |         ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     expected_prompt = ( | ||||||
|  |         "<s> [INST] <<SYS>>\n" | ||||||
|  |         " FOO BAR Custom sys prompt from messages. \n" | ||||||
|  |         "<</SYS>>\n" | ||||||
|  |         "\n" | ||||||
|  |         " Hello, how are you doing? [/INST]" | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert prompt_style.messages_to_prompt(messages) == expected_prompt | ||||||
		Loading…
	
		Reference in New Issue