TransformersModel auto-detects VLMs (#378)
* TransformersModel auto-detects VLMs
This commit is contained in:
		
							parent
							
								
									a5290590c8
								
							
						
					
					
						commit
						4579a6f7cc
					
				|  | @ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent | ||||||
| available_inferences = ["hf_api", "transformers", "ollama", "litellm"] | available_inferences = ["hf_api", "transformers", "ollama", "litellm"] | ||||||
| chosen_inference = "transformers" | chosen_inference = "transformers" | ||||||
| 
 | 
 | ||||||
| print(f"Chose model {chosen_inference}") | print(f"Chose model: '{chosen_inference}'") | ||||||
| 
 | 
 | ||||||
| if chosen_inference == "hf_api": | if chosen_inference == "hf_api": | ||||||
|     model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") |     model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") | ||||||
|  |  | ||||||
|  | @ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|             tool_arguments = tool_call.function.arguments |             tool_arguments = tool_call.function.arguments | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) |             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] |         log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -418,9 +418,6 @@ class TransformersModel(Model): | ||||||
|             The torch_dtype to initialize your model with. |             The torch_dtype to initialize your model with. | ||||||
|         trust_remote_code (bool, default `False`): |         trust_remote_code (bool, default `False`): | ||||||
|             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. |             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||||
|         flatten_messages_as_text (`bool`, default `True`): |  | ||||||
|             Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored). |  | ||||||
|             Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs. |  | ||||||
|         kwargs (dict, *optional*): |         kwargs (dict, *optional*): | ||||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. |             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||||
|         **kwargs: |         **kwargs: | ||||||
|  | @ -449,7 +446,6 @@ class TransformersModel(Model): | ||||||
|         device_map: Optional[str] = None, |         device_map: Optional[str] = None, | ||||||
|         torch_dtype: Optional[str] = None, |         torch_dtype: Optional[str] = None, | ||||||
|         trust_remote_code: bool = False, |         trust_remote_code: bool = False, | ||||||
|         flatten_messages_as_text: bool = True, |  | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__(**kwargs) |         super().__init__(**kwargs) | ||||||
|  | @ -469,6 +465,7 @@ class TransformersModel(Model): | ||||||
|         if device_map is None: |         if device_map is None: | ||||||
|             device_map = "cuda" if torch.cuda.is_available() else "cpu" |             device_map = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|         logger.info(f"Using device: {device_map}") |         logger.info(f"Using device: {device_map}") | ||||||
|  |         self._is_vlm = False | ||||||
|         try: |         try: | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained( |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|                 model_id, |                 model_id, | ||||||
|  | @ -481,6 +478,7 @@ class TransformersModel(Model): | ||||||
|             if "Unrecognized configuration class" in str(e): |             if "Unrecognized configuration class" in str(e): | ||||||
|                 self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map) |                 self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map) | ||||||
|                 self.processor = AutoProcessor.from_pretrained(model_id) |                 self.processor = AutoProcessor.from_pretrained(model_id) | ||||||
|  |                 self._is_vlm = True | ||||||
|             else: |             else: | ||||||
|                 raise e |                 raise e | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -490,7 +488,6 @@ class TransformersModel(Model): | ||||||
|             self.model_id = default_model_id |             self.model_id = default_model_id | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) |             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) | ||||||
|         self.flatten_messages_as_text = flatten_messages_as_text |  | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList": |     def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList": | ||||||
|         from transformers import StoppingCriteria, StoppingCriteriaList |         from transformers import StoppingCriteria, StoppingCriteriaList | ||||||
|  | @ -526,8 +523,7 @@ class TransformersModel(Model): | ||||||
|             messages=messages, |             messages=messages, | ||||||
|             stop_sequences=stop_sequences, |             stop_sequences=stop_sequences, | ||||||
|             grammar=grammar, |             grammar=grammar, | ||||||
|             tools_to_call_from=tools_to_call_from, |             flatten_messages_as_text=(not self._is_vlm), | ||||||
|             flatten_messages_as_text=self.flatten_messages_as_text, |  | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -595,9 +591,19 @@ class TransformersModel(Model): | ||||||
|         else: |         else: | ||||||
|             if "Action:" in output: |             if "Action:" in output: | ||||||
|                 output = output.split("Action:", 1)[1].strip() |                 output = output.split("Action:", 1)[1].strip() | ||||||
|             parsed_output = json.loads(output) |             try: | ||||||
|             tool_name = parsed_output.get("tool_name") |                 start_index = output.index("{") | ||||||
|             tool_arguments = parsed_output.get("tool_arguments") |                 end_index = output.rindex("}") | ||||||
|  |                 output = output[start_index : end_index + 1] | ||||||
|  |             except Exception as e: | ||||||
|  |                 raise Exception("No json blob found in output!") from e | ||||||
|  | 
 | ||||||
|  |             try: | ||||||
|  |                 parsed_output = json.loads(output) | ||||||
|  |             except json.JSONDecodeError as e: | ||||||
|  |                 raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}") | ||||||
|  |             tool_name = parsed_output.get("name") | ||||||
|  |             tool_arguments = parsed_output.get("arguments") | ||||||
|             return ChatMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|  |  | ||||||
|  | @ -28,11 +28,7 @@ from smolagents.agents import ( | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
| ) | ) | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from smolagents.models import ( | from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel | ||||||
|     ChatMessage, |  | ||||||
|     ChatMessageToolCall, |  | ||||||
|     ChatMessageToolCallDefinition, |  | ||||||
| ) |  | ||||||
| from smolagents.tools import tool | from smolagents.tools import tool | ||||||
| from smolagents.types import AgentImage, AgentText | from smolagents.types import AgentImage, AgentText | ||||||
| from smolagents.utils import BASE_BUILTIN_MODULES | from smolagents.utils import BASE_BUILTIN_MODULES | ||||||
|  | @ -620,3 +616,26 @@ nested_answer() | ||||||
| 
 | 
 | ||||||
|         output = agent.run("Count to 3") |         output = agent.run("Count to 3") | ||||||
|         assert output == "Correct!" |         assert output == "Correct!" | ||||||
|  | 
 | ||||||
|  |     def test_transformers_toolcalling_agent(self): | ||||||
|  |         @tool | ||||||
|  |         def get_weather(location: str, celsius: bool = False) -> str: | ||||||
|  |             """ | ||||||
|  |             Get weather in the next days at given location. | ||||||
|  |             Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 location: the location | ||||||
|  |                 celsius: the temperature type | ||||||
|  |             """ | ||||||
|  |             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  | 
 | ||||||
|  |         model = TransformersModel( | ||||||
|  |             model_id="HuggingFaceTB/SmolLM2-360M-Instruct", | ||||||
|  |             max_new_tokens=100, | ||||||
|  |             device_map="auto", | ||||||
|  |             do_sample=False, | ||||||
|  |         ) | ||||||
|  |         agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1) | ||||||
|  |         agent.run("What's the weather in Paris?") | ||||||
|  |         assert agent.logs[2].tool_calls[0].name == "get_weather" | ||||||
|  |  | ||||||
|  | @ -57,7 +57,6 @@ class ModelTests(unittest.TestCase): | ||||||
|             max_new_tokens=5, |             max_new_tokens=5, | ||||||
|             device_map="auto", |             device_map="auto", | ||||||
|             do_sample=False, |             do_sample=False, | ||||||
|             flatten_messages_as_text=True, |  | ||||||
|         ) |         ) | ||||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] |         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||||
|         output = model(messages, stop_sequences=["great"]).content |         output = model(messages, stop_sequences=["great"]).content | ||||||
|  | @ -72,7 +71,6 @@ class ModelTests(unittest.TestCase): | ||||||
|             max_new_tokens=5, |             max_new_tokens=5, | ||||||
|             device_map="auto", |             device_map="auto", | ||||||
|             do_sample=False, |             do_sample=False, | ||||||
|             flatten_messages_as_text=False, |  | ||||||
|         ) |         ) | ||||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}] |         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}] | ||||||
|         output = model(messages, stop_sequences=["great"]).content |         output = model(messages, stop_sequences=["great"]).content | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue