Fix for MLX Model when stop sequence is followed by any chars (#623)
This commit is contained in:
parent
2347631a55
commit
cb2218a86f
|
@ -525,7 +525,7 @@ class MLXModel(Model):
|
|||
|
||||
def _to_message(self, text, tools_to_call_from):
|
||||
if tools_to_call_from:
|
||||
# tmp solution for extracting tool JSON without assuming a specific model output format
|
||||
# solution for extracting tool JSON without assuming a specific model output format
|
||||
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
|
||||
parsed_text = json.loads(maybe_json)
|
||||
tool_name = parsed_text.get(self.tool_name_key, None)
|
||||
|
@ -579,8 +579,9 @@ class MLXModel(Model):
|
|||
self.last_output_token_count += 1
|
||||
text += _.text
|
||||
for stop_sequence in prepared_stop_sequences:
|
||||
if text.strip().endswith(stop_sequence):
|
||||
text = text[: -len(stop_sequence)]
|
||||
stop_sequence_start = text.rfind(stop_sequence)
|
||||
if stop_sequence_start != -1:
|
||||
text = text[:stop_sequence_start]
|
||||
return self._to_message(text, tools_to_call_from)
|
||||
|
||||
return self._to_message(text, tools_to_call_from)
|
||||
|
|
|
@ -69,6 +69,19 @@ class ModelTests(unittest.TestCase):
|
|||
output = model(messages, stop_sequences=["great"]).content
|
||||
assert output.startswith("Hello")
|
||||
|
||||
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
||||
def test_get_mlx_message_tricky_stop_sequence(self):
|
||||
# In this test HuggingFaceTB/SmolLM2-135M-Instruct generates the token ">'"
|
||||
# which is required to test capturing stop_sequences that have extra chars at the end.
|
||||
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=100)
|
||||
stop_sequence = " print '>"
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": f"Please{stop_sequence}'"}]}]
|
||||
# check our assumption that that ">" is followed by "'"
|
||||
assert model.tokenizer.vocab[">'"]
|
||||
assert model(messages, stop_sequences=[]).content == f"I'm ready to help you{stop_sequence}'"
|
||||
# check stop_sequence capture when output has trailing chars
|
||||
assert model(messages, stop_sequences=[stop_sequence]).content == "I'm ready to help you"
|
||||
|
||||
def test_transformers_message_no_tool(self):
|
||||
model = TransformersModel(
|
||||
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
|
|
Loading…
Reference in New Issue