From 4579a6f7ccf79f9b347c87510f9ec2a503b3a150 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:09:14 +0100 Subject: [PATCH] TransformersModel auto-detects VLMs (#378) * TransformersModel auto-detects VLMs --- examples/agent_from_any_llm.py | 2 +- src/smolagents/agents.py | 2 +- src/smolagents/models.py | 26 ++++++++++++++++---------- tests/test_agents.py | 29 ++++++++++++++++++++++++----- tests/test_models.py | 2 -- 5 files changed, 42 insertions(+), 19 deletions(-) diff --git a/examples/agent_from_any_llm.py b/examples/agent_from_any_llm.py index eff667f..eb07991 100644 --- a/examples/agent_from_any_llm.py +++ b/examples/agent_from_any_llm.py @@ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent available_inferences = ["hf_api", "transformers", "ollama", "litellm"] chosen_inference = "transformers" -print(f"Chose model {chosen_inference}") +print(f"Chose model: '{chosen_inference}'") if chosen_inference == "hf_api": model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index b7111e8..2ad9af5 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent): tool_arguments = tool_call.function.arguments except Exception as e: - raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) + raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 9a43005..6e31c89 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -418,9 +418,6 @@ class TransformersModel(Model): The torch_dtype to initialize your model with. trust_remote_code (bool, default `False`): Some models on the Hub require running remote code: for this model, you would have to set this flag to True. - flatten_messages_as_text (`bool`, default `True`): - Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored). - Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs. kwargs (dict, *optional*): Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. **kwargs: @@ -449,7 +446,6 @@ class TransformersModel(Model): device_map: Optional[str] = None, torch_dtype: Optional[str] = None, trust_remote_code: bool = False, - flatten_messages_as_text: bool = True, **kwargs, ): super().__init__(**kwargs) @@ -469,6 +465,7 @@ class TransformersModel(Model): if device_map is None: device_map = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device_map}") + self._is_vlm = False try: self.model = AutoModelForCausalLM.from_pretrained( model_id, @@ -481,6 +478,7 @@ class TransformersModel(Model): if "Unrecognized configuration class" in str(e): self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map) self.processor = AutoProcessor.from_pretrained(model_id) + self._is_vlm = True else: raise e except Exception as e: @@ -490,7 +488,6 @@ class TransformersModel(Model): self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) - self.flatten_messages_as_text = flatten_messages_as_text def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList": from transformers import StoppingCriteria, StoppingCriteriaList @@ -526,8 +523,7 @@ class TransformersModel(Model): messages=messages, stop_sequences=stop_sequences, grammar=grammar, - tools_to_call_from=tools_to_call_from, - flatten_messages_as_text=self.flatten_messages_as_text, + flatten_messages_as_text=(not self._is_vlm), **kwargs, ) @@ -595,9 +591,19 @@ class TransformersModel(Model): else: if "Action:" in output: output = output.split("Action:", 1)[1].strip() - parsed_output = json.loads(output) - tool_name = parsed_output.get("tool_name") - tool_arguments = parsed_output.get("tool_arguments") + try: + start_index = output.index("{") + end_index = output.rindex("}") + output = output[start_index : end_index + 1] + except Exception as e: + raise Exception("No json blob found in output!") from e + + try: + parsed_output = json.loads(output) + except json.JSONDecodeError as e: + raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}") + tool_name = parsed_output.get("name") + tool_arguments = parsed_output.get("arguments") return ChatMessage( role="assistant", content="", diff --git a/tests/test_agents.py b/tests/test_agents.py index 1dcb5e9..d4cbda2 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -28,11 +28,7 @@ from smolagents.agents import ( ToolCallingAgent, ) from smolagents.default_tools import PythonInterpreterTool -from smolagents.models import ( - ChatMessage, - ChatMessageToolCall, - ChatMessageToolCallDefinition, -) +from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel from smolagents.tools import tool from smolagents.types import AgentImage, AgentText from smolagents.utils import BASE_BUILTIN_MODULES @@ -620,3 +616,26 @@ nested_answer() output = agent.run("Count to 3") assert output == "Correct!" + + def test_transformers_toolcalling_agent(self): + @tool + def get_weather(location: str, celsius: bool = False) -> str: + """ + Get weather in the next days at given location. + Secretly this tool does not care about the location, it hates the weather everywhere. + + Args: + location: the location + celsius: the temperature type + """ + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + model = TransformersModel( + model_id="HuggingFaceTB/SmolLM2-360M-Instruct", + max_new_tokens=100, + device_map="auto", + do_sample=False, + ) + agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1) + agent.run("What's the weather in Paris?") + assert agent.logs[2].tool_calls[0].name == "get_weather" diff --git a/tests/test_models.py b/tests/test_models.py index cd3c96f..1857ccd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -57,7 +57,6 @@ class ModelTests(unittest.TestCase): max_new_tokens=5, device_map="auto", do_sample=False, - flatten_messages_as_text=True, ) messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] output = model(messages, stop_sequences=["great"]).content @@ -72,7 +71,6 @@ class ModelTests(unittest.TestCase): max_new_tokens=5, device_map="auto", do_sample=False, - flatten_messages_as_text=False, ) messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}] output = model(messages, stop_sequences=["great"]).content