Make import time faster (optional deps + delay imports) (#253)

* adapt docs

* optional in pyproject.toml

* get rid of some transformers imports

* optional transformers in models.py

* gradio, transformers, litellm

* small refacto AgentType

* merge conflicts

* mouaif

* fix tests

* AgentText no longer a str

* Add back AgentType as str/Image

* fixed for good
This commit is contained in:
Lucain 2025-01-20 10:59:43 +01:00 committed by GitHub
parent a2ca95107f
commit d19ebc7a48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 593 additions and 259 deletions

View File

@ -55,6 +55,7 @@ agent.run(
<hfoption id="Local Transformers Model"> <hfoption id="Local Transformers Model">
```python ```python
# !pip install smolagents[transformers]
from smolagents import CodeAgent, TransformersModel from smolagents import CodeAgent, TransformersModel
model_id = "meta-llama/Llama-3.2-3B-Instruct" model_id = "meta-llama/Llama-3.2-3B-Instruct"
@ -72,6 +73,7 @@ agent.run(
To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization. To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization.
```python ```python
# !pip install smolagents[litellm]
from smolagents import CodeAgent, LiteLLMModel from smolagents import CodeAgent, LiteLLMModel
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o' model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o'
@ -85,6 +87,7 @@ agent.run(
<hfoption id="Ollama"> <hfoption id="Ollama">
```python ```python
# !pip install smolagents[litellm]
from smolagents import CodeAgent, LiteLLMModel from smolagents import CodeAgent, LiteLLMModel
model = LiteLLMModel( model = LiteLLMModel(

View File

@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization.
### GradioUI ### GradioUI
> [!TIP]
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
[[autodoc]] GradioUI [[autodoc]] GradioUI
## Models ## Models
@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
>>> What a >>> What a
``` ```
> [!TIP]
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
[[autodoc]] TransformersModel [[autodoc]] TransformersModel
### HfApiModel ### HfApiModel

View File

@ -61,6 +61,7 @@ agent.run(
<hfoption id="本地Transformers模型"> <hfoption id="本地Transformers模型">
```python ```python
# !pip install smolagents[transformers]
from smolagents import CodeAgent, TransformersModel from smolagents import CodeAgent, TransformersModel
model_id = "meta-llama/Llama-3.2-3B-Instruct" model_id = "meta-llama/Llama-3.2-3B-Instruct"
@ -78,6 +79,7 @@ agent.run(
要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY``OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。 要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY``OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。
```python ```python
# !pip install smolagents[litellm]
from smolagents import CodeAgent, LiteLLMModel from smolagents import CodeAgent, LiteLLMModel
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o' model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o'
@ -91,6 +93,7 @@ agent.run(
<hfoption id="Ollama"> <hfoption id="Ollama">
```python ```python
# !pip install smolagents[litellm]
from smolagents import CodeAgent, LiteLLMModel from smolagents import CodeAgent, LiteLLMModel
model = LiteLLMModel( model = LiteLLMModel(

View File

@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization.
### GradioUI ### GradioUI
> [!TIP]
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
[[autodoc]] GradioUI [[autodoc]] GradioUI
## Models ## Models
@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
>>> What a >>> What a
``` ```
> [!TIP]
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
[[autodoc]] TransformersModel [[autodoc]] TransformersModel
### HfApiModel ### HfApiModel

View File

@ -12,14 +12,12 @@ authors = [
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"transformers>=4.0.0",
"requests>=2.32.3", "requests>=2.32.3",
"rich>=13.9.4", "rich>=13.9.4",
"pandas>=2.2.3", "pandas>=2.2.3",
"jinja2>=3.1.4", "jinja2>=3.1.4",
"pillow>=11.0.0", "pillow>=11.0.0",
"markdownify>=0.14.1", "markdownify>=0.14.1",
"gradio>=5.8.0",
"duckduckgo-search>=6.3.7", "duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3", "e2b-code-interpreter>=1.0.3",
@ -27,12 +25,20 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
audio = [
"soundfile",
]
torch = [ torch = [
"torch", "torch",
]
audio = [
"soundfile",
"smolagents[torch]",
]
transformers = [
"accelerate", "accelerate",
"transformers>=4.0.0",
"smolagents[torch]",
]
gradio = [
"gradio>=5.8.0",
] ]
litellm = [ litellm = [
"litellm>=1.55.10", "litellm>=1.55.10",
@ -47,9 +53,12 @@ openai = [
quality = [ quality = [
"ruff>=0.9.0", "ruff>=0.9.0",
] ]
all = [
"smolagents[accelerate,audio,gradio,litellm,mcp,openai,transformers]",
]
test = [ test = [
"pytest>=8.1.0", "pytest>=8.1.0",
"smolagents[audio,litellm,mcp,openai,torch]", "smolagents[all]",
] ]
dev = [ dev = [
"smolagents[quality,test]", "smolagents[quality,test]",

View File

@ -16,36 +16,14 @@
# limitations under the License. # limitations under the License.
__version__ = "1.5.0.dev" __version__ = "1.5.0.dev"
from typing import TYPE_CHECKING from .agents import *
from .default_tools import *
from transformers.utils import _LazyModule from .e2b_executor import *
from transformers.utils.import_utils import define_import_structure from .gradio_ui import *
from .local_python_executor import *
from .models import *
if TYPE_CHECKING: from .monitoring import *
from .agents import * from .prompts import *
from .default_tools import * from .tools import *
from .e2b_executor import * from .types import *
from .gradio_ui import * from .utils import *
from .local_python_executor import *
from .models import *
from .monitoring import *
from .prompts import *
from .tools import *
from .types import *
from .utils import *
else:
import sys
_file = globals()["__file__"]
import_structure = define_import_structure(_file)
import_structure[""] = {"__version__": __version__}
sys.modules[__name__] = _LazyModule(
__name__,
_file,
import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__},
)

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains utilities exclusively taken from `transformers` repository.
Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have
been duplicated.
TODO: move them to `huggingface_hub` to avoid code duplication.
"""
import inspect
import json
import os
import re
import types
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
get_args,
get_origin,
get_type_hints,
)
from huggingface_hub.utils import is_torch_available
from .utils import _is_pillow_available
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
"""
Extracts all the libraries (not relative imports this time) that are imported in a file.
Args:
filename (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of all packages required to use the input module.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
# filter out try/except block so in custom code we can have try/except imports
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
content = re.sub(
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
"",
content,
flags=re.MULTILINE,
)
# Imports of the form `import xxx`
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports))
class TypeHintParsingException(Exception):
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
class DocstringParsingException(Exception):
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
def get_json_schema(func: Callable) -> Dict:
"""
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
that the function has a docstring, and that each argument has a description in the docstring, in the standard
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
optional because most chat templates ignore the return value of the function.
Args:
func: The function to generate a JSON schema for.
Returns:
A dictionary containing the JSON schema for the function.
Examples:
```python
>>> def multiply(x: float, y: float):
>>> '''
>>> A function that multiplies two numbers
>>>
>>> Args:
>>> x: The first number to multiply
>>> y: The second number to multiply
>>> '''
>>> return x * y
>>>
>>> print(get_json_schema(multiply))
{
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "number", "description": "The first number to multiply"},
"y": {"type": "number", "description": "The second number to multiply"}
},
"required": ["x", "y"]
}
}
```
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
support them, like so:
```python
>>> from transformers import AutoTokenizer
>>> from transformers.utils import get_json_schema
>>>
>>> def multiply(x: float, y: float):
>>> '''
>>> A function that multiplies two numbers
>>>
>>> Args:
>>> x: The first number to multiply
>>> y: The second number to multiply
>>> return x * y
>>> '''
>>>
>>> multiply_schema = get_json_schema(multiply)
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
>>> formatted_chat = tokenizer.apply_chat_template(
>>> messages,
>>> tools=[multiply_schema],
>>> chat_template="tool_use",
>>> return_dict=True,
>>> return_tensors="pt",
>>> add_generation_prompt=True
>>> )
>>> # The formatted chat can now be passed to model.generate()
```
Each argument description can also have an optional `(choices: ...)` block at the end, such as
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
only be parsed correctly if it is at the end of the line:
```python
>>> def drink_beverage(beverage: str):
>>> '''
>>> A function that drinks a beverage
>>>
>>> Args:
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
>>> '''
>>> pass
>>>
>>> print(get_json_schema(drink_beverage))
```
{
'name': 'drink_beverage',
'description': 'A function that drinks a beverage',
'parameters': {
'type': 'object',
'properties': {
'beverage': {
'type': 'string',
'enum': ['tea', 'coffee'],
'description': 'The beverage to drink'
}
},
'required': ['beverage']
}
}
"""
doc = inspect.getdoc(func)
if not doc:
raise DocstringParsingException(
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
)
doc = doc.strip()
main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc)
json_schema = _convert_type_hints_to_json_schema(func)
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
return_dict["description"] = return_doc
for arg, schema in json_schema["properties"].items():
if arg not in param_descriptions:
raise DocstringParsingException(
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
)
desc = param_descriptions[arg]
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
if enum_choices:
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
desc = enum_choices.string[: enum_choices.start()].strip()
schema["description"] = desc
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}
# Extracts the initial segment of the docstring, containing the function description
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
# Extracts the Args: block from the docstring
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
# Splits the Args: block into individual arguments
args_split_re = re.compile(
r"""
(?:^|\n) # Match the start of the args block, or a newline
\s*(\w+):\s* # Capture the argument name and strip spacing
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
""",
re.DOTALL | re.VERBOSE,
)
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
def _parse_google_format_docstring(
docstring: str,
) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
"""
Parses a Google-style docstring to extract the function description,
argument descriptions, and return description.
Args:
docstring (str): The docstring to parse.
Returns:
The function description, arguments, and return description.
"""
# Extract the sections
description_match = description_re.search(docstring)
args_match = args_re.search(docstring)
returns_match = returns_re.search(docstring)
# Clean and store the sections
description = description_match.group(1).strip() if description_match else None
docstring_args = args_match.group(1).strip() if args_match else None
returns = returns_match.group(1).strip() if returns_match else None
# Parsing the arguments into a dictionary
if docstring_args is not None:
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
matches = args_split_re.findall(docstring_args)
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
else:
args_dict = {}
return description, args_dict, returns
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
required = []
for param_name, param in signature.parameters.items():
if param.annotation == inspect.Parameter.empty:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param.default == inspect.Parameter.empty:
required.append(param_name)
properties = {}
for param_name, param_type in type_hints.items():
properties[param_name] = _parse_type_hint(param_type)
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
def _parse_type_hint(hint: str) -> Dict:
origin = get_origin(hint)
args = get_args(hint)
if origin is None:
try:
return _get_json_schema_type(hint)
except KeyError:
raise TypeHintParsingException(
"Couldn't parse this type hint, likely due to a custom class or object: ",
hint,
)
elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
if len(subtypes) == 1:
# A single non-null type can be expressed directly
return_dict = subtypes[0]
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
# A union of basic types can be expressed as a list in the schema
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
else:
# A union of more complex types requires "anyOf"
return_dict = {"anyOf": subtypes}
if type(None) in args:
return_dict["nullable"] = True
return return_dict
elif origin is list:
if not args:
return {"type": "array"}
else:
# Lists can only have a single type argument, so recurse into it
return {"type": "array", "items": _parse_type_hint(args[0])}
elif origin is tuple:
if not args:
return {"type": "array"}
if len(args) == 1:
raise TypeHintParsingException(
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
"more than one element, we recommend "
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
"pass the element directly."
)
if ... in args:
raise TypeHintParsingException(
"Conversion of '...' is not supported in Tuple type hints. "
"Use List[] types for variable-length"
" inputs instead."
)
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
elif origin is dict:
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
# However, we can specify the type of the dict values with "additionalProperties"
out = {"type": "object"}
if len(args) == 2:
out["additionalProperties"] = _parse_type_hint(args[1])
return out
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
_BASE_TYPE_MAPPING = {
int: {"type": "integer"},
float: {"type": "number"},
str: {"type": "string"},
bool: {"type": "boolean"},
Any: {},
}
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
if param_type in _BASE_TYPE_MAPPING:
return _BASE_TYPE_MAPPING[param_type]
if str(param_type) == "Image" and _is_pillow_available():
from PIL.Image import Image
if param_type == Image:
return {"type": "image"}
if str(param_type) == "Tensor" and is_torch_available():
from torch import Tensor
if param_type == Tensor:
return {"type": "audio"}
return {"type": "object"}

View File

@ -14,33 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode, is_torch_available
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS, BASE_PYTHON_TOOLS,
evaluate_python_code, evaluate_python_code,
) )
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .tools import PipelineTool, Tool
from .types import AgentAudio from .types import AgentAudio
if is_torch_available():
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
else:
WhisperForConditionalGeneration = object
WhisperProcessor = object
@dataclass @dataclass
class PreTool: class PreTool:
name: str name: str
@ -51,31 +37,6 @@ class PreTool:
repo_id: str repo_id: str
def get_remote_tools(logger, organization="huggingface-tools"):
if is_offline_mode():
logger.info("You are in offline mode, so remote tools are not available.")
return {}
spaces = list_spaces(author=organization)
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool(
task=task,
description=config["description"],
repo_id=repo_id,
name=task,
inputs=config["inputs"],
output_type=config["output_type"],
)
return tools
class PythonInterpreterTool(Tool): class PythonInterpreterTool(Tool):
name = "python_interpreter" name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations." description = "This is a tool that evaluates python code. It can be used to perform calculations."
@ -150,10 +111,10 @@ class DuckDuckGoSearchTool(Tool):
self.max_results = max_results self.max_results = max_results
try: try:
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
) ) from e
self.ddgs = DDGS() self.ddgs = DDGS()
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
@ -259,10 +220,10 @@ class VisitWebpageTool(Tool):
from requests.exceptions import RequestException from requests.exceptions import RequestException
from smolagents.utils import truncate_content from smolagents.utils import truncate_content
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
) ) from e
try: try:
# Send a GET request to the URL # Send a GET request to the URL
response = requests.get(url) response = requests.get(url)
@ -286,9 +247,6 @@ class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-large-v3-turbo" default_checkpoint = "openai/whisper-large-v3-turbo"
description = "This is a tool that transcribes an audio into text. It returns the transcribed text." description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
name = "transcriber" name = "transcriber"
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = { inputs = {
"audio": { "audio": {
"type": "audio", "type": "audio",
@ -297,6 +255,18 @@ class SpeechToTextTool(PipelineTool):
} }
output_type = "string" output_type = "string"
def __new__(cls):
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
if not hasattr(cls, "pre_processor_class"):
cls.pre_processor_class = WhisperProcessor
if not hasattr(cls, "model_class"):
cls.model_class = WhisperForConditionalGeneration
return super().__new__()
def encode(self, audio): def encode(self, audio):
audio = AgentAudio(audio).to_raw() audio = AgentAudio(audio).to_raw()
return self.pre_processor(audio, return_tensors="pt") return self.pre_processor(audio, return_tensors="pt")

View File

@ -19,14 +19,15 @@ import re
import shutil import shutil
from typing import Optional from typing import Optional
import gradio as gr
from .agents import ActionStep, AgentStepLog, MultiStepAgent from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from .utils import _is_package_available
def pull_messages_from_step(step_log: AgentStepLog): def pull_messages_from_step(step_log: AgentStepLog):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""
import gradio as gr
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
if step_log.tool_calls is not None: if step_log.tool_calls is not None:
@ -57,6 +58,11 @@ def stream_to_gradio(
additional_args: Optional[dict] = None, additional_args: Optional[dict] = None,
): ):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
if not _is_package_available("gradio"):
raise ModuleNotFoundError(
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`"
)
import gradio as gr
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
for message in pull_messages_from_step(step_log): for message in pull_messages_from_step(step_log):
@ -88,6 +94,10 @@ class GradioUI:
"""A one-line interface to launch your agent in Gradio""" """A one-line interface to launch your agent in Gradio"""
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None): def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
if not _is_package_available("gradio"):
raise ModuleNotFoundError(
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`"
)
self.agent = agent self.agent = agent
self.file_upload_folder = file_upload_folder self.file_upload_folder = file_upload_folder
if self.file_upload_folder is not None: if self.file_upload_folder is not None:
@ -95,6 +105,8 @@ class GradioUI:
os.mkdir(file_upload_folder) os.mkdir(file_upload_folder)
def interact_with_agent(self, prompt, messages): def interact_with_agent(self, prompt, messages):
import gradio as gr
messages.append(gr.ChatMessage(role="user", content=prompt)) messages.append(gr.ChatMessage(role="user", content=prompt))
yield messages yield messages
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
@ -115,6 +127,7 @@ class GradioUI:
""" """
Handle file uploads, default allowed types are .pdf, .docx, and .txt Handle file uploads, default allowed types are .pdf, .docx, and .txt
""" """
import gradio as gr
if file is None: if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log return gr.Textbox("No file uploaded", visible=True), file_uploads_log
@ -161,6 +174,8 @@ class GradioUI:
) )
def launch(self): def launch(self):
import gradio as gr
with gr.Blocks() as demo: with gr.Blocks() as demo:
stored_messages = gr.State([]) stored_messages = gr.State([])
file_uploads_log = gr.State([]) file_uploads_log = gr.State([])

View File

@ -21,20 +21,18 @@ import random
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from transformers import ( from huggingface_hub.utils import is_torch_available
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
is_torch_available,
)
from .tools import Tool from .tools import Tool
from .utils import _is_package_available
if TYPE_CHECKING:
from transformers import StoppingCriteriaList
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = { DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
@ -320,6 +318,9 @@ class TransformersModel(Model):
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
> [!TIP]
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
Parameters: Parameters:
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
@ -358,9 +359,12 @@ class TransformersModel(Model):
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
if not is_torch_available(): if not is_torch_available() or not _is_package_available("transformers"):
raise ImportError("Please install torch in order to use TransformersModel.") raise ModuleNotFoundError(
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
)
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None: if model_id is None:
@ -387,7 +391,9 @@ class TransformersModel(Model):
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: def make_stopping_criteria(self, stop_sequences: List[str]) -> "StoppingCriteriaList":
from transformers import StoppingCriteria, StoppingCriteriaList
class StopOnStrings(StoppingCriteria): class StopOnStrings(StoppingCriteria):
def __init__(self, stop_strings: List[str], tokenizer): def __init__(self, stop_strings: List[str], tokenizer):
self.stop_strings = stop_strings self.stop_strings = stop_strings
@ -491,6 +497,7 @@ class LiteLLMModel(Model):
raise ModuleNotFoundError( raise ModuleNotFoundError(
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
) )
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
@ -506,9 +513,10 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None, grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
import litellm import litellm
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
if tools_to_call_from: if tools_to_call_from:
response = litellm.completion( response = litellm.completion(
model=self.model_id, model=self.model_id,

View File

@ -35,54 +35,22 @@ from huggingface_hub import (
metadata_update, metadata_update,
upload_folder, upload_folder,
) )
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import is_torch_available
from packaging import version from packaging import version
from transformers.dynamic_module_utils import get_imports
from transformers.utils import (
TypeHintParsingException,
cached_file,
get_json_schema,
is_accelerate_available,
is_torch_available,
)
from transformers.utils.chat_template_utils import _parse_type_hint
from ._transformers_utils import (
TypeHintParsingException,
_parse_type_hint,
get_imports,
get_json_schema,
)
from .tool_validation import MethodChecker, validate_tool_attributes from .tool_validation import MethodChecker, validate_tool_attributes
from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .types import handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source from .utils import _is_package_available, _is_pillow_available, instance_to_source
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device
if is_torch_available():
from transformers import AutoProcessor
else:
AutoProcessor = object
TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None:
return repo_type
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
return "space"
except RepositoryNotFoundError:
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
except Exception:
return "model"
except Exception:
return "space"
def validate_after_init(cls): def validate_after_init(cls):
original_init = cls.__init__ original_init = cls.__init__
@ -337,12 +305,8 @@ class Tool:
) )
# Save requirements file # Save requirements file
imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"}
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
imports = []
for module in [tool_file]:
imports.extend(get_imports(module))
imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n") f.write("\n".join(imports) + "\n")
@ -439,53 +403,27 @@ class Tool:
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init. others will be passed along to its init.
""" """
assert trust_remote_code, ( if not trust_remote_code:
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." raise ValueError(
) "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
)
hub_kwargs_names = [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"repo_type",
"subfolder",
"local_files_only",
]
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
tool_file = "tool.py"
# Get the tool's tool.py file. # Get the tool's tool.py file.
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) tool_file = hf_hub_download(
resolved_tool_file = cached_file(
repo_id, repo_id,
tool_file, "tool.py",
token=token, token=token,
**hub_kwargs, repo_type="space",
_raise_exceptions_for_gated_repo=False, cache_dir=kwargs.get("cache_dir"),
_raise_exceptions_for_missing_entries=False, force_download=kwargs.get("force_download"),
_raise_exceptions_for_connection_errors=False, resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
revision=kwargs.get("revision"),
subfolder=kwargs.get("subfolder"),
local_files_only=kwargs.get("local_files_only"),
) )
tool_code = resolved_tool_file is not None
if resolved_tool_file is None:
resolved_tool_file = cached_file(
repo_id,
tool_file,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_tool_file is None:
raise EnvironmentError(
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
)
with open(resolved_tool_file, encoding="utf-8") as reader: tool_code = Path(tool_file).read_text()
tool_code = "".join(reader.readlines())
# Find the Tool subclass in the namespace # Find the Tool subclass in the namespace
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
@ -613,7 +551,10 @@ class Tool:
def sanitize_argument_for_prediction(self, arg): def sanitize_argument_for_prediction(self, arg):
from gradio_client.utils import is_http_url_like from gradio_client.utils import is_http_url_like
if isinstance(arg, ImageType): if _is_pillow_available():
from PIL.Image import Image
if _is_pillow_available() and isinstance(arg, Image):
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name) arg.save(temp_file.name)
arg = temp_file.name arg = temp_file.name
@ -988,13 +929,13 @@ class PipelineTool(Tool):
- **model_class** (`type`) -- The class to use to load the model in this tool. - **model_class** (`type`) -- The class to use to load the model in this tool.
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the - **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
pre-processor pre-processor
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the - **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
post-processor (when different from the pre-processor). post-processor (when different from the pre-processor).
Args: Args:
model (`str` or [`PreTrainedModel`], *optional*): model (`str` or [`transformers.PreTrainedModel`], *optional*):
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
value of the class attribute `default_checkpoint`. value of the class attribute `default_checkpoint`.
pre_processor (`str` or `Any`, *optional*): pre_processor (`str` or `Any`, *optional*):
@ -1019,9 +960,9 @@ class PipelineTool(Tool):
Any additional keyword argument to send to the methods that will load the data from the Hub. Any additional keyword argument to send to the methods that will load the data from the Hub.
""" """
pre_processor_class = AutoProcessor pre_processor_class = None
model_class = None model_class = None
post_processor_class = AutoProcessor post_processor_class = None
default_checkpoint = None default_checkpoint = None
description = "This is a pipeline tool" description = "This is a pipeline tool"
name = "pipeline" name = "pipeline"
@ -1040,11 +981,10 @@ class PipelineTool(Tool):
token=None, token=None,
**hub_kwargs, **hub_kwargs,
): ):
if not is_torch_available(): if not is_torch_available() or not _is_package_available("accelerate"):
raise ImportError("Please install torch in order to use this tool.") raise ModuleNotFoundError(
"Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`"
if not is_accelerate_available(): )
raise ImportError("Please install accelerate in order to use this tool.")
if model is None: if model is None:
if self.default_checkpoint is None: if self.default_checkpoint is None:
@ -1071,6 +1011,10 @@ class PipelineTool(Tool):
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
""" """
if isinstance(self.pre_processor, str): if isinstance(self.pre_processor, str):
if self.pre_processor_class is None:
from transformers import AutoProcessor
self.pre_processor_class = AutoProcessor
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str): if isinstance(self.model, str):
@ -1079,12 +1023,18 @@ class PipelineTool(Tool):
if self.post_processor is None: if self.post_processor is None:
self.post_processor = self.pre_processor self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str): elif isinstance(self.post_processor, str):
if self.post_processor_class is None:
from transformers import AutoProcessor
self.post_processor_class = AutoProcessor
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None: if self.device is None:
if self.device_map is not None: if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0] self.device = list(self.model.hf_device_map.values())[0]
else: else:
from accelerate import PartialState
self.device = PartialState().default_device self.device = PartialState().default_device
if self.device_map is None: if self.device_map is None:
@ -1115,6 +1065,7 @@ class PipelineTool(Tool):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
import torch import torch
from accelerate.utils import send_to_device
args, kwargs = handle_agent_input_types(*args, **kwargs) args, kwargs = handle_agent_input_types(*args, **kwargs)

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib.util
import logging import logging
import os import os
import pathlib import pathlib
@ -22,26 +21,15 @@ from io import BytesIO
import numpy as np import numpy as np
import requests import requests
from transformers.utils import ( from huggingface_hub.utils import is_torch_available
is_torch_available, from PIL import Image
is_vision_available, from PIL.Image import Image as ImageType
)
from .utils import _is_package_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_vision_available():
from PIL import Image
from PIL.Image import Image as ImageType
else:
ImageType = object
if is_torch_available():
import torch
from torch import Tensor
else:
Tensor = object
class AgentType: class AgentType:
""" """
@ -94,9 +82,6 @@ class AgentImage(AgentType, ImageType):
AgentType.__init__(self, value) AgentType.__init__(self, value)
ImageType.__init__(self) ImageType.__init__(self)
if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.")
self._path = None self._path = None
self._raw = None self._raw = None
self._tensor = None self._tensor = None
@ -109,11 +94,15 @@ class AgentImage(AgentType, ImageType):
self._raw = Image.open(BytesIO(value)) self._raw = Image.open(BytesIO(value))
elif isinstance(value, (str, pathlib.Path)): elif isinstance(value, (str, pathlib.Path)):
self._path = value self._path = value
elif isinstance(value, torch.Tensor): elif is_torch_available():
self._tensor = value import torch
elif isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value) if isinstance(value, torch.Tensor):
else: self._tensor = value
if isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value)
if self._path is None and self._raw is None and self._tensor is None:
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
def _ipython_display_(self, include=None, exclude=None): def _ipython_display_(self, include=None, exclude=None):
@ -183,10 +172,12 @@ class AgentAudio(AgentType, str):
""" """
def __init__(self, value, samplerate=16_000): def __init__(self, value, samplerate=16_000):
if importlib.util.find_spec("soundfile") is None: if not _is_package_available("soundfile") or not is_torch_available:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
) )
import torch
super().__init__(value) super().__init__(value)
self._path = None self._path = None
@ -223,6 +214,8 @@ class AgentAudio(AgentType, str):
if self._tensor is not None: if self._tensor is not None:
return self._tensor return self._tensor
import torch
if self._path is not None: if self._path is not None:
if "://" in str(self._path): if "://" in str(self._path):
response = requests.get(self._path) response = requests.get(self._path)
@ -250,15 +243,7 @@ class AgentAudio(AgentType, str):
return self._path return self._path
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} _AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {
str: AgentText,
ImageType: AgentImage,
Tensor: AgentAudio,
}
if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
def handle_agent_input_types(*args, **kwargs): def handle_agent_input_types(*args, **kwargs):
@ -268,17 +253,22 @@ def handle_agent_input_types(*args, **kwargs):
def handle_agent_output_types(output, output_type=None): def handle_agent_output_types(output, output_type=None):
if output_type in AGENT_TYPE_MAPPING: if output_type in _AGENT_TYPE_MAPPING:
# If the class has defined outputs, we can map directly according to the class definition # If the class has defined outputs, we can map directly according to the class definition
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output) decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output)
return decoded_outputs return decoded_outputs
else:
# If the class does not have defined output, then we map according to the type # If the class does not have defined output, then we map according to the type
for _k, _v in INSTANCE_TYPE_MAPPING.items(): if isinstance(output, str):
if isinstance(output, _k): return AgentText(output)
if _k is not object: # avoid converting to audio if torch is not installed if isinstance(output, ImageType):
return _v(output) return AgentImage(output)
return output if is_torch_available():
import torch
if isinstance(output, torch.Tensor):
return AgentAudio(output)
return output
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]

View File

@ -15,18 +15,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast import ast
import importlib.metadata
import importlib.util import importlib.util
import inspect import inspect
import json import json
import re import re
import types import types
from functools import lru_cache
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from rich.console import Console from rich.console import Console
def is_pygments_available(): @lru_cache
return importlib.util.find_spec("soundfile") is not None def _is_package_available(package_name: str) -> bool:
try:
importlib.metadata.version(package_name)
return True
except importlib.metadata.PackageNotFoundError:
return False
@lru_cache
def _is_pillow_available():
return importlib.util.find_spec("PIL") is not None
console = Console() console = Console()

View File

@ -17,7 +17,7 @@ import unittest
import pytest import pytest
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
from smolagents.types import AGENT_TYPE_MAPPING from smolagents.types import _AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin
@ -46,7 +46,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = ["2 * 2"] inputs = ["2 * 2"]
output = self.tool(*inputs, sanitize_inputs_outputs=True) output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self): def test_agent_types_inputs(self):
@ -56,13 +56,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
for _input, expected_input in zip(inputs, self.tool.inputs.values()): for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"] input_type = expected_input["type"]
if isinstance(input_type, list): if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) _inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
else: else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(_AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error # Should not raise an error
output = self.tool(*inputs, sanitize_inputs_outputs=True) output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))
def test_imports_work(self): def test_imports_work(self):

