diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py
index 985d217..7715820 100644
--- a/private_gpt/components/llm/prompt_helper.py
+++ b/private_gpt/components/llm/prompt_helper.py
@@ -173,18 +173,22 @@ class TagPromptStyle(AbstractPromptStyle):
class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- prompt = ""
+ inst_buffer = []
+ text = ""
for message in messages:
- role = message.role
- content = message.content or ""
- if role.lower() == "system":
- message_from_user = f"[INST] {content.strip()} [/INST]"
- prompt += message_from_user
- elif role.lower() == "user":
- prompt += ""
- message_from_user = f"[INST] {content.strip()} [/INST]"
- prompt += message_from_user
- return prompt
+ if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
+ inst_buffer.append(str(message.content).strip())
+ elif message.role == MessageRole.ASSISTANT:
+ text += "[INST] " + "\n".join(inst_buffer) + " [/INST]"
+ text += " " + str(message.content).strip() + ""
+ inst_buffer.clear()
+ else:
+ raise ValueError(f"Unknown message role {message.role}")
+
+ if len(inst_buffer) > 0:
+ text += "[INST] " + "\n".join(inst_buffer) + " [/INST]"
+
+ return text
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py
index 3b5af91..ef76437 100644
--- a/tests/test_prompt_helper.py
+++ b/tests/test_prompt_helper.py
@@ -69,17 +69,21 @@ def test_tag_prompt_style_format_with_system_prompt():
def test_mistral_prompt_style_format():
prompt_style = MistralPromptStyle()
messages = [
- ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
- ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
+ ChatMessage(content="A", role=MessageRole.SYSTEM),
+ ChatMessage(content="B", role=MessageRole.USER),
]
-
- expected_prompt = (
- "[INST] You are an AI assistant. [/INST]"
- "[INST] Hello, how are you doing? [/INST]"
- )
-
+ expected_prompt = "[INST] A\nB [/INST]"
assert prompt_style.messages_to_prompt(messages) == expected_prompt
+ messages2 = [
+ ChatMessage(content="A", role=MessageRole.SYSTEM),
+ ChatMessage(content="B", role=MessageRole.USER),
+ ChatMessage(content="C", role=MessageRole.ASSISTANT),
+ ChatMessage(content="D", role=MessageRole.USER),
+ ]
+ expected_prompt2 = "[INST] A\nB [/INST] C[INST] D [/INST]"
+ assert prompt_style.messages_to_prompt(messages2) == expected_prompt2
+
def test_chatml_prompt_style_format():
prompt_style = ChatMLPromptStyle()