fix(LLM): mistral ignoring assistant messages (#1954)
* fix: mistral ignoring assistant messages * fix: typing * fix: fix tests
This commit is contained in:
parent
3b3e96ad6c
commit
c7212ac7cc
|
@ -173,18 +173,22 @@ class TagPromptStyle(AbstractPromptStyle):
|
||||||
|
|
||||||
class MistralPromptStyle(AbstractPromptStyle):
|
class MistralPromptStyle(AbstractPromptStyle):
|
||||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
prompt = "<s>"
|
inst_buffer = []
|
||||||
|
text = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.role
|
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
|
||||||
content = message.content or ""
|
inst_buffer.append(str(message.content).strip())
|
||||||
if role.lower() == "system":
|
elif message.role == MessageRole.ASSISTANT:
|
||||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||||
prompt += message_from_user
|
text += " " + str(message.content).strip() + "</s>"
|
||||||
elif role.lower() == "user":
|
inst_buffer.clear()
|
||||||
prompt += "</s>"
|
else:
|
||||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
raise ValueError(f"Unknown message role {message.role}")
|
||||||
prompt += message_from_user
|
|
||||||
return prompt
|
if len(inst_buffer) > 0:
|
||||||
|
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def _completion_to_prompt(self, completion: str) -> str:
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
return self._messages_to_prompt(
|
return self._messages_to_prompt(
|
||||||
|
|
|
@ -69,17 +69,21 @@ def test_tag_prompt_style_format_with_system_prompt():
|
||||||
def test_mistral_prompt_style_format():
|
def test_mistral_prompt_style_format():
|
||||||
prompt_style = MistralPromptStyle()
|
prompt_style = MistralPromptStyle()
|
||||||
messages = [
|
messages = [
|
||||||
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
ChatMessage(content="A", role=MessageRole.SYSTEM),
|
||||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
ChatMessage(content="B", role=MessageRole.USER),
|
||||||
]
|
]
|
||||||
|
expected_prompt = "<s>[INST] A\nB [/INST]"
|
||||||
expected_prompt = (
|
|
||||||
"<s>[INST] You are an AI assistant. [/INST]</s>"
|
|
||||||
"[INST] Hello, how are you doing? [/INST]"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
messages2 = [
|
||||||
|
ChatMessage(content="A", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="B", role=MessageRole.USER),
|
||||||
|
ChatMessage(content="C", role=MessageRole.ASSISTANT),
|
||||||
|
ChatMessage(content="D", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
expected_prompt2 = "<s>[INST] A\nB [/INST] C</s><s>[INST] D [/INST]"
|
||||||
|
assert prompt_style.messages_to_prompt(messages2) == expected_prompt2
|
||||||
|
|
||||||
|
|
||||||
def test_chatml_prompt_style_format():
|
def test_chatml_prompt_style_format():
|
||||||
prompt_style = ChatMLPromptStyle()
|
prompt_style = ChatMLPromptStyle()
|
||||||
|
|
Loading…
Reference in New Issue