TransformersModel auto-detects VLMs (#378)

* TransformersModel auto-detects VLMs
This commit is contained in:
Aymeric Roucher 2025-01-27 20:09:14 +01:00 committed by GitHub
parent a5290590c8
commit 4579a6f7cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 19 deletions

View File

@ -9,7 +9,7 @@ from smolagents.agents import CodeAgent, ToolCallingAgent
available_inferences = ["hf_api", "transformers", "ollama", "litellm"]
chosen_inference = "transformers"
print(f"Chose model {chosen_inference}")
print(f"Chose model: '{chosen_inference}'")
if chosen_inference == "hf_api":
model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")

View File

@ -798,7 +798,7 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]

View File

@ -418,9 +418,6 @@ class TransformersModel(Model):
The torch_dtype to initialize your model with.
trust_remote_code (bool, default `False`):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
flatten_messages_as_text (`bool`, default `True`):
Whether to flatten messages as text: this must be sent to False to use VLMs (as opposed to LLMs for which this flag can be ignored).
Caution: this parameter is experimental and will be removed in an upcoming PR as we auto-detect VLMs.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
**kwargs:
@ -449,7 +446,6 @@ class TransformersModel(Model):
device_map: Optional[str] = None,
torch_dtype: Optional[str] = None,
trust_remote_code: bool = False,
flatten_messages_as_text: bool = True,
**kwargs,
):
super().__init__(**kwargs)
@ -469,6 +465,7 @@ class TransformersModel(Model):
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_map}")
self._is_vlm = False
try:
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
@ -481,6 +478,7 @@ class TransformersModel(Model):
if "Unrecognized configuration class" in str(e):
self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
self.processor = AutoProcessor.from_pretrained(model_id)
self._is_vlm = True
else:
raise e
except Exception as e:
@ -490,7 +488,6 @@ class TransformersModel(Model):
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
self.flatten_messages_as_text = flatten_messages_as_text
def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
from transformers import StoppingCriteria, StoppingCriteriaList
@ -526,8 +523,7 @@ class TransformersModel(Model):
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
flatten_messages_as_text=self.flatten_messages_as_text,
flatten_messages_as_text=(not self._is_vlm),
**kwargs,
)
@ -595,9 +591,19 @@ class TransformersModel(Model):
else:
if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")
try:
start_index = output.index("{")
end_index = output.rindex("}")
output = output[start_index : end_index + 1]
except Exception as e:
raise Exception("No json blob found in output!") from e
try:
parsed_output = json.loads(output)
except json.JSONDecodeError as e:
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
tool_name = parsed_output.get("name")
tool_arguments = parsed_output.get("arguments")
return ChatMessage(
role="assistant",
content="",

View File

@ -28,11 +28,7 @@ from smolagents.agents import (
ToolCallingAgent,
)
from smolagents.default_tools import PythonInterpreterTool
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.utils import BASE_BUILTIN_MODULES
@ -620,3 +616,26 @@ nested_answer()
output = agent.run("Count to 3")
assert output == "Correct!"
def test_transformers_toolcalling_agent(self):
@tool
def get_weather(location: str, celsius: bool = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
model = TransformersModel(
model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
max_new_tokens=100,
device_map="auto",
do_sample=False,
)
agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1)
agent.run("What's the weather in Paris?")
assert agent.logs[2].tool_calls[0].name == "get_weather"

View File

@ -57,7 +57,6 @@ class ModelTests(unittest.TestCase):
max_new_tokens=5,
device_map="auto",
do_sample=False,
flatten_messages_as_text=True,
)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
output = model(messages, stop_sequences=["great"]).content
@ -72,7 +71,6 @@ class ModelTests(unittest.TestCase):
max_new_tokens=5,
device_map="auto",
do_sample=False,
flatten_messages_as_text=False,
)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}, {"type": "image", "image": img}]}]
output = model(messages, stop_sequences=["great"]).content