View File

@ -22,7 +22,7 @@ from transformers import is_torch_available
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from smolagents.default_tools import FinalAnswerTool from smolagents.default_tools import FinalAnswerTool
from smolagents.types import AGENT_TYPE_MAPPING from smolagents.types import _AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin
@ -55,5 +55,5 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
inputs = self.create_inputs() inputs = self.create_inputs()
for input_type, input in inputs.items(): for input_type, input in inputs.items():
output = self.tool(**input, sanitize_inputs_outputs=True) output = self.tool(**input, sanitize_inputs_outputs=True)
agent_type = AGENT_TYPE_MAPPING[input_type] agent_type = _AGENT_TYPE_MAPPING[input_type]
self.assertTrue(isinstance(output, agent_type)) self.assertTrue(isinstance(output, agent_type))

View File

@ -21,16 +21,12 @@ from unittest.mock import MagicMock, patch
import mcp import mcp
import numpy as np import numpy as np
import pytest import pytest
import torch
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
from smolagents.types import ( from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
AGENT_TYPE_MAPPING,
AgentAudio,
AgentImage,
AgentText,
)
if is_torch_available(): if is_torch_available():
@ -96,7 +92,7 @@ class ToolTesterMixin:
inputs = create_inputs(self.tool.inputs) inputs = create_inputs(self.tool.inputs)
output = self.tool(**inputs, sanitize_inputs_outputs=True) output = self.tool(**inputs, sanitize_inputs_outputs=True)
if self.tool.output_type != "any": if self.tool.output_type != "any":
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type)) self.assertTrue(isinstance(output, agent_type))

View File

@ -121,4 +121,3 @@ class AgentTextTests(unittest.TestCase):
self.assertEqual(string, agent_type.to_string()) self.assertEqual(string, agent_type.to_string())
self.assertEqual(string, agent_type.to_raw()) self.assertEqual(string, agent_type.to_raw())
self.assertEqual(string, agent_type)