refactor(models): restructure model parameter handling (#227)
* refactor(models): restructure model parameter handling - Introduce base-class level default parameters (temperature, max_tokens) - Optimize parameter handling: method args can override base config - Unify parameter handling across model implementations
This commit is contained in:
		
							parent
							
								
									117014d2e1
								
							
						
					
					
						commit
						398c932250
					
				|  | @ -196,9 +196,59 @@ def get_clean_message_list( | |||
| 
 | ||||
| 
 | ||||
| class Model: | ||||
|     def __init__(self): | ||||
|     def __init__(self, **kwargs): | ||||
|         self.last_input_token_count = None | ||||
|         self.last_output_token_count = None | ||||
|         # Set default values for common parameters | ||||
|         kwargs.setdefault("temperature", 0.5) | ||||
|         kwargs.setdefault("max_tokens", 4096) | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def _prepare_completion_kwargs( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         custom_role_conversions: Optional[Dict[str, str]] = None, | ||||
|         **kwargs, | ||||
|     ) -> Dict: | ||||
|         """ | ||||
|         Prepare parameters required for model invocation, handling parameter priorities. | ||||
| 
 | ||||
|         Parameter priority from high to low: | ||||
|         1. Explicitly passed kwargs | ||||
|         2. Specific parameters (stop_sequences, grammar, etc.) | ||||
|         3. Default values in self.kwargs | ||||
|         """ | ||||
|         # Clean and standardize the message list | ||||
|         messages = get_clean_message_list(messages, role_conversions=custom_role_conversions or tool_role_conversions) | ||||
| 
 | ||||
|         # Use self.kwargs as the base configuration | ||||
|         completion_kwargs = { | ||||
|             **self.kwargs, | ||||
|             "messages": messages, | ||||
|         } | ||||
| 
 | ||||
|         # Handle specific parameters | ||||
|         if stop_sequences is not None: | ||||
|             completion_kwargs["stop"] = stop_sequences | ||||
|         if grammar is not None: | ||||
|             completion_kwargs["grammar"] = grammar | ||||
| 
 | ||||
|         # Handle tools parameter | ||||
|         if tools_to_call_from: | ||||
|             completion_kwargs.update( | ||||
|                 { | ||||
|                     "tools": [get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                     "tool_choice": "required", | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|         # Finally, use the passed-in kwargs to override all settings | ||||
|         completion_kwargs.update(kwargs) | ||||
| 
 | ||||
|         return completion_kwargs | ||||
| 
 | ||||
|     def get_token_counts(self) -> Dict[str, int]: | ||||
|         return { | ||||
|  | @ -211,6 +261,8 @@ class Model: | |||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         """Process the input messages and return the model's response. | ||||
| 
 | ||||
|  | @ -221,8 +273,13 @@ class Model: | |||
|                 A list of strings that will stop the generation if encountered in the model's output. | ||||
|             grammar (`str`, *optional*): | ||||
|                 The grammar or formatting structure to use in the model's response. | ||||
|             tools_to_call_from (`List[Tool]`, *optional*): | ||||
|                 A list of tools that the model can use to generate responses. | ||||
|             **kwargs: | ||||
|                 Additional keyword arguments to be passed to the underlying model. | ||||
| 
 | ||||
|         Returns: | ||||
|             `str`: The text content of the model's response. | ||||
|             `ChatMessage`: A chat message object containing the model's response. | ||||
|         """ | ||||
|         pass  # To be implemented in child classes! | ||||
| 
 | ||||
|  | @ -265,16 +322,13 @@ class HfApiModel(Model): | |||
|         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|         token: Optional[str] = None, | ||||
|         timeout: Optional[int] = 120, | ||||
|         temperature: float = 0.5, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         super().__init__(**kwargs) | ||||
|         self.model_id = model_id | ||||
|         if token is None: | ||||
|             token = os.getenv("HF_TOKEN") | ||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||
|         self.temperature = temperature | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|  | @ -282,29 +336,18 @@ class HfApiModel(Model): | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         """ | ||||
|         Gets an LLM output message for the given list of input messages. | ||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||
|         """ | ||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) | ||||
|         if tools_to_call_from: | ||||
|             response = self.client.chat.completions.create( | ||||
|                 messages=messages, | ||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="auto", | ||||
|                 stop=stop_sequences, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         completion_kwargs = self._prepare_completion_kwargs( | ||||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
|         response = self.client.chat_completion(**completion_kwargs) | ||||
| 
 | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         message = ChatMessage.from_hf_api(response.choices[0].message) | ||||
|  | @ -358,7 +401,7 @@ class TransformersModel(Model): | |||
|         trust_remote_code: bool = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         super().__init__(**kwargs) | ||||
|         if not is_torch_available() or not _is_package_available("transformers"): | ||||
|             raise ModuleNotFoundError( | ||||
|                 "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" | ||||
|  | @ -418,12 +461,36 @@ class TransformersModel(Model): | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) | ||||
|         completion_kwargs = self._prepare_completion_kwargs( | ||||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
|         messages = completion_kwargs.pop("messages") | ||||
|         stop_sequences = completion_kwargs.pop("stop", None) | ||||
| 
 | ||||
|         max_new_tokens = ( | ||||
|             kwargs.get("max_new_tokens") | ||||
|             or kwargs.get("max_tokens") | ||||
|             or self.kwargs.get("max_new_tokens") | ||||
|             or self.kwargs.get("max_tokens") | ||||
|         ) | ||||
| 
 | ||||
|         if max_new_tokens: | ||||
|             completion_kwargs["max_new_tokens"] = max_new_tokens | ||||
| 
 | ||||
|         if stop_sequences: | ||||
|             completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences) | ||||
| 
 | ||||
|         if tools_to_call_from is not None: | ||||
|             prompt_tensor = self.tokenizer.apply_chat_template( | ||||
|                 messages, | ||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tools=completion_kwargs.pop("tools", []), | ||||
|                 return_tensors="pt", | ||||
|                 return_dict=True, | ||||
|                 add_generation_prompt=True, | ||||
|  | @ -434,14 +501,11 @@ class TransformersModel(Model): | |||
|                 return_tensors="pt", | ||||
|                 return_dict=True, | ||||
|             ) | ||||
| 
 | ||||
|         prompt_tensor = prompt_tensor.to(self.model.device) | ||||
|         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] | ||||
| 
 | ||||
|         out = self.model.generate( | ||||
|             **prompt_tensor, | ||||
|             stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None), | ||||
|             **self.kwargs, | ||||
|         ) | ||||
|         out = self.model.generate(**prompt_tensor, **completion_kwargs) | ||||
|         generated_tokens = out[0, count_prompt_tokens:] | ||||
|         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||
|         self.last_input_token_count = count_prompt_tokens | ||||
|  | @ -449,6 +513,7 @@ class TransformersModel(Model): | |||
| 
 | ||||
|         if stop_sequences is not None: | ||||
|             output = remove_stop_sequences(output, stop_sequences) | ||||
| 
 | ||||
|         if tools_to_call_from is None: | ||||
|             return ChatMessage(role="assistant", content=output) | ||||
|         else: | ||||
|  | @ -498,13 +563,12 @@ class LiteLLMModel(Model): | |||
|                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" | ||||
|             ) | ||||
| 
 | ||||
|         super().__init__() | ||||
|         super().__init__(**kwargs) | ||||
|         self.model_id = model_id | ||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||
|         litellm.add_function_to_prompt = True | ||||
|         self.api_base = api_base | ||||
|         self.api_key = api_key | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|  | @ -512,34 +576,28 @@ class LiteLLMModel(Model): | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         import litellm | ||||
| 
 | ||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) | ||||
|         completion_kwargs = self._prepare_completion_kwargs( | ||||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             model=self.model_id, | ||||
|             api_base=self.api_base, | ||||
|             api_key=self.api_key, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
|         response = litellm.completion(**completion_kwargs) | ||||
| 
 | ||||
|         if tools_to_call_from: | ||||
|             response = litellm.completion( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="required", | ||||
|                 stop=stop_sequences, | ||||
|                 api_base=self.api_base, | ||||
|                 api_key=self.api_key, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             response = litellm.completion( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 api_base=self.api_base, | ||||
|                 api_key=self.api_key, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         message = response.choices[0].message | ||||
| 
 | ||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||
| 
 | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
|  | @ -576,13 +634,13 @@ class OpenAIServerModel(Model): | |||
|             raise ModuleNotFoundError( | ||||
|                 "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`" | ||||
|             ) from None | ||||
|         super().__init__() | ||||
| 
 | ||||
|         super().__init__(**kwargs) | ||||
|         self.model_id = model_id | ||||
|         self.client = openai.OpenAI( | ||||
|             base_url=api_base, | ||||
|             api_key=api_key, | ||||
|         ) | ||||
|         self.kwargs = kwargs | ||||
|         self.custom_role_conversions = custom_role_conversions | ||||
| 
 | ||||
|     def __call__( | ||||
|  | @ -591,30 +649,23 @@ class OpenAIServerModel(Model): | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, | ||||
|             role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions), | ||||
|         completion_kwargs = self._prepare_completion_kwargs( | ||||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             model=self.model_id, | ||||
|             custom_role_conversions=self.custom_role_conversions, | ||||
|             **kwargs, | ||||
|         ) | ||||
|         if tools_to_call_from: | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="required", | ||||
|                 stop=stop_sequences, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
| 
 | ||||
|         response = self.client.chat.completions.create(**completion_kwargs) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         message = response.choices[0].message | ||||
| 
 | ||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue