Bug fixes on TransformersModel (#165)

* TransformersModel bug fixes
This commit is contained in:
Aggelos Kyriakoulis 2025-01-13 17:20:45 +02:00 committed by GitHub
parent 289c06df0f
commit 695d303401
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 7 deletions

View File

@ -300,8 +300,9 @@ class TransformersModel(Model):
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
) )
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
self.device self.device
@ -340,11 +341,10 @@ class TransformersModel(Model):
grammar: Optional[str] = None, grammar: Optional[str] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> str: ) -> ChatCompletionOutputMessage:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
if tools_to_call_from is not None: if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template( prompt_tensor = self.tokenizer.apply_chat_template(
messages, messages,
@ -361,7 +361,7 @@ class TransformersModel(Model):
) )
prompt_tensor = prompt_tensor.to(self.model.device) prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1] count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
out = self.model.generate( out = self.model.generate(
**prompt_tensor, **prompt_tensor,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
@ -371,17 +371,19 @@ class TransformersModel(Model):
) )
generated_tokens = out[0, count_prompt_tokens:] generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens self.last_input_token_count = count_prompt_tokens
self.last_output_token_count = len(generated_tokens) self.last_output_token_count = len(generated_tokens)
if stop_sequences is not None: if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences) output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None: if tools_to_call_from is None:
return ChatCompletionOutputMessage(role="assistant", content=output) return ChatCompletionOutputMessage(role="assistant", content=output)
else: else:
tool_name, tool_arguments = json.load(output) if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")
return ChatCompletionOutputMessage( return ChatCompletionOutputMessage(
role="assistant", role="assistant",
content="", content="",