refactor(models): restructure model parameter handling (#227)

* refactor(models): restructure model parameter handling

- Introduce base-class level default parameters (temperature, max_tokens)
- Optimize parameter handling: method args can override base config
- Unify parameter handling across model implementations
This commit is contained in:
kingdomad 2025-01-22 18:27:36 +08:00 committed by GitHub
parent 117014d2e1
commit 398c932250
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 133 additions and 82 deletions

View File

@ -196,9 +196,59 @@ def get_clean_message_list(
class Model:
def __init__(self):
def __init__(self, **kwargs):
self.last_input_token_count = None
self.last_output_token_count = None
# Set default values for common parameters
kwargs.setdefault("temperature", 0.5)
kwargs.setdefault("max_tokens", 4096)
self.kwargs = kwargs
def _prepare_completion_kwargs(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
) -> Dict:
"""
Prepare parameters required for model invocation, handling parameter priorities.
Parameter priority from high to low:
1. Explicitly passed kwargs
2. Specific parameters (stop_sequences, grammar, etc.)
3. Default values in self.kwargs
"""
# Clean and standardize the message list
messages = get_clean_message_list(messages, role_conversions=custom_role_conversions or tool_role_conversions)
# Use self.kwargs as the base configuration
completion_kwargs = {
**self.kwargs,
"messages": messages,
}
# Handle specific parameters
if stop_sequences is not None:
completion_kwargs["stop"] = stop_sequences
if grammar is not None:
completion_kwargs["grammar"] = grammar
# Handle tools parameter
if tools_to_call_from:
completion_kwargs.update(
{
"tools": [get_json_schema(tool) for tool in tools_to_call_from],
"tool_choice": "required",
}
)
# Finally, use the passed-in kwargs to override all settings
completion_kwargs.update(kwargs)
return completion_kwargs
def get_token_counts(self) -> Dict[str, int]:
return {
@ -211,6 +261,8 @@ class Model:
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
"""Process the input messages and return the model's response.
@ -221,8 +273,13 @@ class Model:
A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response.
tools_to_call_from (`List[Tool]`, *optional*):
A list of tools that the model can use to generate responses.
**kwargs:
Additional keyword arguments to be passed to the underlying model.
Returns:
`str`: The text content of the model's response.
`ChatMessage`: A chat message object containing the model's response.
"""
pass # To be implemented in child classes!
@ -265,16 +322,13 @@ class HfApiModel(Model):
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
token: Optional[str] = None,
timeout: Optional[int] = 120,
temperature: float = 0.5,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.model_id = model_id
if token is None:
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature
self.kwargs = kwargs
def __call__(
self,
@ -282,29 +336,18 @@ class HfApiModel(Model):
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
"""
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
if tools_to_call_from:
response = self.client.chat.completions.create(
messages=messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
temperature=self.temperature,
**self.kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
temperature=self.temperature,
**self.kwargs,
)
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
response = self.client.chat_completion(**completion_kwargs)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage.from_hf_api(response.choices[0].message)
@ -358,7 +401,7 @@ class TransformersModel(Model):
trust_remote_code: bool = False,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
if not is_torch_available() or not _is_package_available("transformers"):
raise ModuleNotFoundError(
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
@ -418,12 +461,36 @@ class TransformersModel(Model):
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
messages = completion_kwargs.pop("messages")
stop_sequences = completion_kwargs.pop("stop", None)
max_new_tokens = (
kwargs.get("max_new_tokens")
or kwargs.get("max_tokens")
or self.kwargs.get("max_new_tokens")
or self.kwargs.get("max_tokens")
)
if max_new_tokens:
completion_kwargs["max_new_tokens"] = max_new_tokens
if stop_sequences:
completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences)
if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tools=completion_kwargs.pop("tools", []),
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
@ -434,14 +501,11 @@ class TransformersModel(Model):
return_tensors="pt",
return_dict=True,
)
prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
out = self.model.generate(
**prompt_tensor,
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
**self.kwargs,
)
out = self.model.generate(**prompt_tensor, **completion_kwargs)
generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
self.last_input_token_count = count_prompt_tokens
@ -449,6 +513,7 @@ class TransformersModel(Model):
if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None:
return ChatMessage(role="assistant", content=output)
else:
@ -498,13 +563,12 @@ class LiteLLMModel(Model):
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
)
super().__init__()
super().__init__(**kwargs)
self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
self.api_base = api_base
self.api_key = api_key
self.kwargs = kwargs
def __call__(
self,
@ -512,34 +576,28 @@ class LiteLLMModel(Model):
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
import litellm
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
model=self.model_id,
api_base=self.api_base,
api_key=self.api_key,
**kwargs,
)
response = litellm.completion(**completion_kwargs)
if tools_to_call_from:
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required",
stop=stop_sequences,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
else:
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = response.choices[0].message
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message
@ -576,13 +634,13 @@ class OpenAIServerModel(Model):
raise ModuleNotFoundError(
"Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
) from None
super().__init__()
super().__init__(**kwargs)
self.model_id = model_id
self.client = openai.OpenAI(
base_url=api_base,
api_key=api_key,
)
self.kwargs = kwargs
self.custom_role_conversions = custom_role_conversions
def __call__(
@ -591,30 +649,23 @@ class OpenAIServerModel(Model):
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
messages = get_clean_message_list(
messages,
role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
model=self.model_id,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)
if tools_to_call_from:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required",
stop=stop_sequences,
**self.kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
**self.kwargs,
)
response = self.client.chat.completions.create(**completion_kwargs)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
message = response.choices[0].message
message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}))
if tools_to_call_from is not None:
return parse_tool_args_if_needed(message)
return message