parent
289c06df0f
commit
695d303401
|
@ -300,8 +300,9 @@ class TransformersModel(Model):
|
|||
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
|
||||
)
|
||||
self.model_id = default_model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
|
||||
self.device
|
||||
|
@ -340,11 +341,10 @@ class TransformersModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> str:
|
||||
) -> ChatCompletionOutputMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
||||
if tools_to_call_from is not None:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
|
@ -361,7 +361,7 @@ class TransformersModel(Model):
|
|||
)
|
||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||
|
||||
|
||||
out = self.model.generate(
|
||||
**prompt_tensor,
|
||||
max_new_tokens=max_tokens,
|
||||
|
@ -371,17 +371,19 @@ class TransformersModel(Model):
|
|||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
self.last_input_token_count = count_prompt_tokens
|
||||
self.last_output_token_count = len(generated_tokens)
|
||||
|
||||
if stop_sequences is not None:
|
||||
output = remove_stop_sequences(output, stop_sequences)
|
||||
|
||||
if tools_to_call_from is None:
|
||||
return ChatCompletionOutputMessage(role="assistant", content=output)
|
||||
else:
|
||||
tool_name, tool_arguments = json.load(output)
|
||||
if "Action:" in output:
|
||||
output = output.split("Action:", 1)[1].strip()
|
||||
parsed_output = json.loads(output)
|
||||
tool_name = parsed_output.get("tool_name")
|
||||
tool_arguments = parsed_output.get("tool_arguments")
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
|
|
Loading…
Reference in New Issue