Allow passing kwargs to all models (#222)
* Allow passing kwargs to all models
This commit is contained in:
		
							parent
							
								
									a1d8f3c398
								
							
						
					
					
						commit
						b4091cb5ce
					
				|  | @ -57,6 +57,10 @@ contains the API docs for the underlying classes. | ||||||
| 
 | 
 | ||||||
| [[autodoc]] VisitWebpageTool | [[autodoc]] VisitWebpageTool | ||||||
| 
 | 
 | ||||||
|  | ### UserInputTool | ||||||
|  | 
 | ||||||
|  | [[autodoc]] UserInputTool | ||||||
|  | 
 | ||||||
| ## ToolCollection | ## ToolCollection | ||||||
| 
 | 
 | ||||||
| [[autodoc]] ToolCollection | [[autodoc]] ToolCollection | ||||||
|  |  | ||||||
|  | @ -144,7 +144,7 @@ class UserInputTool(Tool): | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|     def forward(self, question): |     def forward(self, question): | ||||||
|         user_input = input(f"{question} => ") |         user_input = input(f"{question} => Type your answer here:") | ||||||
|         return user_input |         return user_input | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -222,7 +222,6 @@ class Model: | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |  | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|  | @ -233,8 +232,6 @@ class Model: | ||||||
|                 A list of strings that will stop the generation if encountered in the model's output. |                 A list of strings that will stop the generation if encountered in the model's output. | ||||||
|             grammar (`str`, *optional*): |             grammar (`str`, *optional*): | ||||||
|                 The grammar or formatting structure to use in the model's response. |                 The grammar or formatting structure to use in the model's response. | ||||||
|             max_tokens (`int`, *optional*): |  | ||||||
|                 The maximum count of tokens to generate. |  | ||||||
|         Returns: |         Returns: | ||||||
|             `str`: The text content of the model's response. |             `str`: The text content of the model's response. | ||||||
|         """ |         """ | ||||||
|  | @ -244,7 +241,7 @@ class Model: | ||||||
| class HfApiModel(Model): | class HfApiModel(Model): | ||||||
|     """A class to interact with Hugging Face's Inference API for language model interaction. |     """A class to interact with Hugging Face's Inference API for language model interaction. | ||||||
| 
 | 
 | ||||||
|     This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. |     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||||
| 
 | 
 | ||||||
|     Parameters: |     Parameters: | ||||||
|         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): |         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): | ||||||
|  | @ -265,9 +262,10 @@ class HfApiModel(Model): | ||||||
|     >>> engine = HfApiModel( |     >>> engine = HfApiModel( | ||||||
|     ...     model_id="Qwen/Qwen2.5-Coder-32B-Instruct", |     ...     model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|     ...     token="your_hf_token_here", |     ...     token="your_hf_token_here", | ||||||
|  |     ...     max_tokens=5000, | ||||||
|     ... ) |     ... ) | ||||||
|     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] |     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||||
|     >>> response = engine(messages, stop_sequences=["END"], max_tokens=1500) |     >>> response = engine(messages, stop_sequences=["END"]) | ||||||
|     >>> print(response) |     >>> print(response) | ||||||
|     "Quantum mechanics is the branch of physics that studies..." |     "Quantum mechanics is the branch of physics that studies..." | ||||||
|     ``` |     ``` | ||||||
|  | @ -279,6 +277,7 @@ class HfApiModel(Model): | ||||||
|         token: Optional[str] = None, |         token: Optional[str] = None, | ||||||
|         timeout: Optional[int] = 120, |         timeout: Optional[int] = 120, | ||||||
|         temperature: float = 0.5, |         temperature: float = 0.5, | ||||||
|  |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|  | @ -286,13 +285,13 @@ class HfApiModel(Model): | ||||||
|             token = os.getenv("HF_TOKEN") |             token = os.getenv("HF_TOKEN") | ||||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) |         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||||
|         self.temperature = temperature |         self.temperature = temperature | ||||||
|  |         self.kwargs = kwargs | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |  | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         """ |         """ | ||||||
|  | @ -308,16 +307,16 @@ class HfApiModel(Model): | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 tool_choice="auto", |                 tool_choice="auto", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 temperature=self.temperature, |                 temperature=self.temperature, | ||||||
|  |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             response = self.client.chat.completions.create( |             response = self.client.chat.completions.create( | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 temperature=self.temperature, |                 temperature=self.temperature, | ||||||
|  |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|  | @ -325,16 +324,44 @@ class HfApiModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|     """This engine initializes a model and tokenizer from the given `model_id`. |     """A class to interact with Hugging Face's Inference API for language model interaction. | ||||||
|  | 
 | ||||||
|  |     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||||
| 
 | 
 | ||||||
|     Parameters: |     Parameters: | ||||||
|         model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): |         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): | ||||||
|             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. |             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||||
|         device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.): |         device_map (`str`, *optional*): | ||||||
|             The device to load the model on (`"cpu"` or `"cuda"`). |             The device_map to initialize your model with. | ||||||
|  |         torch_dtype (`str`, *optional*): | ||||||
|  |             The torch_dtype to initialize your model with. | ||||||
|  |         kwargs (dict, *optional*): | ||||||
|  |             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||||
|  |     Raises: | ||||||
|  |         ValueError: | ||||||
|  |             If the model name is not provided. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |     ```python | ||||||
|  |     >>> engine = TransformersModel( | ||||||
|  |     ...     model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|  |     ...     device="cuda", | ||||||
|  |     ...     max_new_tokens=5000, | ||||||
|  |     ... ) | ||||||
|  |     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||||
|  |     >>> response = engine(messages, stop_sequences=["END"]) | ||||||
|  |     >>> print(response) | ||||||
|  |     "Quantum mechanics is the branch of physics that studies..." | ||||||
|  |     ``` | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_id: Optional[str] = None, | ||||||
|  |         device_map: Optional[str] = None, | ||||||
|  |         torch_dtype: Optional[str] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         if not is_torch_available(): |         if not is_torch_available(): | ||||||
|             raise ImportError("Please install torch in order to use TransformersModel.") |             raise ImportError("Please install torch in order to use TransformersModel.") | ||||||
|  | @ -347,14 +374,14 @@ class TransformersModel(Model): | ||||||
|                 f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" |                 f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" | ||||||
|             ) |             ) | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         if device is None: |         self.kwargs = kwargs | ||||||
|             device = "cuda" if torch.cuda.is_available() else "cpu" |         if device_map is None: | ||||||
|         self.device = device |             device_map = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|         logger.info(f"Using device: {self.device}") |         logger.info(f"Using device: {device_map}") | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained( |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|                 model_id, device_map=self.device |                 model_id, device_map=device_map, torch_dtype=torch_dtype | ||||||
|             ) |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|  | @ -363,7 +390,7 @@ class TransformersModel(Model): | ||||||
|             self.model_id = default_model_id |             self.model_id = default_model_id | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained( |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|                 model_id, device_map=self.device |                 model_id, device_map=device_map, torch_dtype=torch_dtype | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|  | @ -397,7 +424,6 @@ class TransformersModel(Model): | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |  | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|  | @ -422,10 +448,10 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
|         out = self.model.generate( |         out = self.model.generate( | ||||||
|             **prompt_tensor, |             **prompt_tensor, | ||||||
|             max_new_tokens=max_tokens, |  | ||||||
|             stopping_criteria=( |             stopping_criteria=( | ||||||
|                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None |                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None | ||||||
|             ), |             ), | ||||||
|  |             **self.kwargs, | ||||||
|         ) |         ) | ||||||
|         generated_tokens = out[0, count_prompt_tokens:] |         generated_tokens = out[0, count_prompt_tokens:] | ||||||
|         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||||
|  | @ -458,6 +484,19 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LiteLLMModel(Model): | class LiteLLMModel(Model): | ||||||
|  |     """This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs. | ||||||
|  | 
 | ||||||
|  |     Parameters: | ||||||
|  |         model_id (`str`): | ||||||
|  |             The model identifier to use on the server (e.g. "gpt-3.5-turbo"). | ||||||
|  |         api_base (`str`): | ||||||
|  |             The base URL of the OpenAI-compatible API server. | ||||||
|  |         api_key (`str`): | ||||||
|  |             The API key to use for authentication. | ||||||
|  |         **kwargs: | ||||||
|  |             Additional keyword arguments to pass to the OpenAI API. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model_id="anthropic/claude-3-5-sonnet-20240620", |         model_id="anthropic/claude-3-5-sonnet-20240620", | ||||||
|  | @ -482,7 +521,6 @@ class LiteLLMModel(Model): | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |  | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|  | @ -495,7 +533,6 @@ class LiteLLMModel(Model): | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 tool_choice="required", |                 tool_choice="required", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 api_base=self.api_base, |                 api_base=self.api_base, | ||||||
|                 api_key=self.api_key, |                 api_key=self.api_key, | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|  | @ -505,7 +542,6 @@ class LiteLLMModel(Model): | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 api_base=self.api_base, |                 api_base=self.api_base, | ||||||
|                 api_key=self.api_key, |                 api_key=self.api_key, | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|  | @ -516,7 +552,7 @@ class LiteLLMModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAIServerModel(Model): | class OpenAIServerModel(Model): | ||||||
|     """This engine connects to an OpenAI-compatible API server. |     """This model connects to an OpenAI-compatible API server. | ||||||
| 
 | 
 | ||||||
|     Parameters: |     Parameters: | ||||||
|         model_id (`str`): |         model_id (`str`): | ||||||
|  | @ -525,8 +561,6 @@ class OpenAIServerModel(Model): | ||||||
|             The base URL of the OpenAI-compatible API server. |             The base URL of the OpenAI-compatible API server. | ||||||
|         api_key (`str`): |         api_key (`str`): | ||||||
|             The API key to use for authentication. |             The API key to use for authentication. | ||||||
|         temperature (`float`, *optional*, defaults to 0.7): |  | ||||||
|             Controls randomness in the model's responses. Values between 0 and 2. |  | ||||||
|         **kwargs: |         **kwargs: | ||||||
|             Additional keyword arguments to pass to the OpenAI API. |             Additional keyword arguments to pass to the OpenAI API. | ||||||
|     """ |     """ | ||||||
|  | @ -536,7 +570,6 @@ class OpenAIServerModel(Model): | ||||||
|         model_id: str, |         model_id: str, | ||||||
|         api_base: str, |         api_base: str, | ||||||
|         api_key: str, |         api_key: str, | ||||||
|         temperature: float = 0.7, |  | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  | @ -545,7 +578,6 @@ class OpenAIServerModel(Model): | ||||||
|             base_url=api_base, |             base_url=api_base, | ||||||
|             api_key=api_key, |             api_key=api_key, | ||||||
|         ) |         ) | ||||||
|         self.temperature = temperature |  | ||||||
|         self.kwargs = kwargs |         self.kwargs = kwargs | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|  | @ -553,7 +585,6 @@ class OpenAIServerModel(Model): | ||||||
|         messages: List[Dict[str, str]], |         messages: List[Dict[str, str]], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |  | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|  | @ -566,8 +597,6 @@ class OpenAIServerModel(Model): | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 tool_choice="auto", |                 tool_choice="auto", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 temperature=self.temperature, |  | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|  | @ -575,8 +604,6 @@ class OpenAIServerModel(Model): | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 max_tokens=max_tokens, |  | ||||||
|                 temperature=self.temperature, |  | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |  | ||||||
|  | @ -16,7 +16,7 @@ import unittest | ||||||
| import json | import json | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| from smolagents import models, tool, ChatMessage, HfApiModel | from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ModelTests(unittest.TestCase): | class ModelTests(unittest.TestCase): | ||||||
|  | @ -46,6 +46,17 @@ class ModelTests(unittest.TestCase): | ||||||
|         assert data["content"] == "Hello!" |         assert data["content"] == "Hello!" | ||||||
| 
 | 
 | ||||||
|     def test_get_hfapi_message_no_tool(self): |     def test_get_hfapi_message_no_tool(self): | ||||||
|         model = HfApiModel() |         model = HfApiModel(max_tokens=10) | ||||||
|         messages = [{"role": "user", "content": "Hello!"}] |         messages = [{"role": "user", "content": "Hello!"}] | ||||||
|         model(messages, stop_sequences=["great"]) |         model(messages, stop_sequences=["great"]) | ||||||
|  | 
 | ||||||
|  |     def test_transformers_message_no_tool(self): | ||||||
|  |         model = TransformersModel( | ||||||
|  |             model_id="HuggingFaceTB/SmolLM2-135M-Instruct", | ||||||
|  |             max_new_tokens=5, | ||||||
|  |             device_map="auto", | ||||||
|  |             do_sample=False, | ||||||
|  |         ) | ||||||
|  |         messages = [{"role": "user", "content": "Hello!"}] | ||||||
|  |         output = model(messages, stop_sequences=["great"]).content | ||||||
|  |         assert output == "assistant\nHello" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue