refactor(models): restructure model parameter handling (#227)
* refactor(models): restructure model parameter handling - Introduce base-class level default parameters (temperature, max_tokens) - Optimize parameter handling: method args can override base config - Unify parameter handling across model implementations
This commit is contained in:
		
							parent
							
								
									117014d2e1
								
							
						
					
					
						commit
						398c932250
					
				|  | @ -196,9 +196,59 @@ def get_clean_message_list( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Model: | class Model: | ||||||
|     def __init__(self): |     def __init__(self, **kwargs): | ||||||
|         self.last_input_token_count = None |         self.last_input_token_count = None | ||||||
|         self.last_output_token_count = None |         self.last_output_token_count = None | ||||||
|  |         # Set default values for common parameters | ||||||
|  |         kwargs.setdefault("temperature", 0.5) | ||||||
|  |         kwargs.setdefault("max_tokens", 4096) | ||||||
|  |         self.kwargs = kwargs | ||||||
|  | 
 | ||||||
|  |     def _prepare_completion_kwargs( | ||||||
|  |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         custom_role_conversions: Optional[Dict[str, str]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> Dict: | ||||||
|  |         """ | ||||||
|  |         Prepare parameters required for model invocation, handling parameter priorities. | ||||||
|  | 
 | ||||||
|  |         Parameter priority from high to low: | ||||||
|  |         1. Explicitly passed kwargs | ||||||
|  |         2. Specific parameters (stop_sequences, grammar, etc.) | ||||||
|  |         3. Default values in self.kwargs | ||||||
|  |         """ | ||||||
|  |         # Clean and standardize the message list | ||||||
|  |         messages = get_clean_message_list(messages, role_conversions=custom_role_conversions or tool_role_conversions) | ||||||
|  | 
 | ||||||
|  |         # Use self.kwargs as the base configuration | ||||||
|  |         completion_kwargs = { | ||||||
|  |             **self.kwargs, | ||||||
|  |             "messages": messages, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         # Handle specific parameters | ||||||
|  |         if stop_sequences is not None: | ||||||
|  |             completion_kwargs["stop"] = stop_sequences | ||||||
|  |         if grammar is not None: | ||||||
|  |             completion_kwargs["grammar"] = grammar | ||||||
|  | 
 | ||||||
|  |         # Handle tools parameter | ||||||
|  |         if tools_to_call_from: | ||||||
|  |             completion_kwargs.update( | ||||||
|  |                 { | ||||||
|  |                     "tools": [get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|  |                     "tool_choice": "required", | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # Finally, use the passed-in kwargs to override all settings | ||||||
|  |         completion_kwargs.update(kwargs) | ||||||
|  | 
 | ||||||
|  |         return completion_kwargs | ||||||
| 
 | 
 | ||||||
|     def get_token_counts(self) -> Dict[str, int]: |     def get_token_counts(self) -> Dict[str, int]: | ||||||
|         return { |         return { | ||||||
|  | @ -211,6 +261,8 @@ class Model: | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|  | @ -221,8 +273,13 @@ class Model: | ||||||
|                 A list of strings that will stop the generation if encountered in the model's output. |                 A list of strings that will stop the generation if encountered in the model's output. | ||||||
|             grammar (`str`, *optional*): |             grammar (`str`, *optional*): | ||||||
|                 The grammar or formatting structure to use in the model's response. |                 The grammar or formatting structure to use in the model's response. | ||||||
|  |             tools_to_call_from (`List[Tool]`, *optional*): | ||||||
|  |                 A list of tools that the model can use to generate responses. | ||||||
|  |             **kwargs: | ||||||
|  |                 Additional keyword arguments to be passed to the underlying model. | ||||||
|  | 
 | ||||||
|         Returns: |         Returns: | ||||||
|             `str`: The text content of the model's response. |             `ChatMessage`: A chat message object containing the model's response. | ||||||
|         """ |         """ | ||||||
|         pass  # To be implemented in child classes! |         pass  # To be implemented in child classes! | ||||||
| 
 | 
 | ||||||
|  | @ -265,16 +322,13 @@ class HfApiModel(Model): | ||||||
|         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", |         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|         token: Optional[str] = None, |         token: Optional[str] = None, | ||||||
|         timeout: Optional[int] = 120, |         timeout: Optional[int] = 120, | ||||||
|         temperature: float = 0.5, |  | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__(**kwargs) | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         if token is None: |         if token is None: | ||||||
|             token = os.getenv("HF_TOKEN") |             token = os.getenv("HF_TOKEN") | ||||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) |         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||||
|         self.temperature = temperature |  | ||||||
|         self.kwargs = kwargs |  | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|  | @ -282,29 +336,18 @@ class HfApiModel(Model): | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         """ |         completion_kwargs = self._prepare_completion_kwargs( | ||||||
|         Gets an LLM output message for the given list of input messages. |             messages=messages, | ||||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. |             stop_sequences=stop_sequences, | ||||||
|         """ |             grammar=grammar, | ||||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) |             tools_to_call_from=tools_to_call_from, | ||||||
|         if tools_to_call_from: |             **kwargs, | ||||||
|             response = self.client.chat.completions.create( |         ) | ||||||
|                 messages=messages, | 
 | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |         response = self.client.chat_completion(**completion_kwargs) | ||||||
|                 tool_choice="auto", | 
 | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 temperature=self.temperature, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             response = self.client.chat.completions.create( |  | ||||||
|                 model=self.model_id, |  | ||||||
|                 messages=messages, |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 temperature=self.temperature, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         message = ChatMessage.from_hf_api(response.choices[0].message) |         message = ChatMessage.from_hf_api(response.choices[0].message) | ||||||
|  | @ -358,7 +401,7 @@ class TransformersModel(Model): | ||||||
|         trust_remote_code: bool = False, |         trust_remote_code: bool = False, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__(**kwargs) | ||||||
|         if not is_torch_available() or not _is_package_available("transformers"): |         if not is_torch_available() or not _is_package_available("transformers"): | ||||||
|             raise ModuleNotFoundError( |             raise ModuleNotFoundError( | ||||||
|                 "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" |                 "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" | ||||||
|  | @ -418,12 +461,36 @@ class TransformersModel(Model): | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) |         completion_kwargs = self._prepare_completion_kwargs( | ||||||
|  |             messages=messages, | ||||||
|  |             stop_sequences=stop_sequences, | ||||||
|  |             grammar=grammar, | ||||||
|  |             tools_to_call_from=tools_to_call_from, | ||||||
|  |             **kwargs, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         messages = completion_kwargs.pop("messages") | ||||||
|  |         stop_sequences = completion_kwargs.pop("stop", None) | ||||||
|  | 
 | ||||||
|  |         max_new_tokens = ( | ||||||
|  |             kwargs.get("max_new_tokens") | ||||||
|  |             or kwargs.get("max_tokens") | ||||||
|  |             or self.kwargs.get("max_new_tokens") | ||||||
|  |             or self.kwargs.get("max_tokens") | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if max_new_tokens: | ||||||
|  |             completion_kwargs["max_new_tokens"] = max_new_tokens | ||||||
|  | 
 | ||||||
|  |         if stop_sequences: | ||||||
|  |             completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences) | ||||||
|  | 
 | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             prompt_tensor = self.tokenizer.apply_chat_template( |             prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|                 messages, |                 messages, | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |                 tools=completion_kwargs.pop("tools", []), | ||||||
|                 return_tensors="pt", |                 return_tensors="pt", | ||||||
|                 return_dict=True, |                 return_dict=True, | ||||||
|                 add_generation_prompt=True, |                 add_generation_prompt=True, | ||||||
|  | @ -434,14 +501,11 @@ class TransformersModel(Model): | ||||||
|                 return_tensors="pt", |                 return_tensors="pt", | ||||||
|                 return_dict=True, |                 return_dict=True, | ||||||
|             ) |             ) | ||||||
|  | 
 | ||||||
|         prompt_tensor = prompt_tensor.to(self.model.device) |         prompt_tensor = prompt_tensor.to(self.model.device) | ||||||
|         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] |         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] | ||||||
| 
 | 
 | ||||||
|         out = self.model.generate( |         out = self.model.generate(**prompt_tensor, **completion_kwargs) | ||||||
|             **prompt_tensor, |  | ||||||
|             stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None), |  | ||||||
|             **self.kwargs, |  | ||||||
|         ) |  | ||||||
|         generated_tokens = out[0, count_prompt_tokens:] |         generated_tokens = out[0, count_prompt_tokens:] | ||||||
|         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||||
|         self.last_input_token_count = count_prompt_tokens |         self.last_input_token_count = count_prompt_tokens | ||||||
|  | @ -449,6 +513,7 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
|         if stop_sequences is not None: |         if stop_sequences is not None: | ||||||
|             output = remove_stop_sequences(output, stop_sequences) |             output = remove_stop_sequences(output, stop_sequences) | ||||||
|  | 
 | ||||||
|         if tools_to_call_from is None: |         if tools_to_call_from is None: | ||||||
|             return ChatMessage(role="assistant", content=output) |             return ChatMessage(role="assistant", content=output) | ||||||
|         else: |         else: | ||||||
|  | @ -498,13 +563,12 @@ class LiteLLMModel(Model): | ||||||
|                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" |                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         super().__init__() |         super().__init__(**kwargs) | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs |         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||||
|         litellm.add_function_to_prompt = True |         litellm.add_function_to_prompt = True | ||||||
|         self.api_base = api_base |         self.api_base = api_base | ||||||
|         self.api_key = api_key |         self.api_key = api_key | ||||||
|         self.kwargs = kwargs |  | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|  | @ -512,34 +576,28 @@ class LiteLLMModel(Model): | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         import litellm |         import litellm | ||||||
| 
 | 
 | ||||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) |         completion_kwargs = self._prepare_completion_kwargs( | ||||||
|  |             messages=messages, | ||||||
|  |             stop_sequences=stop_sequences, | ||||||
|  |             grammar=grammar, | ||||||
|  |             tools_to_call_from=tools_to_call_from, | ||||||
|  |             model=self.model_id, | ||||||
|  |             api_base=self.api_base, | ||||||
|  |             api_key=self.api_key, | ||||||
|  |             **kwargs, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response = litellm.completion(**completion_kwargs) | ||||||
| 
 | 
 | ||||||
|         if tools_to_call_from: |  | ||||||
|             response = litellm.completion( |  | ||||||
|                 model=self.model_id, |  | ||||||
|                 messages=messages, |  | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |  | ||||||
|                 tool_choice="required", |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 api_base=self.api_base, |  | ||||||
|                 api_key=self.api_key, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             response = litellm.completion( |  | ||||||
|                 model=self.model_id, |  | ||||||
|                 messages=messages, |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 api_base=self.api_base, |  | ||||||
|                 api_key=self.api_key, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         message = response.choices[0].message | 
 | ||||||
|  |         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||||
|  | 
 | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             return parse_tool_args_if_needed(message) |             return parse_tool_args_if_needed(message) | ||||||
|         return message |         return message | ||||||
|  | @ -576,13 +634,13 @@ class OpenAIServerModel(Model): | ||||||
|             raise ModuleNotFoundError( |             raise ModuleNotFoundError( | ||||||
|                 "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`" |                 "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`" | ||||||
|             ) from None |             ) from None | ||||||
|         super().__init__() | 
 | ||||||
|  |         super().__init__(**kwargs) | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         self.client = openai.OpenAI( |         self.client = openai.OpenAI( | ||||||
|             base_url=api_base, |             base_url=api_base, | ||||||
|             api_key=api_key, |             api_key=api_key, | ||||||
|         ) |         ) | ||||||
|         self.kwargs = kwargs |  | ||||||
|         self.custom_role_conversions = custom_role_conversions |         self.custom_role_conversions = custom_role_conversions | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|  | @ -591,30 +649,23 @@ class OpenAIServerModel(Model): | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         completion_kwargs = self._prepare_completion_kwargs( | ||||||
|             messages, |             messages=messages, | ||||||
|             role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions), |             stop_sequences=stop_sequences, | ||||||
|  |             grammar=grammar, | ||||||
|  |             tools_to_call_from=tools_to_call_from, | ||||||
|  |             model=self.model_id, | ||||||
|  |             custom_role_conversions=self.custom_role_conversions, | ||||||
|  |             **kwargs, | ||||||
|         ) |         ) | ||||||
|         if tools_to_call_from: | 
 | ||||||
|             response = self.client.chat.completions.create( |         response = self.client.chat.completions.create(**completion_kwargs) | ||||||
|                 model=self.model_id, |  | ||||||
|                 messages=messages, |  | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |  | ||||||
|                 tool_choice="required", |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             response = self.client.chat.completions.create( |  | ||||||
|                 model=self.model_id, |  | ||||||
|                 messages=messages, |  | ||||||
|                 stop=stop_sequences, |  | ||||||
|                 **self.kwargs, |  | ||||||
|             ) |  | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         message = response.choices[0].message | 
 | ||||||
|  |         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             return parse_tool_args_if_needed(message) |             return parse_tool_args_if_needed(message) | ||||||
|         return message |         return message | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue