Allow passing kwargs to all models (#222)
* Allow passing kwargs to all models
This commit is contained in:
		
							parent
							
								
									a1d8f3c398
								
							
						
					
					
						commit
						b4091cb5ce
					
				|  | @ -57,6 +57,10 @@ contains the API docs for the underlying classes. | |||
| 
 | ||||
| [[autodoc]] VisitWebpageTool | ||||
| 
 | ||||
| ### UserInputTool | ||||
| 
 | ||||
| [[autodoc]] UserInputTool | ||||
| 
 | ||||
| ## ToolCollection | ||||
| 
 | ||||
| [[autodoc]] ToolCollection | ||||
|  |  | |||
|  | @ -144,7 +144,7 @@ class UserInputTool(Tool): | |||
|     output_type = "string" | ||||
| 
 | ||||
|     def forward(self, question): | ||||
|         user_input = input(f"{question} => ") | ||||
|         user_input = input(f"{question} => Type your answer here:") | ||||
|         return user_input | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -222,7 +222,6 @@ class Model: | |||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|     ) -> ChatMessage: | ||||
|         """Process the input messages and return the model's response. | ||||
| 
 | ||||
|  | @ -233,8 +232,6 @@ class Model: | |||
|                 A list of strings that will stop the generation if encountered in the model's output. | ||||
|             grammar (`str`, *optional*): | ||||
|                 The grammar or formatting structure to use in the model's response. | ||||
|             max_tokens (`int`, *optional*): | ||||
|                 The maximum count of tokens to generate. | ||||
|         Returns: | ||||
|             `str`: The text content of the model's response. | ||||
|         """ | ||||
|  | @ -244,7 +241,7 @@ class Model: | |||
| class HfApiModel(Model): | ||||
|     """A class to interact with Hugging Face's Inference API for language model interaction. | ||||
| 
 | ||||
|     This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||
|     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||
| 
 | ||||
|     Parameters: | ||||
|         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): | ||||
|  | @ -265,9 +262,10 @@ class HfApiModel(Model): | |||
|     >>> engine = HfApiModel( | ||||
|     ...     model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|     ...     token="your_hf_token_here", | ||||
|     ...     max_tokens=5000, | ||||
|     ... ) | ||||
|     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||
|     >>> response = engine(messages, stop_sequences=["END"], max_tokens=1500) | ||||
|     >>> response = engine(messages, stop_sequences=["END"]) | ||||
|     >>> print(response) | ||||
|     "Quantum mechanics is the branch of physics that studies..." | ||||
|     ``` | ||||
|  | @ -279,6 +277,7 @@ class HfApiModel(Model): | |||
|         token: Optional[str] = None, | ||||
|         timeout: Optional[int] = 120, | ||||
|         temperature: float = 0.5, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.model_id = model_id | ||||
|  | @ -286,13 +285,13 @@ class HfApiModel(Model): | |||
|             token = os.getenv("HF_TOKEN") | ||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||
|         self.temperature = temperature | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatMessage: | ||||
|         """ | ||||
|  | @ -308,16 +307,16 @@ class HfApiModel(Model): | |||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="auto", | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|  | @ -325,16 +324,44 @@ class HfApiModel(Model): | |||
| 
 | ||||
| 
 | ||||
| class TransformersModel(Model): | ||||
|     """This engine initializes a model and tokenizer from the given `model_id`. | ||||
|     """A class to interact with Hugging Face's Inference API for language model interaction. | ||||
| 
 | ||||
|     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||
| 
 | ||||
|     Parameters: | ||||
|         model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): | ||||
|         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): | ||||
|             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||
|         device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.): | ||||
|             The device to load the model on (`"cpu"` or `"cuda"`). | ||||
|         device_map (`str`, *optional*): | ||||
|             The device_map to initialize your model with. | ||||
|         torch_dtype (`str`, *optional*): | ||||
|             The torch_dtype to initialize your model with. | ||||
|         kwargs (dict, *optional*): | ||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||
|     Raises: | ||||
|         ValueError: | ||||
|             If the model name is not provided. | ||||
| 
 | ||||
