From 695d3034016b828bf550aecc0e5dd035d10e156c Mon Sep 17 00:00:00 2001 From: Aggelos Kyriakoulis Date: Mon, 13 Jan 2025 17:20:45 +0200 Subject: [PATCH] Bug fixes on TransformersModel (#165) * TransformersModel bug fixes --- src/smolagents/models.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index fd68607..cc9aedc 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -300,8 +300,9 @@ class TransformersModel(Model): self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) except Exception as e: logger.warning( - f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." + f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." ) + self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( self.device @@ -340,11 +341,10 @@ class TransformersModel(Model): grammar: Optional[str] = None, max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, - ) -> str: + ) -> ChatCompletionOutputMessage: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) - if tools_to_call_from is not None: prompt_tensor = self.tokenizer.apply_chat_template( messages, @@ -361,7 +361,7 @@ class TransformersModel(Model): ) prompt_tensor = prompt_tensor.to(self.model.device) count_prompt_tokens = prompt_tensor["input_ids"].shape[1] - + out = self.model.generate( **prompt_tensor, max_new_tokens=max_tokens, @@ -371,17 +371,19 @@ class TransformersModel(Model): ) generated_tokens = out[0, count_prompt_tokens:] output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) - self.last_input_token_count = count_prompt_tokens self.last_output_token_count = len(generated_tokens) if stop_sequences is not None: output = remove_stop_sequences(output, stop_sequences) - if tools_to_call_from is None: return ChatCompletionOutputMessage(role="assistant", content=output) else: - tool_name, tool_arguments = json.load(output) + if "Action:" in output: + output = output.split("Action:", 1)[1].strip() + parsed_output = json.loads(output) + tool_name = parsed_output.get("tool_name") + tool_arguments = parsed_output.get("tool_arguments") return ChatCompletionOutputMessage( role="assistant", content="",