From c7212ac7cc891f9e3c713cc206ae9807c5dfdeb6 Mon Sep 17 00:00:00 2001 From: Pablo Orgaz Date: Thu, 30 May 2024 15:41:16 +0200 Subject: [PATCH] fix(LLM): mistral ignoring assistant messages (#1954) * fix: mistral ignoring assistant messages * fix: typing * fix: fix tests --- private_gpt/components/llm/prompt_helper.py | 26 ++++++++++++--------- tests/test_prompt_helper.py | 20 +++++++++------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 985d217..7715820 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -173,18 +173,22 @@ class TagPromptStyle(AbstractPromptStyle): class MistralPromptStyle(AbstractPromptStyle): def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - prompt = "" + inst_buffer = [] + text = "" for message in messages: - role = message.role - content = message.content or "" - if role.lower() == "system": - message_from_user = f"[INST] {content.strip()} [/INST]" - prompt += message_from_user - elif role.lower() == "user": - prompt += "" - message_from_user = f"[INST] {content.strip()} [/INST]" - prompt += message_from_user - return prompt + if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER: + inst_buffer.append(str(message.content).strip()) + elif message.role == MessageRole.ASSISTANT: + text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" + text += " " + str(message.content).strip() + "" + inst_buffer.clear() + else: + raise ValueError(f"Unknown message role {message.role}") + + if len(inst_buffer) > 0: + text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" + + return text def _completion_to_prompt(self, completion: str) -> str: return self._messages_to_prompt( diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index 3b5af91..ef76437 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -69,17 +69,21 @@ def test_tag_prompt_style_format_with_system_prompt(): def test_mistral_prompt_style_format(): prompt_style = MistralPromptStyle() messages = [ - ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), - ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ChatMessage(content="A", role=MessageRole.SYSTEM), + ChatMessage(content="B", role=MessageRole.USER), ] - - expected_prompt = ( - "[INST] You are an AI assistant. [/INST]" - "[INST] Hello, how are you doing? [/INST]" - ) - + expected_prompt = "[INST] A\nB [/INST]" assert prompt_style.messages_to_prompt(messages) == expected_prompt + messages2 = [ + ChatMessage(content="A", role=MessageRole.SYSTEM), + ChatMessage(content="B", role=MessageRole.USER), + ChatMessage(content="C", role=MessageRole.ASSISTANT), + ChatMessage(content="D", role=MessageRole.USER), + ] + expected_prompt2 = "[INST] A\nB [/INST] C[INST] D [/INST]" + assert prompt_style.messages_to_prompt(messages2) == expected_prompt2 + def test_chatml_prompt_style_format(): prompt_style = ChatMLPromptStyle()