From d19ebc7a481346bf884791588a7215cbe3c3e661 Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 20 Jan 2025 10:59:43 +0100 Subject: [PATCH] Make import time faster (optional deps + delay imports) (#253) * adapt docs * optional in pyproject.toml * get rid of some transformers imports * optional transformers in models.py * gradio, transformers, litellm * small refacto AgentType * merge conflicts * mouaif * fix tests * AgentText no longer a str * Add back AgentType as str/Image * fixed for good --- docs/source/en/guided_tour.md | 3 + docs/source/en/reference/agents.md | 6 + docs/source/zh/guided_tour.md | 3 + docs/source/zh/reference/agents.md | 6 + pyproject.toml | 21 +- src/smolagents/__init__.py | 44 +-- src/smolagents/_transformers_utils.py | 388 ++++++++++++++++++++++++++ src/smolagents/default_tools.py | 64 ++--- src/smolagents/gradio_ui.py | 19 +- src/smolagents/models.py | 32 ++- src/smolagents/tools.py | 147 ++++------ src/smolagents/types.py | 78 +++--- src/smolagents/utils.py | 16 +- tests/test_default_tools.py | 10 +- tests/test_final_answer.py | 4 +- tests/test_tools.py | 10 +- tests/test_types.py | 1 - 17 files changed, 593 insertions(+), 259 deletions(-) create mode 100644 src/smolagents/_transformers_utils.py diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md index dd6a821..ee91553 100644 --- a/docs/source/en/guided_tour.md +++ b/docs/source/en/guided_tour.md @@ -55,6 +55,7 @@ agent.run( ```python +# !pip install smolagents[transformers] from smolagents import CodeAgent, TransformersModel model_id = "meta-llama/Llama-3.2-3B-Instruct" @@ -72,6 +73,7 @@ agent.run( To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization. ```python +# !pip install smolagents[litellm] from smolagents import CodeAgent, LiteLLMModel model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o' @@ -85,6 +87,7 @@ agent.run( ```python +# !pip install smolagents[litellm] from smolagents import CodeAgent, LiteLLMModel model = LiteLLMModel( diff --git a/docs/source/en/reference/agents.md b/docs/source/en/reference/agents.md index 76b2ecb..60087bb 100644 --- a/docs/source/en/reference/agents.md +++ b/docs/source/en/reference/agents.md @@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization. ### GradioUI +> [!TIP] +> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. + [[autodoc]] GradioUI ## Models @@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) >>> What a ``` +> [!TIP] +> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. + [[autodoc]] TransformersModel ### HfApiModel diff --git a/docs/source/zh/guided_tour.md b/docs/source/zh/guided_tour.md index d8a30da..537e594 100644 --- a/docs/source/zh/guided_tour.md +++ b/docs/source/zh/guided_tour.md @@ -61,6 +61,7 @@ agent.run( ```python +# !pip install smolagents[transformers] from smolagents import CodeAgent, TransformersModel model_id = "meta-llama/Llama-3.2-3B-Instruct" @@ -78,6 +79,7 @@ agent.run( 要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY` 或 `OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。 ```python +# !pip install smolagents[litellm] from smolagents import CodeAgent, LiteLLMModel model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o' @@ -91,6 +93,7 @@ agent.run( ```python +# !pip install smolagents[litellm] from smolagents import CodeAgent, LiteLLMModel model = LiteLLMModel( diff --git a/docs/source/zh/reference/agents.md b/docs/source/zh/reference/agents.md index dc011d3..3b05a6d 100644 --- a/docs/source/zh/reference/agents.md +++ b/docs/source/zh/reference/agents.md @@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization. ### GradioUI +> [!TIP] +> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. + [[autodoc]] GradioUI ## Models @@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) >>> What a ``` +> [!TIP] +> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. + [[autodoc]] TransformersModel ### HfApiModel diff --git a/pyproject.toml b/pyproject.toml index 1cee73d..f968afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,14 +12,12 @@ authors = [ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "transformers>=4.0.0", "requests>=2.32.3", "rich>=13.9.4", "pandas>=2.2.3", "jinja2>=3.1.4", "pillow>=11.0.0", "markdownify>=0.14.1", - "gradio>=5.8.0", "duckduckgo-search>=6.3.7", "python-dotenv>=1.0.1", "e2b-code-interpreter>=1.0.3", @@ -27,12 +25,20 @@ dependencies = [ ] [project.optional-dependencies] -audio = [ - "soundfile", -] torch = [ "torch", +] +audio = [ + "soundfile", + "smolagents[torch]", +] +transformers = [ "accelerate", + "transformers>=4.0.0", + "smolagents[torch]", +] +gradio = [ + "gradio>=5.8.0", ] litellm = [ "litellm>=1.55.10", @@ -47,9 +53,12 @@ openai = [ quality = [ "ruff>=0.9.0", ] +all = [ + "smolagents[accelerate,audio,gradio,litellm,mcp,openai,transformers]", +] test = [ "pytest>=8.1.0", - "smolagents[audio,litellm,mcp,openai,torch]", + "smolagents[all]", ] dev = [ "smolagents[quality,test]", diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py index 93ea574..f3b79a2 100644 --- a/src/smolagents/__init__.py +++ b/src/smolagents/__init__.py @@ -16,36 +16,14 @@ # limitations under the License. __version__ = "1.5.0.dev" -from typing import TYPE_CHECKING - -from transformers.utils import _LazyModule -from transformers.utils.import_utils import define_import_structure - - -if TYPE_CHECKING: - from .agents import * - from .default_tools import * - from .e2b_executor import * - from .gradio_ui import * - from .local_python_executor import * - from .models import * - from .monitoring import * - from .prompts import * - from .tools import * - from .types import * - from .utils import * - - -else: - import sys - - _file = globals()["__file__"] - import_structure = define_import_structure(_file) - import_structure[""] = {"__version__": __version__} - sys.modules[__name__] = _LazyModule( - __name__, - _file, - import_structure, - module_spec=__spec__, - extra_objects={"__version__": __version__}, - ) +from .agents import * +from .default_tools import * +from .e2b_executor import * +from .gradio_ui import * +from .local_python_executor import * +from .models import * +from .monitoring import * +from .prompts import * +from .tools import * +from .types import * +from .utils import * diff --git a/src/smolagents/_transformers_utils.py b/src/smolagents/_transformers_utils.py new file mode 100644 index 0000000..fcbcf26 --- /dev/null +++ b/src/smolagents/_transformers_utils.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains utilities exclusively taken from `transformers` repository. + +Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have +been duplicated. + +TODO: move them to `huggingface_hub` to avoid code duplication. +""" + +import inspect +import json +import os +import re +import types +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + get_args, + get_origin, + get_type_hints, +) + +from huggingface_hub.utils import is_torch_available + +from .utils import _is_pillow_available + + +def get_imports(filename: Union[str, os.PathLike]) -> List[str]: + """ + Extracts all the libraries (not relative imports this time) that are imported in a file. + + Args: + filename (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of all packages required to use the input module. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # filter out try/except block so in custom code we can have try/except imports + content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL) + + # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment + content = re.sub( + r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", + "", + content, + flags=re.MULTILINE, + ) + + # Imports of the form `import xxx` + imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + return list(set(imports)) + + +class TypeHintParsingException(Exception): + """Exception raised for errors in parsing type hints to generate JSON schemas""" + + +class DocstringParsingException(Exception): + """Exception raised for errors in parsing docstrings to generate JSON schemas""" + + +def get_json_schema(func: Callable) -> Dict: + """ + This function generates a JSON schema for a given function, based on its docstring and type hints. This is + mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of + the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires + that the function has a docstring, and that each argument has a description in the docstring, in the standard + Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. + + Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is + optional because most chat templates ignore the return value of the function. + + Args: + func: The function to generate a JSON schema for. + + Returns: + A dictionary containing the JSON schema for the function. + + Examples: + ```python + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply + >>> ''' + >>> return x * y + >>> + >>> print(get_json_schema(multiply)) + { + "name": "multiply", + "description": "A function that multiplies two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "The first number to multiply"}, + "y": {"type": "number", "description": "The second number to multiply"} + }, + "required": ["x", "y"] + } + } + ``` + + The general use for these schemas is that they are used to generate tool descriptions for chat templates that + support them, like so: + + ```python + >>> from transformers import AutoTokenizer + >>> from transformers.utils import get_json_schema + >>> + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply + >>> return x * y + >>> ''' + >>> + >>> multiply_schema = get_json_schema(multiply) + >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] + >>> formatted_chat = tokenizer.apply_chat_template( + >>> messages, + >>> tools=[multiply_schema], + >>> chat_template="tool_use", + >>> return_dict=True, + >>> return_tensors="pt", + >>> add_generation_prompt=True + >>> ) + >>> # The formatted chat can now be passed to model.generate() + ``` + + Each argument description can also have an optional `(choices: ...)` block at the end, such as + `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will + only be parsed correctly if it is at the end of the line: + + ```python + >>> def drink_beverage(beverage: str): + >>> ''' + >>> A function that drinks a beverage + >>> + >>> Args: + >>> beverage: The beverage to drink (choices: ["tea", "coffee"]) + >>> ''' + >>> pass + >>> + >>> print(get_json_schema(drink_beverage)) + ``` + { + 'name': 'drink_beverage', + 'description': 'A function that drinks a beverage', + 'parameters': { + 'type': 'object', + 'properties': { + 'beverage': { + 'type': 'string', + 'enum': ['tea', 'coffee'], + 'description': 'The beverage to drink' + } + }, + 'required': ['beverage'] + } + } + """ + doc = inspect.getdoc(func) + if not doc: + raise DocstringParsingException( + f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" + ) + doc = doc.strip() + main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc) + + json_schema = _convert_type_hints_to_json_schema(func) + if (return_dict := json_schema["properties"].pop("return", None)) is not None: + if return_doc is not None: # We allow a missing return docstring since most templates ignore it + return_dict["description"] = return_doc + for arg, schema in json_schema["properties"].items(): + if arg not in param_descriptions: + raise DocstringParsingException( + f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + ) + desc = param_descriptions[arg] + enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) + if enum_choices: + schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] + desc = enum_choices.string[: enum_choices.start()].strip() + schema["description"] = desc + + output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + if return_dict is not None: + output["return"] = return_dict + return {"type": "function", "function": output} + + +# Extracts the initial segment of the docstring, containing the function description +description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) +# Extracts the Args: block from the docstring +args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) +# Splits the Args: block into individual arguments +args_split_re = re.compile( + r""" +(?:^|\n) # Match the start of the args block, or a newline +\s*(\w+):\s* # Capture the argument name and strip spacing +(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing +(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block +""", + re.DOTALL | re.VERBOSE, +) +# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! +returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) + + +def _parse_google_format_docstring( + docstring: str, +) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: + """ + Parses a Google-style docstring to extract the function description, + argument descriptions, and return description. + + Args: + docstring (str): The docstring to parse. + + Returns: + The function description, arguments, and return description. + """ + + # Extract the sections + description_match = description_re.search(docstring) + args_match = args_re.search(docstring) + returns_match = returns_re.search(docstring) + + # Clean and store the sections + description = description_match.group(1).strip() if description_match else None + docstring_args = args_match.group(1).strip() if args_match else None + returns = returns_match.group(1).strip() if returns_match else None + + # Parsing the arguments into a dictionary + if docstring_args is not None: + docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines + matches = args_split_re.findall(docstring_args) + args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} + else: + args_dict = {} + + return description, args_dict, returns + + +def _convert_type_hints_to_json_schema(func: Callable) -> Dict: + type_hints = get_type_hints(func) + signature = inspect.signature(func) + required = [] + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty: + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") + if param.default == inspect.Parameter.empty: + required.append(param_name) + + properties = {} + for param_name, param_type in type_hints.items(): + properties[param_name] = _parse_type_hint(param_type) + + schema = {"type": "object", "properties": properties} + if required: + schema["required"] = required + + return schema + + +def _parse_type_hint(hint: str) -> Dict: + origin = get_origin(hint) + args = get_args(hint) + + if origin is None: + try: + return _get_json_schema_type(hint) + except KeyError: + raise TypeHintParsingException( + "Couldn't parse this type hint, likely due to a custom class or object: ", + hint, + ) + + elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): + # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end + subtypes = [_parse_type_hint(t) for t in args if t is not type(None)] + if len(subtypes) == 1: + # A single non-null type can be expressed directly + return_dict = subtypes[0] + elif all(isinstance(subtype["type"], str) for subtype in subtypes): + # A union of basic types can be expressed as a list in the schema + return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])} + else: + # A union of more complex types requires "anyOf" + return_dict = {"anyOf": subtypes} + if type(None) in args: + return_dict["nullable"] = True + return return_dict + + elif origin is list: + if not args: + return {"type": "array"} + else: + # Lists can only have a single type argument, so recurse into it + return {"type": "array", "items": _parse_type_hint(args[0])} + + elif origin is tuple: + if not args: + return {"type": "array"} + if len(args) == 1: + raise TypeHintParsingException( + f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " + "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " + "more than one element, we recommend " + "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " + "pass the element directly." + ) + if ... in args: + raise TypeHintParsingException( + "Conversion of '...' is not supported in Tuple type hints. " + "Use List[] types for variable-length" + " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} + + elif origin is dict: + # The JSON equivalent to a dict is 'object', which mandates that all keys are strings + # However, we can specify the type of the dict values with "additionalProperties" + out = {"type": "object"} + if len(args) == 2: + out["additionalProperties"] = _parse_type_hint(args[1]) + return out + + raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + + +_BASE_TYPE_MAPPING = { + int: {"type": "integer"}, + float: {"type": "number"}, + str: {"type": "string"}, + bool: {"type": "boolean"}, + Any: {}, +} + + +def _get_json_schema_type(param_type: str) -> Dict[str, str]: + if param_type in _BASE_TYPE_MAPPING: + return _BASE_TYPE_MAPPING[param_type] + if str(param_type) == "Image" and _is_pillow_available(): + from PIL.Image import Image + + if param_type == Image: + return {"type": "image"} + if str(param_type) == "Tensor" and is_torch_available(): + from torch import Tensor + + if param_type == Tensor: + return {"type": "audio"} + return {"type": "object"} diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 6a0913e..876f36a 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -14,33 +14,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import re from dataclasses import dataclass from typing import Dict, Optional -from huggingface_hub import hf_hub_download, list_spaces -from transformers.utils import is_offline_mode, is_torch_available - from .local_python_executor import ( BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code, ) -from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool +from .tools import PipelineTool, Tool from .types import AgentAudio -if is_torch_available(): - from transformers.models.whisper import ( - WhisperForConditionalGeneration, - WhisperProcessor, - ) -else: - WhisperForConditionalGeneration = object - WhisperProcessor = object - - @dataclass class PreTool: name: str @@ -51,31 +37,6 @@ class PreTool: repo_id: str -def get_remote_tools(logger, organization="huggingface-tools"): - if is_offline_mode(): - logger.info("You are in offline mode, so remote tools are not available.") - return {} - - spaces = list_spaces(author=organization) - tools = {} - for space_info in spaces: - repo_id = space_info.id - resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") - with open(resolved_config_file, encoding="utf-8") as reader: - config = json.load(reader) - task = repo_id.split("/")[-1] - tools[config["name"]] = PreTool( - task=task, - description=config["description"], - repo_id=repo_id, - name=task, - inputs=config["inputs"], - output_type=config["output_type"], - ) - - return tools - - class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." @@ -150,10 +111,10 @@ class DuckDuckGoSearchTool(Tool): self.max_results = max_results try: from duckduckgo_search import DDGS - except ImportError: + except ImportError as e: raise ImportError( "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." - ) + ) from e self.ddgs = DDGS() def forward(self, query: str) -> str: @@ -259,10 +220,10 @@ class VisitWebpageTool(Tool): from requests.exceptions import RequestException from smolagents.utils import truncate_content - except ImportError: + except ImportError as e: raise ImportError( "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." - ) + ) from e try: # Send a GET request to the URL response = requests.get(url) @@ -286,9 +247,6 @@ class SpeechToTextTool(PipelineTool): default_checkpoint = "openai/whisper-large-v3-turbo" description = "This is a tool that transcribes an audio into text. It returns the transcribed text." name = "transcriber" - pre_processor_class = WhisperProcessor - model_class = WhisperForConditionalGeneration - inputs = { "audio": { "type": "audio", @@ -297,6 +255,18 @@ class SpeechToTextTool(PipelineTool): } output_type = "string" + def __new__(cls): + from transformers.models.whisper import ( + WhisperForConditionalGeneration, + WhisperProcessor, + ) + + if not hasattr(cls, "pre_processor_class"): + cls.pre_processor_class = WhisperProcessor + if not hasattr(cls, "model_class"): + cls.model_class = WhisperForConditionalGeneration + return super().__new__() + def encode(self, audio): audio = AgentAudio(audio).to_raw() return self.pre_processor(audio, return_tensors="pt") diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 6115fca..52f952b 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -19,14 +19,15 @@ import re import shutil from typing import Optional -import gradio as gr - from .agents import ActionStep, AgentStepLog, MultiStepAgent from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types +from .utils import _is_package_available def pull_messages_from_step(step_log: AgentStepLog): """Extract ChatMessage objects from agent steps""" + import gradio as gr + if isinstance(step_log, ActionStep): yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") if step_log.tool_calls is not None: @@ -57,6 +58,11 @@ def stream_to_gradio( additional_args: Optional[dict] = None, ): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" + if not _is_package_available("gradio"): + raise ModuleNotFoundError( + "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" + ) + import gradio as gr for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): for message in pull_messages_from_step(step_log): @@ -88,6 +94,10 @@ class GradioUI: """A one-line interface to launch your agent in Gradio""" def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None): + if not _is_package_available("gradio"): + raise ModuleNotFoundError( + "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" + ) self.agent = agent self.file_upload_folder = file_upload_folder if self.file_upload_folder is not None: @@ -95,6 +105,8 @@ class GradioUI: os.mkdir(file_upload_folder) def interact_with_agent(self, prompt, messages): + import gradio as gr + messages.append(gr.ChatMessage(role="user", content=prompt)) yield messages for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): @@ -115,6 +127,7 @@ class GradioUI: """ Handle file uploads, default allowed types are .pdf, .docx, and .txt """ + import gradio as gr if file is None: return gr.Textbox("No file uploaded", visible=True), file_uploads_log @@ -161,6 +174,8 @@ class GradioUI: ) def launch(self): + import gradio as gr + with gr.Blocks() as demo: stored_messages = gr.State([]) file_uploads_log = gr.State([]) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 94d8a57..30bcb1c 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -21,20 +21,18 @@ import random from copy import deepcopy from dataclasses import asdict, dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from huggingface_hub import InferenceClient -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - StoppingCriteria, - StoppingCriteriaList, - is_torch_available, -) +from huggingface_hub.utils import is_torch_available from .tools import Tool +from .utils import _is_package_available +if TYPE_CHECKING: + from transformers import StoppingCriteriaList + logger = logging.getLogger(__name__) DEFAULT_JSONAGENT_REGEX_GRAMMAR = { @@ -320,6 +318,9 @@ class TransformersModel(Model): This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. + > [!TIP] + > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. + Parameters: model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. @@ -358,9 +359,12 @@ class TransformersModel(Model): **kwargs, ): super().__init__() - if not is_torch_available(): - raise ImportError("Please install torch in order to use TransformersModel.") + if not is_torch_available() or not _is_package_available("transformers"): + raise ModuleNotFoundError( + "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" + ) import torch + from transformers import AutoModelForCausalLM, AutoTokenizer default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" if model_id is None: @@ -387,7 +391,9 @@ class TransformersModel(Model): self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) - def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: + def make_stopping_criteria(self, stop_sequences: List[str]) -> "StoppingCriteriaList": + from transformers import StoppingCriteria, StoppingCriteriaList + class StopOnStrings(StoppingCriteria): def __init__(self, stop_strings: List[str], tokenizer): self.stop_strings = stop_strings @@ -491,6 +497,7 @@ class LiteLLMModel(Model): raise ModuleNotFoundError( "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" ) + super().__init__() self.model_id = model_id # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs @@ -506,9 +513,10 @@ class LiteLLMModel(Model): grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: - messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) import litellm + messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) + if tools_to_call_from: response = litellm.completion( model=self.model_id, diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 6af27e4..b9f4141 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -35,54 +35,22 @@ from huggingface_hub import ( metadata_update, upload_folder, ) -from huggingface_hub.utils import RepositoryNotFoundError +from huggingface_hub.utils import is_torch_available from packaging import version -from transformers.dynamic_module_utils import get_imports -from transformers.utils import ( - TypeHintParsingException, - cached_file, - get_json_schema, - is_accelerate_available, - is_torch_available, -) -from transformers.utils.chat_template_utils import _parse_type_hint +from ._transformers_utils import ( + TypeHintParsingException, + _parse_type_hint, + get_imports, + get_json_schema, +) from .tool_validation import MethodChecker, validate_tool_attributes -from .types import ImageType, handle_agent_input_types, handle_agent_output_types -from .utils import instance_to_source +from .types import handle_agent_input_types, handle_agent_output_types +from .utils import _is_package_available, _is_pillow_available, instance_to_source logger = logging.getLogger(__name__) -if is_accelerate_available(): - from accelerate import PartialState - from accelerate.utils import send_to_device - -if is_torch_available(): - from transformers import AutoProcessor -else: - AutoProcessor = object - -TOOL_CONFIG_FILE = "tool_config.json" - - -def get_repo_type(repo_id, repo_type=None, **hub_kwargs): - if repo_type is not None: - return repo_type - try: - hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs) - return "space" - except RepositoryNotFoundError: - try: - hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) - return "model" - except RepositoryNotFoundError: - raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.") - except Exception: - return "model" - except Exception: - return "space" - def validate_after_init(cls): original_init = cls.__init__ @@ -337,12 +305,8 @@ class Tool: ) # Save requirements file + imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"} requirements_file = os.path.join(output_dir, "requirements.txt") - - imports = [] - for module in [tool_file]: - imports.extend(get_imports(module)) - imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names])) with open(requirements_file, "w", encoding="utf-8") as f: f.write("\n".join(imports) + "\n") @@ -439,53 +403,27 @@ class Tool: `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others will be passed along to its init. """ - assert trust_remote_code, ( - "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." - ) - - hub_kwargs_names = [ - "cache_dir", - "force_download", - "resume_download", - "proxies", - "revision", - "repo_type", - "subfolder", - "local_files_only", - ] - hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} - - tool_file = "tool.py" + if not trust_remote_code: + raise ValueError( + "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." + ) # Get the tool's tool.py file. - hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) - resolved_tool_file = cached_file( + tool_file = hf_hub_download( repo_id, - tool_file, + "tool.py", token=token, - **hub_kwargs, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, + repo_type="space", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download"), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + revision=kwargs.get("revision"), + subfolder=kwargs.get("subfolder"), + local_files_only=kwargs.get("local_files_only"), ) - tool_code = resolved_tool_file is not None - if resolved_tool_file is None: - resolved_tool_file = cached_file( - repo_id, - tool_file, - token=token, - **hub_kwargs, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - ) - if resolved_tool_file is None: - raise EnvironmentError( - f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." - ) - with open(resolved_tool_file, encoding="utf-8") as reader: - tool_code = "".join(reader.readlines()) + tool_code = Path(tool_file).read_text() # Find the Tool subclass in the namespace with tempfile.TemporaryDirectory() as temp_dir: @@ -613,7 +551,10 @@ class Tool: def sanitize_argument_for_prediction(self, arg): from gradio_client.utils import is_http_url_like - if isinstance(arg, ImageType): + if _is_pillow_available(): + from PIL.Image import Image + + if _is_pillow_available() and isinstance(arg, Image): temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) arg.save(temp_file.name) arg = temp_file.name @@ -988,13 +929,13 @@ class PipelineTool(Tool): - **model_class** (`type`) -- The class to use to load the model in this tool. - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. - - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the + - **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the pre-processor - - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the + - **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the post-processor (when different from the pre-processor). Args: - model (`str` or [`PreTrainedModel`], *optional*): + model (`str` or [`transformers.PreTrainedModel`], *optional*): The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the value of the class attribute `default_checkpoint`. pre_processor (`str` or `Any`, *optional*): @@ -1019,9 +960,9 @@ class PipelineTool(Tool): Any additional keyword argument to send to the methods that will load the data from the Hub. """ - pre_processor_class = AutoProcessor + pre_processor_class = None model_class = None - post_processor_class = AutoProcessor + post_processor_class = None default_checkpoint = None description = "This is a pipeline tool" name = "pipeline" @@ -1040,11 +981,10 @@ class PipelineTool(Tool): token=None, **hub_kwargs, ): - if not is_torch_available(): - raise ImportError("Please install torch in order to use this tool.") - - if not is_accelerate_available(): - raise ImportError("Please install accelerate in order to use this tool.") + if not is_torch_available() or not _is_package_available("accelerate"): + raise ModuleNotFoundError( + "Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`" + ) if model is None: if self.default_checkpoint is None: @@ -1071,6 +1011,10 @@ class PipelineTool(Tool): Instantiates the `pre_processor`, `model` and `post_processor` if necessary. """ if isinstance(self.pre_processor, str): + if self.pre_processor_class is None: + from transformers import AutoProcessor + + self.pre_processor_class = AutoProcessor self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) if isinstance(self.model, str): @@ -1079,12 +1023,18 @@ class PipelineTool(Tool): if self.post_processor is None: self.post_processor = self.pre_processor elif isinstance(self.post_processor, str): + if self.post_processor_class is None: + from transformers import AutoProcessor + + self.post_processor_class = AutoProcessor self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) if self.device is None: if self.device_map is not None: self.device = list(self.model.hf_device_map.values())[0] else: + from accelerate import PartialState + self.device = PartialState().default_device if self.device_map is None: @@ -1115,6 +1065,7 @@ class PipelineTool(Tool): def __call__(self, *args, **kwargs): import torch + from accelerate.utils import send_to_device args, kwargs = handle_agent_input_types(*args, **kwargs) diff --git a/src/smolagents/types.py b/src/smolagents/types.py index e18de51..7077daa 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util import logging import os import pathlib @@ -22,26 +21,15 @@ from io import BytesIO import numpy as np import requests -from transformers.utils import ( - is_torch_available, - is_vision_available, -) +from huggingface_hub.utils import is_torch_available +from PIL import Image +from PIL.Image import Image as ImageType + +from .utils import _is_package_available logger = logging.getLogger(__name__) -if is_vision_available(): - from PIL import Image - from PIL.Image import Image as ImageType -else: - ImageType = object - -if is_torch_available(): - import torch - from torch import Tensor -else: - Tensor = object - class AgentType: """ @@ -94,9 +82,6 @@ class AgentImage(AgentType, ImageType): AgentType.__init__(self, value) ImageType.__init__(self) - if not is_vision_available(): - raise ImportError("PIL must be installed in order to handle images.") - self._path = None self._raw = None self._tensor = None @@ -109,11 +94,15 @@ class AgentImage(AgentType, ImageType): self._raw = Image.open(BytesIO(value)) elif isinstance(value, (str, pathlib.Path)): self._path = value - elif isinstance(value, torch.Tensor): - self._tensor = value - elif isinstance(value, np.ndarray): - self._tensor = torch.from_numpy(value) - else: + elif is_torch_available(): + import torch + + if isinstance(value, torch.Tensor): + self._tensor = value + if isinstance(value, np.ndarray): + self._tensor = torch.from_numpy(value) + + if self._path is None and self._raw is None and self._tensor is None: raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") def _ipython_display_(self, include=None, exclude=None): @@ -183,10 +172,12 @@ class AgentAudio(AgentType, str): """ def __init__(self, value, samplerate=16_000): - if importlib.util.find_spec("soundfile") is None: + if not _is_package_available("soundfile") or not is_torch_available: raise ModuleNotFoundError( "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" ) + import torch + super().__init__(value) self._path = None @@ -223,6 +214,8 @@ class AgentAudio(AgentType, str): if self._tensor is not None: return self._tensor + import torch + if self._path is not None: if "://" in str(self._path): response = requests.get(self._path) @@ -250,15 +243,7 @@ class AgentAudio(AgentType, str): return self._path -AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} -INSTANCE_TYPE_MAPPING = { - str: AgentText, - ImageType: AgentImage, - Tensor: AgentAudio, -} - -if is_torch_available(): - INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio +_AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} def handle_agent_input_types(*args, **kwargs): @@ -268,17 +253,22 @@ def handle_agent_input_types(*args, **kwargs): def handle_agent_output_types(output, output_type=None): - if output_type in AGENT_TYPE_MAPPING: + if output_type in _AGENT_TYPE_MAPPING: # If the class has defined outputs, we can map directly according to the class definition - decoded_outputs = AGENT_TYPE_MAPPING[output_type](output) + decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output) return decoded_outputs - else: - # If the class does not have defined output, then we map according to the type - for _k, _v in INSTANCE_TYPE_MAPPING.items(): - if isinstance(output, _k): - if _k is not object: # avoid converting to audio if torch is not installed - return _v(output) - return output + + # If the class does not have defined output, then we map according to the type + if isinstance(output, str): + return AgentText(output) + if isinstance(output, ImageType): + return AgentImage(output) + if is_torch_available(): + import torch + + if isinstance(output, torch.Tensor): + return AgentAudio(output) + return output __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"] diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 196f21c..30bcf61 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -15,18 +15,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast +import importlib.metadata import importlib.util import inspect import json import re import types +from functools import lru_cache from typing import Dict, Tuple, Union from rich.console import Console -def is_pygments_available(): - return importlib.util.find_spec("soundfile") is not None +@lru_cache +def _is_package_available(package_name: str) -> bool: + try: + importlib.metadata.version(package_name) + return True + except importlib.metadata.PackageNotFoundError: + return False + + +@lru_cache +def _is_pillow_available(): + return importlib.util.find_spec("PIL") is not None console = Console() diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py index d92b387..91c40c6 100644 --- a/tests/test_default_tools.py +++ b/tests/test_default_tools.py @@ -17,7 +17,7 @@ import unittest import pytest from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool -from smolagents.types import AGENT_TYPE_MAPPING +from smolagents.types import _AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin @@ -46,7 +46,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): def test_agent_type_output(self): inputs = ["2 * 2"] output = self.tool(*inputs, sanitize_inputs_outputs=True) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] + output_type = _AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, output_type)) def test_agent_types_inputs(self): @@ -56,13 +56,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): for _input, expected_input in zip(inputs, self.tool.inputs.values()): input_type = expected_input["type"] if isinstance(input_type, list): - _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) + _inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) else: - _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + _inputs.append(_AGENT_TYPE_MAPPING[input_type](_input)) # Should not raise an error output = self.tool(*inputs, sanitize_inputs_outputs=True) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] + output_type = _AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, output_type)) def test_imports_work(self): diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 89f4ffc..7bb1e5e 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -22,7 +22,7 @@ from transformers import is_torch_available from transformers.testing_utils import get_tests_dir, require_torch from smolagents.default_tools import FinalAnswerTool -from smolagents.types import AGENT_TYPE_MAPPING +from smolagents.types import _AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin @@ -55,5 +55,5 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): inputs = self.create_inputs() for input_type, input in inputs.items(): output = self.tool(**input, sanitize_inputs_outputs=True) - agent_type = AGENT_TYPE_MAPPING[input_type] + agent_type = _AGENT_TYPE_MAPPING[input_type] self.assertTrue(isinstance(output, agent_type)) diff --git a/tests/test_tools.py b/tests/test_tools.py index 917bcf1..67bd2f6 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -21,16 +21,12 @@ from unittest.mock import MagicMock, patch import mcp import numpy as np import pytest +import torch from transformers import is_torch_available, is_vision_available from transformers.testing_utils import get_tests_dir from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool -from smolagents.types import ( - AGENT_TYPE_MAPPING, - AgentAudio, - AgentImage, - AgentText, -) +from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText if is_torch_available(): @@ -96,7 +92,7 @@ class ToolTesterMixin: inputs = create_inputs(self.tool.inputs) output = self.tool(**inputs, sanitize_inputs_outputs=True) if self.tool.output_type != "any": - agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] + agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, agent_type)) diff --git a/tests/test_types.py b/tests/test_types.py index 244875c..9350da1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -121,4 +121,3 @@ class AgentTextTests(unittest.TestCase): self.assertEqual(string, agent_type.to_string()) self.assertEqual(string, agent_type.to_raw()) - self.assertEqual(string, agent_type)