parent
289c06df0f
commit
695d303401
|
@ -300,8 +300,9 @@ class TransformersModel(Model):
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
|
||||||
)
|
)
|
||||||
|
self.model_id = default_model_id
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
|
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
|
||||||
self.device
|
self.device
|
||||||
|
@ -340,11 +341,10 @@ class TransformersModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> str:
|
) -> ChatCompletionOutputMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
|
@ -361,7 +361,7 @@ class TransformersModel(Model):
|
||||||
)
|
)
|
||||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||||
|
|
||||||
out = self.model.generate(
|
out = self.model.generate(
|
||||||
**prompt_tensor,
|
**prompt_tensor,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
|
@ -371,17 +371,19 @@ class TransformersModel(Model):
|
||||||
)
|
)
|
||||||
generated_tokens = out[0, count_prompt_tokens:]
|
generated_tokens = out[0, count_prompt_tokens:]
|
||||||
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
self.last_input_token_count = count_prompt_tokens
|
self.last_input_token_count = count_prompt_tokens
|
||||||
self.last_output_token_count = len(generated_tokens)
|
self.last_output_token_count = len(generated_tokens)
|
||||||
|
|
||||||
if stop_sequences is not None:
|
if stop_sequences is not None:
|
||||||
output = remove_stop_sequences(output, stop_sequences)
|
output = remove_stop_sequences(output, stop_sequences)
|
||||||
|
|
||||||
if tools_to_call_from is None:
|
if tools_to_call_from is None:
|
||||||
return ChatCompletionOutputMessage(role="assistant", content=output)
|
return ChatCompletionOutputMessage(role="assistant", content=output)
|
||||||
else:
|
else:
|
||||||
tool_name, tool_arguments = json.load(output)
|
if "Action:" in output:
|
||||||
|
output = output.split("Action:", 1)[1].strip()
|
||||||
|
parsed_output = json.loads(output)
|
||||||
|
tool_name = parsed_output.get("tool_name")
|
||||||
|
tool_arguments = parsed_output.get("tool_arguments")
|
||||||
return ChatCompletionOutputMessage(
|
return ChatCompletionOutputMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
|
|
Loading…
Reference in New Issue