|     Example: | ||||
|     ```python | ||||
|     >>> engine = TransformersModel( | ||||
|     ...     model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|     ...     device="cuda", | ||||
|     ...     max_new_tokens=5000, | ||||
|     ... ) | ||||
|     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||
|     >>> response = engine(messages, stop_sequences=["END"]) | ||||
|     >>> print(response) | ||||
|     "Quantum mechanics is the branch of physics that studies..." | ||||
|     ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_id: Optional[str] = None, | ||||
|         device_map: Optional[str] = None, | ||||
|         torch_dtype: Optional[str] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         if not is_torch_available(): | ||||
|             raise ImportError("Please install torch in order to use TransformersModel.") | ||||
|  | @ -347,14 +374,14 @@ class TransformersModel(Model): | |||
|                 f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" | ||||
|             ) | ||||
|         self.model_id = model_id | ||||
|         if device is None: | ||||
|             device = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         self.device = device | ||||
|         logger.info(f"Using device: {self.device}") | ||||
|         self.kwargs = kwargs | ||||
|         if device_map is None: | ||||
|             device_map = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         logger.info(f"Using device: {device_map}") | ||||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, device_map=self.device | ||||
|                 model_id, device_map=device_map, torch_dtype=torch_dtype | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|  | @ -363,7 +390,7 @@ class TransformersModel(Model): | |||
|             self.model_id = default_model_id | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, device_map=self.device | ||||
|                 model_id, device_map=device_map, torch_dtype=torch_dtype | ||||
|             ) | ||||
| 
 | ||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||
|  | @ -397,7 +424,6 @@ class TransformersModel(Model): | |||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|  | @ -422,10 +448,10 @@ class TransformersModel(Model): | |||
| 
 | ||||
|         out = self.model.generate( | ||||
|             **prompt_tensor, | ||||
|             max_new_tokens=max_tokens, | ||||
|             stopping_criteria=( | ||||
|                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None | ||||
|             ), | ||||
|             **self.kwargs, | ||||
|         ) | ||||
|         generated_tokens = out[0, count_prompt_tokens:] | ||||
|         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||
|  | @ -458,6 +484,19 @@ class TransformersModel(Model): | |||
| 
 | ||||
| 
 | ||||
| class LiteLLMModel(Model): | ||||
|     """This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs. | ||||
| 
 | ||||
|     Parameters: | ||||
|         model_id (`str`): | ||||
|             The model identifier to use on the server (e.g. "gpt-3.5-turbo"). | ||||
|         api_base (`str`): | ||||
|             The base URL of the OpenAI-compatible API server. | ||||
|         api_key (`str`): | ||||
|             The API key to use for authentication. | ||||
|         **kwargs: | ||||
|             Additional keyword arguments to pass to the OpenAI API. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_id="anthropic/claude-3-5-sonnet-20240620", | ||||
|  | @ -482,7 +521,6 @@ class LiteLLMModel(Model): | |||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|  | @ -495,7 +533,6 @@ class LiteLLMModel(Model): | |||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="required", | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 api_base=self.api_base, | ||||
|                 api_key=self.api_key, | ||||
|                 **self.kwargs, | ||||
|  | @ -505,7 +542,6 @@ class LiteLLMModel(Model): | |||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 api_base=self.api_base, | ||||
|                 api_key=self.api_key, | ||||
|                 **self.kwargs, | ||||
|  | @ -516,7 +552,7 @@ class LiteLLMModel(Model): | |||
| 
 | ||||
| 
 | ||||
| class OpenAIServerModel(Model): | ||||
|     """This engine connects to an OpenAI-compatible API server. | ||||
|     """This model connects to an OpenAI-compatible API server. | ||||
| 
 | ||||
|     Parameters: | ||||
|         model_id (`str`): | ||||
|  | @ -525,8 +561,6 @@ class OpenAIServerModel(Model): | |||
|             The base URL of the OpenAI-compatible API server. | ||||
|         api_key (`str`): | ||||
|             The API key to use for authentication. | ||||
|         temperature (`float`, *optional*, defaults to 0.7): | ||||
|             Controls randomness in the model's responses. Values between 0 and 2. | ||||
|         **kwargs: | ||||
|             Additional keyword arguments to pass to the OpenAI API. | ||||
|     """ | ||||
|  | @ -536,7 +570,6 @@ class OpenAIServerModel(Model): | |||
|         model_id: str, | ||||
|         api_base: str, | ||||
|         api_key: str, | ||||
|         temperature: float = 0.7, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | @ -545,7 +578,6 @@ class OpenAIServerModel(Model): | |||
|             base_url=api_base, | ||||
|             api_key=api_key, | ||||
|         ) | ||||
|         self.temperature = temperature | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def __call__( | ||||
|  | @ -553,7 +585,6 @@ class OpenAIServerModel(Model): | |||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|  | @ -566,8 +597,6 @@ class OpenAIServerModel(Model): | |||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="auto", | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         else: | ||||
|  | @ -575,8 +604,6 @@ class OpenAIServerModel(Model): | |||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 stop=stop_sequences, | ||||
|                 max_tokens=max_tokens, | ||||
|                 temperature=self.temperature, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ import unittest | |||
| import json | ||||
| from typing import Optional | ||||
| 
 | ||||
| from smolagents import models, tool, ChatMessage, HfApiModel | ||||
| from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel | ||||
| 
 | ||||
| 
 | ||||
| class ModelTests(unittest.TestCase): | ||||
|  | @ -46,6 +46,17 @@ class ModelTests(unittest.TestCase): | |||
|         assert data["content"] == "Hello!" | ||||
| 
 | ||||
|     def test_get_hfapi_message_no_tool(self): | ||||
|         model = HfApiModel() | ||||
|         model = HfApiModel(max_tokens=10) | ||||
|         messages = [{"role": "user", "content": "Hello!"}] | ||||
|         model(messages, stop_sequences=["great"]) | ||||
| 
 | ||||
|     def test_transformers_message_no_tool(self): | ||||
|         model = TransformersModel( | ||||
|             model_id="HuggingFaceTB/SmolLM2-135M-Instruct", | ||||
|             max_new_tokens=5, | ||||
|             device_map="auto", | ||||
|             do_sample=False, | ||||
|         ) | ||||
|         messages = [{"role": "user", "content": "Hello!"}] | ||||
|         output = model(messages, stop_sequences=["great"]).content | ||||
|         assert output == "assistant\nHello" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue