TransformersModel auto-detects VLMs (#378)
* TransformersModel auto-detects VLMs
This commit is contained in:
		
							parent
							
								
									a5290590c8
								
							
						
					
					
						commit
						4579a6f7cc
					
				|  | @ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent | |||
| available_inferences = ["hf_api", "transformers", "ollama", "litellm"] | ||||
| chosen_inference = "transformers" | ||||
| 
 | ||||
| print(f"Chose model {chosen_inference}") | ||||
| print(f"Chose model: '{chosen_inference}'") | ||||
| 
 | ||||
| if chosen_inference == "hf_api": | ||||
|     model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") | ||||
|  |  | |||
|  | @ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent): | |||
|             tool_arguments = tool_call.function.arguments | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) | ||||
|             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e | ||||
| 
 | ||||
|         log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] | ||||
| 
 | ||||
|  |  | |||
|  | @ -418,9 +418,6 @@ class TransformersModel(Model): | |||
|             The torch_dtype to initialize your model with. | ||||
|         trust_remote_code (bool, default `False`): | ||||
|             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||
|         flatten_messages_as_text (`bool`, default `True`): | ||||
|             Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored). | ||||
|             Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs. | ||||
|         kwargs (dict, *optional*): | ||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||
|         **kwargs: | ||||
|  | @ -449,7 +446,6 @@ class TransformersModel(Model): | |||
|         device_map: Optional[str] = None, | ||||
|         torch_dtype: Optional[str] = None, | ||||
|         trust_remote_code: bool = False, | ||||
|         flatten_messages_as_text: bool = True, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__(**kwargs) | ||||
|  | @ -469,6 +465,7 @@ class TransformersModel(Model): | |||
|         if device_map is None: | ||||
|             device_map = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         logger.info(f"Using device: {device_map}") | ||||
|         self._is_vlm = False | ||||
|         try: | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, | ||||
|  | @ -481,6 +478,7 @@ class TransformersModel(Model): | |||
|             if "Unrecognized configuration class" in str(e): | ||||
|                 self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map) | ||||
|                 self.processor = AutoProcessor.from_pretrained(model_id) | ||||
|                 self._is_vlm = True | ||||
|             else: | ||||
|                 raise e | ||||
|         except Exception as e: | ||||
|  | @ -490,7 +488,6 @@ class TransformersModel(Model): | |||
|             self.model_id = default_model_id | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) | ||||
|         self.flatten_messages_as_text = flatten_messages_as_text | ||||
| 
 | ||||
|     def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList": | ||||
|         from transformers import StoppingCriteria, StoppingCriteriaList | ||||
|  | @ -526,8 +523,7 @@ class TransformersModel(Model): | |||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             flatten_messages_as_text=self.flatten_messages_as_text, | ||||
|             flatten_messages_as_text=(not self._is_vlm), | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
|  | @ -595,9 +591,19 @@ class TransformersModel(Model): | |||
|         else: | ||||
|             if "Action:" in output: | ||||
|                 output = output.split("Action:", 1)[1].strip() | ||||
|             try: | ||||
|                 start_index = output.index("{") | ||||
|                 end_index = output.rindex("}") | ||||
|                 output = output[start_index : end_index + 1] | ||||
|             except Exception as e: | ||||
|                 raise Exception("No json blob found in output!") from e | ||||
| 
 | ||||
|             try: | ||||
|                 parsed_output = json.loads(output) | ||||
|             tool_name = parsed_output.get("tool_name") | ||||
|             tool_arguments = parsed_output.get("tool_arguments") | ||||
|             except json.JSONDecodeError as e: | ||||
|                 raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}") | ||||
|             tool_name = parsed_output.get("name") | ||||
|             tool_arguments = parsed_output.get("arguments") | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|  |  | |||
|  | @ -28,11 +28,7 @@ from smolagents.agents import ( | |||
|     ToolCallingAgent, | ||||
| ) | ||||
| from smolagents.default_tools import PythonInterpreterTool | ||||
| from smolagents.models import ( | ||||
|     ChatMessage, | ||||
|     ChatMessageToolCall, | ||||
|     ChatMessageToolCallDefinition, | ||||
| ) | ||||
| from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel | ||||
| from smolagents.tools import tool | ||||
| from smolagents.types import AgentImage, AgentText | ||||
| from smolagents.utils import BASE_BUILTIN_MODULES | ||||
|  | @ -620,3 +616,26 @@ nested_answer() | |||
| 
 | ||||
|         output = agent.run("Count to 3") | ||||
|         assert output == "Correct!" | ||||
| 
 | ||||
|     def test_transformers_toolcalling_agent(self): | ||||
|         @tool | ||||
|         def get_weather(location: str, celsius: bool = False) -> str: | ||||
|             """ | ||||
|             Get weather in the next days at given location. | ||||
|             Secretly this tool does not care about the location, it hates the weather everywhere. | ||||
| 
 | ||||
|             Args: | ||||
|                 location: the location | ||||
|                 celsius: the temperature type | ||||
|             """ | ||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
| 
 | ||||
|         model = TransformersModel( | ||||
|             model_id="HuggingFaceTB/SmolLM2-360M-Instruct", | ||||
|             max_new_tokens=100, | ||||
|             device_map="auto", | ||||
|             do_sample=False, | ||||
|         ) | ||||
|         agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1) | ||||
|         agent.run("What's the weather in Paris?") | ||||
|         assert agent.logs[2].tool_calls[0].name == "get_weather" | ||||
|  |  | |||
|  | @ -57,7 +57,6 @@ class ModelTests(unittest.TestCase): | |||
|             max_new_tokens=5, | ||||
|             device_map="auto", | ||||
|             do_sample=False, | ||||
|             flatten_messages_as_text=True, | ||||
|         ) | ||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||
|         output = model(messages, stop_sequences=["great"]).content | ||||
|  | @ -72,7 +71,6 @@ class ModelTests(unittest.TestCase): | |||
|             max_new_tokens=5, | ||||
|             device_map="auto", | ||||
|             do_sample=False, | ||||
|             flatten_messages_as_text=False, | ||||
|         ) | ||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}] | ||||
|         output = model(messages, stop_sequences=["great"]).content | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue