Make import time faster (optional deps + delay imports) (#253)
* adapt docs * optional in pyproject.toml * get rid of some transformers imports * optional transformers in models.py * gradio, transformers, litellm * small refacto AgentType * merge conflicts * mouaif * fix tests * AgentText no longer a str * Add back AgentType as str/Image * fixed for good
This commit is contained in:
		
							parent
							
								
									a2ca95107f
								
							
						
					
					
						commit
						d19ebc7a48
					
				|  | @ -55,6 +55,7 @@ agent.run( | ||||||
| <hfoption id="Local Transformers Model"> | <hfoption id="Local Transformers Model"> | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[transformers] | ||||||
| from smolagents import CodeAgent, TransformersModel | from smolagents import CodeAgent, TransformersModel | ||||||
| 
 | 
 | ||||||
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | model_id = "meta-llama/Llama-3.2-3B-Instruct" | ||||||
|  | @ -72,6 +73,7 @@ agent.run( | ||||||
| To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization. | To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[litellm] | ||||||
| from smolagents import CodeAgent, LiteLLMModel | from smolagents import CodeAgent, LiteLLMModel | ||||||
| 
 | 
 | ||||||
| model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o' | model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o' | ||||||
|  | @ -85,6 +87,7 @@ agent.run( | ||||||
| <hfoption id="Ollama"> | <hfoption id="Ollama"> | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[litellm] | ||||||
| from smolagents import CodeAgent, LiteLLMModel | from smolagents import CodeAgent, LiteLLMModel | ||||||
| 
 | 
 | ||||||
| model = LiteLLMModel( | model = LiteLLMModel( | ||||||
|  |  | ||||||
|  | @ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization. | ||||||
| 
 | 
 | ||||||
| ### GradioUI | ### GradioUI | ||||||
| 
 | 
 | ||||||
|  | > [!TIP] | ||||||
|  | > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. | ||||||
|  | 
 | ||||||
| [[autodoc]] GradioUI | [[autodoc]] GradioUI | ||||||
| 
 | 
 | ||||||
| ## Models | ## Models | ||||||
|  | @ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) | ||||||
| >>> What a | >>> What a | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | > [!TIP] | ||||||
|  | > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. | ||||||
|  | 
 | ||||||
| [[autodoc]] TransformersModel | [[autodoc]] TransformersModel | ||||||
| 
 | 
 | ||||||
| ### HfApiModel | ### HfApiModel | ||||||
|  |  | ||||||
|  | @ -61,6 +61,7 @@ agent.run( | ||||||
| <hfoption id="本地Transformers模型"> | <hfoption id="本地Transformers模型"> | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[transformers] | ||||||
| from smolagents import CodeAgent, TransformersModel | from smolagents import CodeAgent, TransformersModel | ||||||
| 
 | 
 | ||||||
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | model_id = "meta-llama/Llama-3.2-3B-Instruct" | ||||||
|  | @ -78,6 +79,7 @@ agent.run( | ||||||
| 要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY` 或 `OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。 | 要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY` 或 `OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。 | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[litellm] | ||||||
| from smolagents import CodeAgent, LiteLLMModel | from smolagents import CodeAgent, LiteLLMModel | ||||||
| 
 | 
 | ||||||
| model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o' | model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o' | ||||||
|  | @ -91,6 +93,7 @@ agent.run( | ||||||
| <hfoption id="Ollama"> | <hfoption id="Ollama"> | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
|  | # !pip install smolagents[litellm] | ||||||
| from smolagents import CodeAgent, LiteLLMModel | from smolagents import CodeAgent, LiteLLMModel | ||||||
| 
 | 
 | ||||||
| model = LiteLLMModel( | model = LiteLLMModel( | ||||||
|  |  | ||||||
|  | @ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization. | ||||||
| 
 | 
 | ||||||
| ### GradioUI | ### GradioUI | ||||||
| 
 | 
 | ||||||
|  | > [!TIP] | ||||||
|  | > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. | ||||||
|  | 
 | ||||||
| [[autodoc]] GradioUI | [[autodoc]] GradioUI | ||||||
| 
 | 
 | ||||||
| ## Models | ## Models | ||||||
|  | @ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) | ||||||
| >>> What a | >>> What a | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | > [!TIP] | ||||||
|  | > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. | ||||||
|  | 
 | ||||||
| [[autodoc]] TransformersModel | [[autodoc]] TransformersModel | ||||||
| 
 | 
 | ||||||
| ### HfApiModel | ### HfApiModel | ||||||
|  |  | ||||||
|  | @ -12,14 +12,12 @@ authors = [ | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| requires-python = ">=3.10" | requires-python = ">=3.10" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|   "transformers>=4.0.0", |  | ||||||
|   "requests>=2.32.3", |   "requests>=2.32.3", | ||||||
|   "rich>=13.9.4", |   "rich>=13.9.4", | ||||||
|   "pandas>=2.2.3", |   "pandas>=2.2.3", | ||||||
|   "jinja2>=3.1.4", |   "jinja2>=3.1.4", | ||||||
|   "pillow>=11.0.0", |   "pillow>=11.0.0", | ||||||
|   "markdownify>=0.14.1", |   "markdownify>=0.14.1", | ||||||
|   "gradio>=5.8.0", |  | ||||||
|   "duckduckgo-search>=6.3.7", |   "duckduckgo-search>=6.3.7", | ||||||
|   "python-dotenv>=1.0.1", |   "python-dotenv>=1.0.1", | ||||||
|   "e2b-code-interpreter>=1.0.3", |   "e2b-code-interpreter>=1.0.3", | ||||||
|  | @ -27,12 +25,20 @@ dependencies = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [project.optional-dependencies] | [project.optional-dependencies] | ||||||
| audio = [ |  | ||||||
|   "soundfile", |  | ||||||
| ] |  | ||||||
| torch = [ | torch = [ | ||||||
|   "torch", |   "torch", | ||||||
|  | ] | ||||||
|  | audio = [ | ||||||
|  |   "soundfile", | ||||||
|  |   "smolagents[torch]", | ||||||
|  | ] | ||||||
|  | transformers = [ | ||||||
|   "accelerate", |   "accelerate", | ||||||
|  |   "transformers>=4.0.0", | ||||||
|  |   "smolagents[torch]", | ||||||
|  | ] | ||||||
|  | gradio = [ | ||||||
|  |   "gradio>=5.8.0", | ||||||
| ] | ] | ||||||
| litellm = [ | litellm = [ | ||||||
|   "litellm>=1.55.10", |   "litellm>=1.55.10", | ||||||
|  | @ -47,9 +53,12 @@ openai = [ | ||||||
| quality = [ | quality = [ | ||||||
|   "ruff>=0.9.0", |   "ruff>=0.9.0", | ||||||
| ] | ] | ||||||
|  | all = [ | ||||||
|  |   "smolagents[accelerate,audio,gradio,litellm,mcp,openai,transformers]", | ||||||
|  | ] | ||||||
| test = [ | test = [ | ||||||
|   "pytest>=8.1.0", |   "pytest>=8.1.0", | ||||||
|   "smolagents[audio,litellm,mcp,openai,torch]", |   "smolagents[all]", | ||||||
| ] | ] | ||||||
| dev = [ | dev = [ | ||||||
|   "smolagents[quality,test]", |   "smolagents[quality,test]", | ||||||
|  |  | ||||||
|  | @ -16,13 +16,6 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| __version__ = "1.5.0.dev" | __version__ = "1.5.0.dev" | ||||||
| 
 | 
 | ||||||
| from typing import TYPE_CHECKING |  | ||||||
| 
 |  | ||||||
| from transformers.utils import _LazyModule |  | ||||||
| from transformers.utils.import_utils import define_import_structure |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if TYPE_CHECKING: |  | ||||||
| from .agents import * | from .agents import * | ||||||
| from .default_tools import * | from .default_tools import * | ||||||
| from .e2b_executor import * | from .e2b_executor import * | ||||||
|  | @ -34,18 +27,3 @@ if TYPE_CHECKING: | ||||||
| from .tools import * | from .tools import * | ||||||
| from .types import * | from .types import * | ||||||
| from .utils import * | from .utils import * | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| else: |  | ||||||
|     import sys |  | ||||||
| 
 |  | ||||||
|     _file = globals()["__file__"] |  | ||||||
|     import_structure = define_import_structure(_file) |  | ||||||
|     import_structure[""] = {"__version__": __version__} |  | ||||||
|     sys.modules[__name__] = _LazyModule( |  | ||||||
|         __name__, |  | ||||||
|         _file, |  | ||||||
|         import_structure, |  | ||||||
|         module_spec=__spec__, |  | ||||||
|         extra_objects={"__version__": __version__}, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  | @ -0,0 +1,388 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # coding=utf-8 | ||||||
|  | 
 | ||||||
|  | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | """This module contains utilities exclusively taken from `transformers` repository. | ||||||
|  | 
 | ||||||
|  | Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have | ||||||
|  | been duplicated. | ||||||
|  | 
 | ||||||
|  | TODO: move them to `huggingface_hub` to avoid code duplication. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import inspect | ||||||
|  | import json | ||||||
|  | import os | ||||||
|  | import re | ||||||
|  | import types | ||||||
|  | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  |     Dict, | ||||||
|  |     List, | ||||||
|  |     Optional, | ||||||
|  |     Tuple, | ||||||
|  |     Union, | ||||||
|  |     get_args, | ||||||
|  |     get_origin, | ||||||
|  |     get_type_hints, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | from huggingface_hub.utils import is_torch_available | ||||||
|  | 
 | ||||||
|  | from .utils import _is_pillow_available | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_imports(filename: Union[str, os.PathLike]) -> List[str]: | ||||||
|  |     """ | ||||||
|  |     Extracts all the libraries (not relative imports this time) that are imported in a file. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         filename (`str` or `os.PathLike`): The module file to inspect. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         `List[str]`: The list of all packages required to use the input module. | ||||||
|  |     """ | ||||||
|  |     with open(filename, "r", encoding="utf-8") as f: | ||||||
|  |         content = f.read() | ||||||
|  | 
 | ||||||
|  |     # filter out try/except block so in custom code we can have try/except imports | ||||||
|  |     content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL) | ||||||
|  | 
 | ||||||
|  |     # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment | ||||||
|  |     content = re.sub( | ||||||
|  |         r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", | ||||||
|  |         "", | ||||||
|  |         content, | ||||||
|  |         flags=re.MULTILINE, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Imports of the form `import xxx` | ||||||
|  |     imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) | ||||||
|  |     # Imports of the form `from xxx import yyy` | ||||||
|  |     imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) | ||||||
|  |     # Only keep the top-level module | ||||||
|  |     imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] | ||||||
|  |     return list(set(imports)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TypeHintParsingException(Exception): | ||||||
|  |     """Exception raised for errors in parsing type hints to generate JSON schemas""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DocstringParsingException(Exception): | ||||||
|  |     """Exception raised for errors in parsing docstrings to generate JSON schemas""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_json_schema(func: Callable) -> Dict: | ||||||
|  |     """ | ||||||
|  |     This function generates a JSON schema for a given function, based on its docstring and type hints. This is | ||||||
|  |     mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of | ||||||
|  |     the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires | ||||||
|  |     that the function has a docstring, and that each argument has a description in the docstring, in the standard | ||||||
|  |     Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. | ||||||
|  | 
 | ||||||
|  |     Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is | ||||||
|  |     optional because most chat templates ignore the return value of the function. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         func: The function to generate a JSON schema for. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         A dictionary containing the JSON schema for the function. | ||||||
|  | 
 | ||||||
|  |     Examples: | ||||||
|  |     ```python | ||||||
|  |     >>> def multiply(x: float, y: float): | ||||||
|  |     >>>    ''' | ||||||
|  |     >>>    A function that multiplies two numbers | ||||||
|  |     >>> | ||||||
|  |     >>>    Args: | ||||||
|  |     >>>        x: The first number to multiply | ||||||
|  |     >>>        y: The second number to multiply | ||||||
|  |     >>>    ''' | ||||||
|  |     >>>    return x * y | ||||||
|  |     >>> | ||||||
|  |     >>> print(get_json_schema(multiply)) | ||||||
|  |     { | ||||||
|  |         "name": "multiply", | ||||||
|  |         "description": "A function that multiplies two numbers", | ||||||
|  |         "parameters": { | ||||||
|  |             "type": "object", | ||||||
|  |             "properties": { | ||||||
|  |                 "x": {"type": "number", "description": "The first number to multiply"}, | ||||||
|  |                 "y": {"type": "number", "description": "The second number to multiply"} | ||||||
|  |             }, | ||||||
|  |             "required": ["x", "y"] | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     The general use for these schemas is that they are used to generate tool descriptions for chat templates that | ||||||
|  |     support them, like so: | ||||||
|  | 
 | ||||||
|  |     ```python | ||||||
|  |     >>> from transformers import AutoTokenizer | ||||||
|  |     >>> from transformers.utils import get_json_schema | ||||||
|  |     >>> | ||||||
|  |     >>> def multiply(x: float, y: float): | ||||||
|  |     >>>    ''' | ||||||
|  |     >>>    A function that multiplies two numbers | ||||||
|  |     >>> | ||||||
|  |     >>>    Args: | ||||||
|  |     >>>        x: The first number to multiply | ||||||
|  |     >>>        y: The second number to multiply | ||||||
|  |     >>>    return x * y | ||||||
|  |     >>>    ''' | ||||||
|  |     >>> | ||||||
|  |     >>> multiply_schema = get_json_schema(multiply) | ||||||
|  |     >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") | ||||||
|  |     >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] | ||||||
|  |     >>> formatted_chat = tokenizer.apply_chat_template( | ||||||
|  |     >>>     messages, | ||||||
|  |     >>>     tools=[multiply_schema], | ||||||
|  |     >>>     chat_template="tool_use", | ||||||
|  |     >>>     return_dict=True, | ||||||
|  |     >>>     return_tensors="pt", | ||||||
|  |     >>>     add_generation_prompt=True | ||||||
|  |     >>> ) | ||||||
|  |     >>> # The formatted chat can now be passed to model.generate() | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     Each argument description can also have an optional `(choices: ...)` block at the end, such as | ||||||
|  |     `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will | ||||||
|  |     only be parsed correctly if it is at the end of the line: | ||||||
|  | 
 | ||||||
|  |     ```python | ||||||
|  |     >>> def drink_beverage(beverage: str): | ||||||
|  |     >>>    ''' | ||||||
|  |     >>>    A function that drinks a beverage | ||||||
|  |     >>> | ||||||
|  |     >>>    Args: | ||||||
|  |     >>>        beverage: The beverage to drink (choices: ["tea", "coffee"]) | ||||||
|  |     >>>    ''' | ||||||
|  |     >>>    pass | ||||||
|  |     >>> | ||||||
|  |     >>> print(get_json_schema(drink_beverage)) | ||||||
|  |     ``` | ||||||
|  |     { | ||||||
|  |         'name': 'drink_beverage', | ||||||
|  |         'description': 'A function that drinks a beverage', | ||||||
|  |         'parameters': { | ||||||
|  |             'type': 'object', | ||||||
|  |             'properties': { | ||||||
|  |                 'beverage': { | ||||||
|  |                     'type': 'string', | ||||||
|  |                     'enum': ['tea', 'coffee'], | ||||||
|  |                     'description': 'The beverage to drink' | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|  |             'required': ['beverage'] | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     """ | ||||||
|  |     doc = inspect.getdoc(func) | ||||||
|  |     if not doc: | ||||||
|  |         raise DocstringParsingException( | ||||||
|  |             f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" | ||||||
|  |         ) | ||||||
|  |     doc = doc.strip() | ||||||
|  |     main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc) | ||||||
|  | 
 | ||||||
|  |     json_schema = _convert_type_hints_to_json_schema(func) | ||||||
|  |     if (return_dict := json_schema["properties"].pop("return", None)) is not None: | ||||||
|  |         if return_doc is not None:  # We allow a missing return docstring since most templates ignore it | ||||||
|  |             return_dict["description"] = return_doc | ||||||
|  |     for arg, schema in json_schema["properties"].items(): | ||||||
|  |         if arg not in param_descriptions: | ||||||
|  |             raise DocstringParsingException( | ||||||
|  |                 f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" | ||||||
|  |             ) | ||||||
|  |         desc = param_descriptions[arg] | ||||||
|  |         enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) | ||||||
|  |         if enum_choices: | ||||||
|  |             schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] | ||||||
|  |             desc = enum_choices.string[: enum_choices.start()].strip() | ||||||
|  |         schema["description"] = desc | ||||||
|  | 
 | ||||||
|  |     output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} | ||||||
|  |     if return_dict is not None: | ||||||
|  |         output["return"] = return_dict | ||||||
|  |     return {"type": "function", "function": output} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Extracts the initial segment of the docstring, containing the function description | ||||||
|  | description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) | ||||||
|  | # Extracts the Args: block from the docstring | ||||||
|  | args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) | ||||||
|  | # Splits the Args: block into individual arguments | ||||||
|  | args_split_re = re.compile( | ||||||
|  |     r""" | ||||||
|  | (?:^|\n)  # Match the start of the args block, or a newline | ||||||
|  | \s*(\w+):\s*  # Capture the argument name and strip spacing | ||||||
|  | (.*?)\s*  # Capture the argument description, which can span multiple lines, and strip trailing spacing | ||||||
|  | (?=\n\s*\w+:|\Z)  # Stop when you hit the next argument or the end of the block | ||||||
|  | """, | ||||||
|  |     re.DOTALL | re.VERBOSE, | ||||||
|  | ) | ||||||
|  | # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! | ||||||
|  | returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _parse_google_format_docstring( | ||||||
|  |     docstring: str, | ||||||
|  | ) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: | ||||||
|  |     """ | ||||||
|  |     Parses a Google-style docstring to extract the function description, | ||||||
|  |     argument descriptions, and return description. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         docstring (str): The docstring to parse. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         The function description, arguments, and return description. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     # Extract the sections | ||||||
|  |     description_match = description_re.search(docstring) | ||||||
|  |     args_match = args_re.search(docstring) | ||||||
|  |     returns_match = returns_re.search(docstring) | ||||||
|  | 
 | ||||||
|  |     # Clean and store the sections | ||||||
|  |     description = description_match.group(1).strip() if description_match else None | ||||||
|  |     docstring_args = args_match.group(1).strip() if args_match else None | ||||||
|  |     returns = returns_match.group(1).strip() if returns_match else None | ||||||
|  | 
 | ||||||
|  |     # Parsing the arguments into a dictionary | ||||||
|  |     if docstring_args is not None: | ||||||
|  |         docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()])  # Remove blank lines | ||||||
|  |         matches = args_split_re.findall(docstring_args) | ||||||
|  |         args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} | ||||||
|  |     else: | ||||||
|  |         args_dict = {} | ||||||
|  | 
 | ||||||
|  |     return description, args_dict, returns | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | ||||||
|  |     type_hints = get_type_hints(func) | ||||||
|  |     signature = inspect.signature(func) | ||||||
|  |     required = [] | ||||||
|  |     for param_name, param in signature.parameters.items(): | ||||||
|  |         if param.annotation == inspect.Parameter.empty: | ||||||
|  |             raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") | ||||||
|  |         if param.default == inspect.Parameter.empty: | ||||||
|  |             required.append(param_name) | ||||||
|  | 
 | ||||||
|  |     properties = {} | ||||||
|  |     for param_name, param_type in type_hints.items(): | ||||||
|  |         properties[param_name] = _parse_type_hint(param_type) | ||||||
|  | 
 | ||||||
|  |     schema = {"type": "object", "properties": properties} | ||||||
|  |     if required: | ||||||
|  |         schema["required"] = required | ||||||
|  | 
 | ||||||
|  |     return schema | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _parse_type_hint(hint: str) -> Dict: | ||||||
|  |     origin = get_origin(hint) | ||||||
|  |     args = get_args(hint) | ||||||
|  | 
 | ||||||
|  |     if origin is None: | ||||||
|  |         try: | ||||||
|  |             return _get_json_schema_type(hint) | ||||||
|  |         except KeyError: | ||||||
|  |             raise TypeHintParsingException( | ||||||
|  |                 "Couldn't parse this type hint, likely due to a custom class or object: ", | ||||||
|  |                 hint, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): | ||||||
|  |         # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end | ||||||
|  |         subtypes = [_parse_type_hint(t) for t in args if t is not type(None)] | ||||||
|  |         if len(subtypes) == 1: | ||||||
|  |             # A single non-null type can be expressed directly | ||||||
|  |             return_dict = subtypes[0] | ||||||
|  |         elif all(isinstance(subtype["type"], str) for subtype in subtypes): | ||||||
|  |             # A union of basic types can be expressed as a list in the schema | ||||||
|  |             return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])} | ||||||
|  |         else: | ||||||
|  |             # A union of more complex types requires "anyOf" | ||||||
|  |             return_dict = {"anyOf": subtypes} | ||||||
|  |         if type(None) in args: | ||||||
|  |             return_dict["nullable"] = True | ||||||
|  |         return return_dict | ||||||
|  | 
 | ||||||
|  |     elif origin is list: | ||||||
|  |         if not args: | ||||||
|  |             return {"type": "array"} | ||||||
|  |         else: | ||||||
|  |             # Lists can only have a single type argument, so recurse into it | ||||||
|  |             return {"type": "array", "items": _parse_type_hint(args[0])} | ||||||
|  | 
 | ||||||
|  |     elif origin is tuple: | ||||||
|  |         if not args: | ||||||
|  |             return {"type": "array"} | ||||||
|  |         if len(args) == 1: | ||||||
|  |             raise TypeHintParsingException( | ||||||
|  |                 f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " | ||||||
|  |                 "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " | ||||||
|  |                 "more than one element, we recommend " | ||||||
|  |                 "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " | ||||||
|  |                 "pass the element directly." | ||||||
|  |             ) | ||||||
|  |         if ... in args: | ||||||
|  |             raise TypeHintParsingException( | ||||||
|  |                 "Conversion of '...' is not supported in Tuple type hints. " | ||||||
|  |                 "Use List[] types for variable-length" | ||||||
|  |                 " inputs instead." | ||||||
|  |             ) | ||||||
|  |         return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} | ||||||
|  | 
 | ||||||
|  |     elif origin is dict: | ||||||
|  |         # The JSON equivalent to a dict is 'object', which mandates that all keys are strings | ||||||
|  |         # However, we can specify the type of the dict values with "additionalProperties" | ||||||
|  |         out = {"type": "object"} | ||||||
|  |         if len(args) == 2: | ||||||
|  |             out["additionalProperties"] = _parse_type_hint(args[1]) | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  |     raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _BASE_TYPE_MAPPING = { | ||||||
|  |     int: {"type": "integer"}, | ||||||
|  |     float: {"type": "number"}, | ||||||
|  |     str: {"type": "string"}, | ||||||
|  |     bool: {"type": "boolean"}, | ||||||
|  |     Any: {}, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _get_json_schema_type(param_type: str) -> Dict[str, str]: | ||||||
|  |     if param_type in _BASE_TYPE_MAPPING: | ||||||
|  |         return _BASE_TYPE_MAPPING[param_type] | ||||||
|  |     if str(param_type) == "Image" and _is_pillow_available(): | ||||||
|  |         from PIL.Image import Image | ||||||
|  | 
 | ||||||
|  |         if param_type == Image: | ||||||
|  |             return {"type": "image"} | ||||||
|  |     if str(param_type) == "Tensor" and is_torch_available(): | ||||||
|  |         from torch import Tensor | ||||||
|  | 
 | ||||||
|  |         if param_type == Tensor: | ||||||
|  |             return {"type": "audio"} | ||||||
|  |     return {"type": "object"} | ||||||
|  | @ -14,33 +14,19 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import json |  | ||||||
| import re | import re | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Dict, Optional | from typing import Dict, Optional | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import hf_hub_download, list_spaces |  | ||||||
| from transformers.utils import is_offline_mode, is_torch_available |  | ||||||
| 
 |  | ||||||
| from .local_python_executor import ( | from .local_python_executor import ( | ||||||
|     BASE_BUILTIN_MODULES, |     BASE_BUILTIN_MODULES, | ||||||
|     BASE_PYTHON_TOOLS, |     BASE_PYTHON_TOOLS, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
| ) | ) | ||||||
| from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | from .tools import PipelineTool, Tool | ||||||
| from .types import AgentAudio | from .types import AgentAudio | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): |  | ||||||
|     from transformers.models.whisper import ( |  | ||||||
|         WhisperForConditionalGeneration, |  | ||||||
|         WhisperProcessor, |  | ||||||
|     ) |  | ||||||
| else: |  | ||||||
|     WhisperForConditionalGeneration = object |  | ||||||
|     WhisperProcessor = object |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass | @dataclass | ||||||
| class PreTool: | class PreTool: | ||||||
|     name: str |     name: str | ||||||
|  | @ -51,31 +37,6 @@ class PreTool: | ||||||
|     repo_id: str |     repo_id: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_remote_tools(logger, organization="huggingface-tools"): |  | ||||||
|     if is_offline_mode(): |  | ||||||
|         logger.info("You are in offline mode, so remote tools are not available.") |  | ||||||
|         return {} |  | ||||||
| 
 |  | ||||||
|     spaces = list_spaces(author=organization) |  | ||||||
|     tools = {} |  | ||||||
|     for space_info in spaces: |  | ||||||
|         repo_id = space_info.id |  | ||||||
|         resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") |  | ||||||
|         with open(resolved_config_file, encoding="utf-8") as reader: |  | ||||||
|             config = json.load(reader) |  | ||||||
|         task = repo_id.split("/")[-1] |  | ||||||
|         tools[config["name"]] = PreTool( |  | ||||||
|             task=task, |  | ||||||
|             description=config["description"], |  | ||||||
|             repo_id=repo_id, |  | ||||||
|             name=task, |  | ||||||
|             inputs=config["inputs"], |  | ||||||
|             output_type=config["output_type"], |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     return tools |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class PythonInterpreterTool(Tool): | class PythonInterpreterTool(Tool): | ||||||
|     name = "python_interpreter" |     name = "python_interpreter" | ||||||
|     description = "This is a tool that evaluates python code. It can be used to perform calculations." |     description = "This is a tool that evaluates python code. It can be used to perform calculations." | ||||||
|  | @ -150,10 +111,10 @@ class DuckDuckGoSearchTool(Tool): | ||||||
|         self.max_results = max_results |         self.max_results = max_results | ||||||
|         try: |         try: | ||||||
|             from duckduckgo_search import DDGS |             from duckduckgo_search import DDGS | ||||||
|         except ImportError: |         except ImportError as e: | ||||||
|             raise ImportError( |             raise ImportError( | ||||||
|                 "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." |                 "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." | ||||||
|             ) |             ) from e | ||||||
|         self.ddgs = DDGS() |         self.ddgs = DDGS() | ||||||
| 
 | 
 | ||||||
|     def forward(self, query: str) -> str: |     def forward(self, query: str) -> str: | ||||||
|  | @ -259,10 +220,10 @@ class VisitWebpageTool(Tool): | ||||||
|             from requests.exceptions import RequestException |             from requests.exceptions import RequestException | ||||||
| 
 | 
 | ||||||
|             from smolagents.utils import truncate_content |             from smolagents.utils import truncate_content | ||||||
|         except ImportError: |         except ImportError as e: | ||||||
|             raise ImportError( |             raise ImportError( | ||||||
|                 "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." |                 "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." | ||||||
|             ) |             ) from e | ||||||
|         try: |         try: | ||||||
|             # Send a GET request to the URL |             # Send a GET request to the URL | ||||||
|             response = requests.get(url) |             response = requests.get(url) | ||||||
|  | @ -286,9 +247,6 @@ class SpeechToTextTool(PipelineTool): | ||||||
|     default_checkpoint = "openai/whisper-large-v3-turbo" |     default_checkpoint = "openai/whisper-large-v3-turbo" | ||||||
|     description = "This is a tool that transcribes an audio into text. It returns the transcribed text." |     description = "This is a tool that transcribes an audio into text. It returns the transcribed text." | ||||||
|     name = "transcriber" |     name = "transcriber" | ||||||
|     pre_processor_class = WhisperProcessor |  | ||||||
|     model_class = WhisperForConditionalGeneration |  | ||||||
| 
 |  | ||||||
|     inputs = { |     inputs = { | ||||||
|         "audio": { |         "audio": { | ||||||
|             "type": "audio", |             "type": "audio", | ||||||
|  | @ -297,6 +255,18 @@ class SpeechToTextTool(PipelineTool): | ||||||
|     } |     } | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|  |     def __new__(cls): | ||||||
|  |         from transformers.models.whisper import ( | ||||||
|  |             WhisperForConditionalGeneration, | ||||||
|  |             WhisperProcessor, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if not hasattr(cls, "pre_processor_class"): | ||||||
|  |             cls.pre_processor_class = WhisperProcessor | ||||||
|  |         if not hasattr(cls, "model_class"): | ||||||
|  |             cls.model_class = WhisperForConditionalGeneration | ||||||
|  |         return super().__new__() | ||||||
|  | 
 | ||||||
|     def encode(self, audio): |     def encode(self, audio): | ||||||
|         audio = AgentAudio(audio).to_raw() |         audio = AgentAudio(audio).to_raw() | ||||||
|         return self.pre_processor(audio, return_tensors="pt") |         return self.pre_processor(audio, return_tensors="pt") | ||||||
|  |  | ||||||
|  | @ -19,14 +19,15 @@ import re | ||||||
| import shutil | import shutil | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| import gradio as gr |  | ||||||
| 
 |  | ||||||
| from .agents import ActionStep, AgentStepLog, MultiStepAgent | from .agents import ActionStep, AgentStepLog, MultiStepAgent | ||||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||||
|  | from .utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def pull_messages_from_step(step_log: AgentStepLog): | def pull_messages_from_step(step_log: AgentStepLog): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|  |     import gradio as gr | ||||||
|  | 
 | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") |         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||||
|         if step_log.tool_calls is not None: |         if step_log.tool_calls is not None: | ||||||
|  | @ -57,6 +58,11 @@ def stream_to_gradio( | ||||||
|     additional_args: Optional[dict] = None, |     additional_args: Optional[dict] = None, | ||||||
| ): | ): | ||||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||||
|  |     if not _is_package_available("gradio"): | ||||||
|  |         raise ModuleNotFoundError( | ||||||
|  |             "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" | ||||||
|  |         ) | ||||||
|  |     import gradio as gr | ||||||
| 
 | 
 | ||||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): |     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): | ||||||
|         for message in pull_messages_from_step(step_log): |         for message in pull_messages_from_step(step_log): | ||||||
|  | @ -88,6 +94,10 @@ class GradioUI: | ||||||
|     """A one-line interface to launch your agent in Gradio""" |     """A one-line interface to launch your agent in Gradio""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None): |     def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None): | ||||||
|  |         if not _is_package_available("gradio"): | ||||||
|  |             raise ModuleNotFoundError( | ||||||
|  |                 "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" | ||||||
|  |             ) | ||||||
|         self.agent = agent |         self.agent = agent | ||||||
|         self.file_upload_folder = file_upload_folder |         self.file_upload_folder = file_upload_folder | ||||||
|         if self.file_upload_folder is not None: |         if self.file_upload_folder is not None: | ||||||
|  | @ -95,6 +105,8 @@ class GradioUI: | ||||||
|                 os.mkdir(file_upload_folder) |                 os.mkdir(file_upload_folder) | ||||||
| 
 | 
 | ||||||
|     def interact_with_agent(self, prompt, messages): |     def interact_with_agent(self, prompt, messages): | ||||||
|  |         import gradio as gr | ||||||
|  | 
 | ||||||
|         messages.append(gr.ChatMessage(role="user", content=prompt)) |         messages.append(gr.ChatMessage(role="user", content=prompt)) | ||||||
|         yield messages |         yield messages | ||||||
|         for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): |         for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): | ||||||
|  | @ -115,6 +127,7 @@ class GradioUI: | ||||||
|         """ |         """ | ||||||
|         Handle file uploads, default allowed types are .pdf, .docx, and .txt |         Handle file uploads, default allowed types are .pdf, .docx, and .txt | ||||||
|         """ |         """ | ||||||
|  |         import gradio as gr | ||||||
| 
 | 
 | ||||||
|         if file is None: |         if file is None: | ||||||
|             return gr.Textbox("No file uploaded", visible=True), file_uploads_log |             return gr.Textbox("No file uploaded", visible=True), file_uploads_log | ||||||
|  | @ -161,6 +174,8 @@ class GradioUI: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def launch(self): |     def launch(self): | ||||||
|  |         import gradio as gr | ||||||
|  | 
 | ||||||
|         with gr.Blocks() as demo: |         with gr.Blocks() as demo: | ||||||
|             stored_messages = gr.State([]) |             stored_messages = gr.State([]) | ||||||
|             file_uploads_log = gr.State([]) |             file_uploads_log = gr.State([]) | ||||||
|  |  | ||||||
|  | @ -21,20 +21,18 @@ import random | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from dataclasses import asdict, dataclass | from dataclasses import asdict, dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Any, Dict, List, Optional, Union | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import InferenceClient | from huggingface_hub import InferenceClient | ||||||
| from transformers import ( | from huggingface_hub.utils import is_torch_available | ||||||
|     AutoModelForCausalLM, |  | ||||||
|     AutoTokenizer, |  | ||||||
|     StoppingCriteria, |  | ||||||
|     StoppingCriteriaList, |  | ||||||
|     is_torch_available, |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
|  | from .utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from transformers import StoppingCriteriaList | ||||||
|  | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| DEFAULT_JSONAGENT_REGEX_GRAMMAR = { | DEFAULT_JSONAGENT_REGEX_GRAMMAR = { | ||||||
|  | @ -320,6 +318,9 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
|     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. |     This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. | ||||||
| 
 | 
 | ||||||
|  |     > [!TIP] | ||||||
|  |     > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case. | ||||||
|  | 
 | ||||||
|     Parameters: |     Parameters: | ||||||
|         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): |         model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`): | ||||||
|             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. |             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||||
|  | @ -358,9 +359,12 @@ class TransformersModel(Model): | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         if not is_torch_available(): |         if not is_torch_available() or not _is_package_available("transformers"): | ||||||
|             raise ImportError("Please install torch in order to use TransformersModel.") |             raise ModuleNotFoundError( | ||||||
|  |                 "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`" | ||||||
|  |             ) | ||||||
|         import torch |         import torch | ||||||
|  |         from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
| 
 | 
 | ||||||
|         default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" |         default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||||
|         if model_id is None: |         if model_id is None: | ||||||
|  | @ -387,7 +391,9 @@ class TransformersModel(Model): | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) |             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> "StoppingCriteriaList": | ||||||
|  |         from transformers import StoppingCriteria, StoppingCriteriaList | ||||||
|  | 
 | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|             def __init__(self, stop_strings: List[str], tokenizer): |             def __init__(self, stop_strings: List[str], tokenizer): | ||||||
|                 self.stop_strings = stop_strings |                 self.stop_strings = stop_strings | ||||||
|  | @ -491,6 +497,7 @@ class LiteLLMModel(Model): | ||||||
|             raise ModuleNotFoundError( |             raise ModuleNotFoundError( | ||||||
|                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" |                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" | ||||||
|             ) |             ) | ||||||
|  | 
 | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs |         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||||
|  | @ -506,9 +513,10 @@ class LiteLLMModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) |  | ||||||
|         import litellm |         import litellm | ||||||
| 
 | 
 | ||||||
|  |         messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) | ||||||
|  | 
 | ||||||
|         if tools_to_call_from: |         if tools_to_call_from: | ||||||
|             response = litellm.completion( |             response = litellm.completion( | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|  |  | ||||||
|  | @ -35,54 +35,22 @@ from huggingface_hub import ( | ||||||
|     metadata_update, |     metadata_update, | ||||||
|     upload_folder, |     upload_folder, | ||||||
| ) | ) | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError | from huggingface_hub.utils import is_torch_available | ||||||
| from packaging import version | from packaging import version | ||||||
| from transformers.dynamic_module_utils import get_imports |  | ||||||
| from transformers.utils import ( |  | ||||||
|     TypeHintParsingException, |  | ||||||
|     cached_file, |  | ||||||
|     get_json_schema, |  | ||||||
|     is_accelerate_available, |  | ||||||
|     is_torch_available, |  | ||||||
| ) |  | ||||||
| from transformers.utils.chat_template_utils import _parse_type_hint |  | ||||||
| 
 | 
 | ||||||
|  | from ._transformers_utils import ( | ||||||
|  |     TypeHintParsingException, | ||||||
|  |     _parse_type_hint, | ||||||
|  |     get_imports, | ||||||
|  |     get_json_schema, | ||||||
|  | ) | ||||||
| from .tool_validation import MethodChecker, validate_tool_attributes | from .tool_validation import MethodChecker, validate_tool_attributes | ||||||
| from .types import ImageType, handle_agent_input_types, handle_agent_output_types | from .types import handle_agent_input_types, handle_agent_output_types | ||||||
| from .utils import instance_to_source | from .utils import _is_package_available, _is_pillow_available, instance_to_source | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| if is_accelerate_available(): |  | ||||||
|     from accelerate import PartialState |  | ||||||
|     from accelerate.utils import send_to_device |  | ||||||
| 
 |  | ||||||
| if is_torch_available(): |  | ||||||
|     from transformers import AutoProcessor |  | ||||||
| else: |  | ||||||
|     AutoProcessor = object |  | ||||||
| 
 |  | ||||||
| TOOL_CONFIG_FILE = "tool_config.json" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_repo_type(repo_id, repo_type=None, **hub_kwargs): |  | ||||||
|     if repo_type is not None: |  | ||||||
|         return repo_type |  | ||||||
|     try: |  | ||||||
|         hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs) |  | ||||||
|         return "space" |  | ||||||
|     except RepositoryNotFoundError: |  | ||||||
|         try: |  | ||||||
|             hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) |  | ||||||
|             return "model" |  | ||||||
|         except RepositoryNotFoundError: |  | ||||||
|             raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.") |  | ||||||
|         except Exception: |  | ||||||
|             return "model" |  | ||||||
|     except Exception: |  | ||||||
|         return "space" |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def validate_after_init(cls): | def validate_after_init(cls): | ||||||
|     original_init = cls.__init__ |     original_init = cls.__init__ | ||||||
|  | @ -337,12 +305,8 @@ class Tool: | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # Save requirements file |         # Save requirements file | ||||||
|  |         imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"} | ||||||
|         requirements_file = os.path.join(output_dir, "requirements.txt") |         requirements_file = os.path.join(output_dir, "requirements.txt") | ||||||
| 
 |  | ||||||
|         imports = [] |  | ||||||
|         for module in [tool_file]: |  | ||||||
|             imports.extend(get_imports(module)) |  | ||||||
|         imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names])) |  | ||||||
|         with open(requirements_file, "w", encoding="utf-8") as f: |         with open(requirements_file, "w", encoding="utf-8") as f: | ||||||
|             f.write("\n".join(imports) + "\n") |             f.write("\n".join(imports) + "\n") | ||||||
| 
 | 
 | ||||||
|  | @ -439,53 +403,27 @@ class Tool: | ||||||
|                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the |                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the | ||||||
|                 others will be passed along to its init. |                 others will be passed along to its init. | ||||||
|         """ |         """ | ||||||
|         assert trust_remote_code, ( |         if not trust_remote_code: | ||||||
|  |             raise ValueError( | ||||||
|                 "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." |                 "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         hub_kwargs_names = [ |  | ||||||
|             "cache_dir", |  | ||||||
|             "force_download", |  | ||||||
|             "resume_download", |  | ||||||
|             "proxies", |  | ||||||
|             "revision", |  | ||||||
|             "repo_type", |  | ||||||
|             "subfolder", |  | ||||||
|             "local_files_only", |  | ||||||
|         ] |  | ||||||
|         hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} |  | ||||||
| 
 |  | ||||||
|         tool_file = "tool.py" |  | ||||||
| 
 |  | ||||||
|         # Get the tool's tool.py file. |         # Get the tool's tool.py file. | ||||||
|         hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) |         tool_file = hf_hub_download( | ||||||
|         resolved_tool_file = cached_file( |  | ||||||
|             repo_id, |             repo_id, | ||||||
|             tool_file, |             "tool.py", | ||||||
|             token=token, |             token=token, | ||||||
|             **hub_kwargs, |             repo_type="space", | ||||||
|             _raise_exceptions_for_gated_repo=False, |             cache_dir=kwargs.get("cache_dir"), | ||||||
|             _raise_exceptions_for_missing_entries=False, |             force_download=kwargs.get("force_download"), | ||||||
|             _raise_exceptions_for_connection_errors=False, |             resume_download=kwargs.get("resume_download"), | ||||||
|         ) |             proxies=kwargs.get("proxies"), | ||||||
|         tool_code = resolved_tool_file is not None |             revision=kwargs.get("revision"), | ||||||
|         if resolved_tool_file is None: |             subfolder=kwargs.get("subfolder"), | ||||||
|             resolved_tool_file = cached_file( |             local_files_only=kwargs.get("local_files_only"), | ||||||
|                 repo_id, |  | ||||||
|                 tool_file, |  | ||||||
|                 token=token, |  | ||||||
|                 **hub_kwargs, |  | ||||||
|                 _raise_exceptions_for_gated_repo=False, |  | ||||||
|                 _raise_exceptions_for_missing_entries=False, |  | ||||||
|                 _raise_exceptions_for_connection_errors=False, |  | ||||||
|             ) |  | ||||||
|         if resolved_tool_file is None: |  | ||||||
|             raise EnvironmentError( |  | ||||||
|                 f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." |  | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         with open(resolved_tool_file, encoding="utf-8") as reader: |         tool_code = Path(tool_file).read_text() | ||||||
|             tool_code = "".join(reader.readlines()) |  | ||||||
| 
 | 
 | ||||||
|         # Find the Tool subclass in the namespace |         # Find the Tool subclass in the namespace | ||||||
|         with tempfile.TemporaryDirectory() as temp_dir: |         with tempfile.TemporaryDirectory() as temp_dir: | ||||||
|  | @ -613,7 +551,10 @@ class Tool: | ||||||
|             def sanitize_argument_for_prediction(self, arg): |             def sanitize_argument_for_prediction(self, arg): | ||||||
|                 from gradio_client.utils import is_http_url_like |                 from gradio_client.utils import is_http_url_like | ||||||
| 
 | 
 | ||||||
|                 if isinstance(arg, ImageType): |                 if _is_pillow_available(): | ||||||
|  |                     from PIL.Image import Image | ||||||
|  | 
 | ||||||
|  |                 if _is_pillow_available() and isinstance(arg, Image): | ||||||
|                     temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |                     temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | ||||||
|                     arg.save(temp_file.name) |                     arg.save(temp_file.name) | ||||||
|                     arg = temp_file.name |                     arg = temp_file.name | ||||||
|  | @ -988,13 +929,13 @@ class PipelineTool(Tool): | ||||||
| 
 | 
 | ||||||
|     - **model_class** (`type`) -- The class to use to load the model in this tool. |     - **model_class** (`type`) -- The class to use to load the model in this tool. | ||||||
|     - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. |     - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. | ||||||
|     - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the |     - **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the | ||||||
|       pre-processor |       pre-processor | ||||||
|     - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the |     - **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the | ||||||
|       post-processor (when different from the pre-processor). |       post-processor (when different from the pre-processor). | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         model (`str` or [`PreTrainedModel`], *optional*): |         model (`str` or [`transformers.PreTrainedModel`], *optional*): | ||||||
|             The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the |             The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the | ||||||
|             value of the class attribute `default_checkpoint`. |             value of the class attribute `default_checkpoint`. | ||||||
|         pre_processor (`str` or `Any`, *optional*): |         pre_processor (`str` or `Any`, *optional*): | ||||||
|  | @ -1019,9 +960,9 @@ class PipelineTool(Tool): | ||||||
|             Any additional keyword argument to send to the methods that will load the data from the Hub. |             Any additional keyword argument to send to the methods that will load the data from the Hub. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     pre_processor_class = AutoProcessor |     pre_processor_class = None | ||||||
|     model_class = None |     model_class = None | ||||||
|     post_processor_class = AutoProcessor |     post_processor_class = None | ||||||
|     default_checkpoint = None |     default_checkpoint = None | ||||||
|     description = "This is a pipeline tool" |     description = "This is a pipeline tool" | ||||||
|     name = "pipeline" |     name = "pipeline" | ||||||
|  | @ -1040,11 +981,10 @@ class PipelineTool(Tool): | ||||||
|         token=None, |         token=None, | ||||||
|         **hub_kwargs, |         **hub_kwargs, | ||||||
|     ): |     ): | ||||||
|         if not is_torch_available(): |         if not is_torch_available() or not _is_package_available("accelerate"): | ||||||
|             raise ImportError("Please install torch in order to use this tool.") |             raise ModuleNotFoundError( | ||||||
| 
 |                 "Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`" | ||||||
|         if not is_accelerate_available(): |             ) | ||||||
|             raise ImportError("Please install accelerate in order to use this tool.") |  | ||||||
| 
 | 
 | ||||||
|         if model is None: |         if model is None: | ||||||
|             if self.default_checkpoint is None: |             if self.default_checkpoint is None: | ||||||
|  | @ -1071,6 +1011,10 @@ class PipelineTool(Tool): | ||||||
|         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. |         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. | ||||||
|         """ |         """ | ||||||
|         if isinstance(self.pre_processor, str): |         if isinstance(self.pre_processor, str): | ||||||
|  |             if self.pre_processor_class is None: | ||||||
|  |                 from transformers import AutoProcessor | ||||||
|  | 
 | ||||||
|  |                 self.pre_processor_class = AutoProcessor | ||||||
|             self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) |             self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) | ||||||
| 
 | 
 | ||||||
|         if isinstance(self.model, str): |         if isinstance(self.model, str): | ||||||
|  | @ -1079,12 +1023,18 @@ class PipelineTool(Tool): | ||||||
|         if self.post_processor is None: |         if self.post_processor is None: | ||||||
|             self.post_processor = self.pre_processor |             self.post_processor = self.pre_processor | ||||||
|         elif isinstance(self.post_processor, str): |         elif isinstance(self.post_processor, str): | ||||||
|  |             if self.post_processor_class is None: | ||||||
|  |                 from transformers import AutoProcessor | ||||||
|  | 
 | ||||||
|  |                 self.post_processor_class = AutoProcessor | ||||||
|             self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) |             self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) | ||||||
| 
 | 
 | ||||||
|         if self.device is None: |         if self.device is None: | ||||||
|             if self.device_map is not None: |             if self.device_map is not None: | ||||||
|                 self.device = list(self.model.hf_device_map.values())[0] |                 self.device = list(self.model.hf_device_map.values())[0] | ||||||
|             else: |             else: | ||||||
|  |                 from accelerate import PartialState | ||||||
|  | 
 | ||||||
|                 self.device = PartialState().default_device |                 self.device = PartialState().default_device | ||||||
| 
 | 
 | ||||||
|         if self.device_map is None: |         if self.device_map is None: | ||||||
|  | @ -1115,6 +1065,7 @@ class PipelineTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def __call__(self, *args, **kwargs): |     def __call__(self, *args, **kwargs): | ||||||
|         import torch |         import torch | ||||||
|  |         from accelerate.utils import send_to_device | ||||||
| 
 | 
 | ||||||
|         args, kwargs = handle_agent_input_types(*args, **kwargs) |         args, kwargs = handle_agent_input_types(*args, **kwargs) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,7 +12,6 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import importlib.util |  | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import pathlib | import pathlib | ||||||
|  | @ -22,26 +21,15 @@ from io import BytesIO | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import requests | import requests | ||||||
| from transformers.utils import ( | from huggingface_hub.utils import is_torch_available | ||||||
|     is_torch_available, | from PIL import Image | ||||||
|     is_vision_available, | from PIL.Image import Image as ImageType | ||||||
| ) | 
 | ||||||
|  | from .utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| if is_vision_available(): |  | ||||||
|     from PIL import Image |  | ||||||
|     from PIL.Image import Image as ImageType |  | ||||||
| else: |  | ||||||
|     ImageType = object |  | ||||||
| 
 |  | ||||||
| if is_torch_available(): |  | ||||||
|     import torch |  | ||||||
|     from torch import Tensor |  | ||||||
| else: |  | ||||||
|     Tensor = object |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class AgentType: | class AgentType: | ||||||
|     """ |     """ | ||||||
|  | @ -94,9 +82,6 @@ class AgentImage(AgentType, ImageType): | ||||||
|         AgentType.__init__(self, value) |         AgentType.__init__(self, value) | ||||||
|         ImageType.__init__(self) |         ImageType.__init__(self) | ||||||
| 
 | 
 | ||||||
|         if not is_vision_available(): |  | ||||||
|             raise ImportError("PIL must be installed in order to handle images.") |  | ||||||
| 
 |  | ||||||
|         self._path = None |         self._path = None | ||||||
|         self._raw = None |         self._raw = None | ||||||
|         self._tensor = None |         self._tensor = None | ||||||
|  | @ -109,11 +94,15 @@ class AgentImage(AgentType, ImageType): | ||||||
|             self._raw = Image.open(BytesIO(value)) |             self._raw = Image.open(BytesIO(value)) | ||||||
|         elif isinstance(value, (str, pathlib.Path)): |         elif isinstance(value, (str, pathlib.Path)): | ||||||
|             self._path = value |             self._path = value | ||||||
|         elif isinstance(value, torch.Tensor): |         elif is_torch_available(): | ||||||
|  |             import torch | ||||||
|  | 
 | ||||||
|  |             if isinstance(value, torch.Tensor): | ||||||
|                 self._tensor = value |                 self._tensor = value | ||||||
|         elif isinstance(value, np.ndarray): |             if isinstance(value, np.ndarray): | ||||||
|                 self._tensor = torch.from_numpy(value) |                 self._tensor = torch.from_numpy(value) | ||||||
|         else: | 
 | ||||||
|  |         if self._path is None and self._raw is None and self._tensor is None: | ||||||
|             raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") |             raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") | ||||||
| 
 | 
 | ||||||
|     def _ipython_display_(self, include=None, exclude=None): |     def _ipython_display_(self, include=None, exclude=None): | ||||||
|  | @ -183,10 +172,12 @@ class AgentAudio(AgentType, str): | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, value, samplerate=16_000): |     def __init__(self, value, samplerate=16_000): | ||||||
|         if importlib.util.find_spec("soundfile") is None: |         if not _is_package_available("soundfile") or not is_torch_available: | ||||||
|             raise ModuleNotFoundError( |             raise ModuleNotFoundError( | ||||||
|                 "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" |                 "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" | ||||||
|             ) |             ) | ||||||
|  |         import torch | ||||||
|  | 
 | ||||||
|         super().__init__(value) |         super().__init__(value) | ||||||
| 
 | 
 | ||||||
|         self._path = None |         self._path = None | ||||||
|  | @ -223,6 +214,8 @@ class AgentAudio(AgentType, str): | ||||||
|         if self._tensor is not None: |         if self._tensor is not None: | ||||||
|             return self._tensor |             return self._tensor | ||||||
| 
 | 
 | ||||||
|  |         import torch | ||||||
|  | 
 | ||||||
|         if self._path is not None: |         if self._path is not None: | ||||||
|             if "://" in str(self._path): |             if "://" in str(self._path): | ||||||
|                 response = requests.get(self._path) |                 response = requests.get(self._path) | ||||||
|  | @ -250,15 +243,7 @@ class AgentAudio(AgentType, str): | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} | _AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} | ||||||
| INSTANCE_TYPE_MAPPING = { |  | ||||||
|     str: AgentText, |  | ||||||
|     ImageType: AgentImage, |  | ||||||
|     Tensor: AgentAudio, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| if is_torch_available(): |  | ||||||
|     INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def handle_agent_input_types(*args, **kwargs): | def handle_agent_input_types(*args, **kwargs): | ||||||
|  | @ -268,16 +253,21 @@ def handle_agent_input_types(*args, **kwargs): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def handle_agent_output_types(output, output_type=None): | def handle_agent_output_types(output, output_type=None): | ||||||
|     if output_type in AGENT_TYPE_MAPPING: |     if output_type in _AGENT_TYPE_MAPPING: | ||||||
|         # If the class has defined outputs, we can map directly according to the class definition |         # If the class has defined outputs, we can map directly according to the class definition | ||||||
|         decoded_outputs = AGENT_TYPE_MAPPING[output_type](output) |         decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output) | ||||||
|         return decoded_outputs |         return decoded_outputs | ||||||
|     else: | 
 | ||||||
|     # If the class does not have defined output, then we map according to the type |     # If the class does not have defined output, then we map according to the type | ||||||
|         for _k, _v in INSTANCE_TYPE_MAPPING.items(): |     if isinstance(output, str): | ||||||
|             if isinstance(output, _k): |         return AgentText(output) | ||||||
|                 if _k is not object:  # avoid converting to audio if torch is not installed |     if isinstance(output, ImageType): | ||||||
|                     return _v(output) |         return AgentImage(output) | ||||||
|  |     if is_torch_available(): | ||||||
|  |         import torch | ||||||
|  | 
 | ||||||
|  |         if isinstance(output, torch.Tensor): | ||||||
|  |             return AgentAudio(output) | ||||||
|     return output |     return output | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,18 +15,30 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import ast | import ast | ||||||
|  | import importlib.metadata | ||||||
| import importlib.util | import importlib.util | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| import types | import types | ||||||
|  | from functools import lru_cache | ||||||
| from typing import Dict, Tuple, Union | from typing import Dict, Tuple, Union | ||||||
| 
 | 
 | ||||||
| from rich.console import Console | from rich.console import Console | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_pygments_available(): | @lru_cache | ||||||
|     return importlib.util.find_spec("soundfile") is not None | def _is_package_available(package_name: str) -> bool: | ||||||
|  |     try: | ||||||
|  |         importlib.metadata.version(package_name) | ||||||
|  |         return True | ||||||
|  |     except importlib.metadata.PackageNotFoundError: | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @lru_cache | ||||||
|  | def _is_pillow_available(): | ||||||
|  |     return importlib.util.find_spec("PIL") is not None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| console = Console() | console = Console() | ||||||
|  |  | ||||||
|  | @ -17,7 +17,7 @@ import unittest | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool | from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING | from smolagents.types import _AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
| from .test_tools import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
|  | @ -46,7 +46,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|     def test_agent_type_output(self): |     def test_agent_type_output(self): | ||||||
|         inputs = ["2 * 2"] |         inputs = ["2 * 2"] | ||||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |         output_type = _AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|         self.assertTrue(isinstance(output, output_type)) |         self.assertTrue(isinstance(output, output_type)) | ||||||
| 
 | 
 | ||||||
|     def test_agent_types_inputs(self): |     def test_agent_types_inputs(self): | ||||||
|  | @ -56,13 +56,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|         for _input, expected_input in zip(inputs, self.tool.inputs.values()): |         for _input, expected_input in zip(inputs, self.tool.inputs.values()): | ||||||
|             input_type = expected_input["type"] |             input_type = expected_input["type"] | ||||||
|             if isinstance(input_type, list): |             if isinstance(input_type, list): | ||||||
|                 _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) |                 _inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) | ||||||
|             else: |             else: | ||||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) |                 _inputs.append(_AGENT_TYPE_MAPPING[input_type](_input)) | ||||||
| 
 | 
 | ||||||
|         # Should not raise an error |         # Should not raise an error | ||||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |         output_type = _AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|         self.assertTrue(isinstance(output, output_type)) |         self.assertTrue(isinstance(output, output_type)) | ||||||
| 
 | 
 | ||||||
|     def test_imports_work(self): |     def test_imports_work(self): | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ from transformers import is_torch_available | ||||||
| from transformers.testing_utils import get_tests_dir, require_torch | from transformers.testing_utils import get_tests_dir, require_torch | ||||||
| 
 | 
 | ||||||
| from smolagents.default_tools import FinalAnswerTool | from smolagents.default_tools import FinalAnswerTool | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING | from smolagents.types import _AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
| from .test_tools import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
|  | @ -55,5 +55,5 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|         inputs = self.create_inputs() |         inputs = self.create_inputs() | ||||||
|         for input_type, input in inputs.items(): |         for input_type, input in inputs.items(): | ||||||
|             output = self.tool(**input, sanitize_inputs_outputs=True) |             output = self.tool(**input, sanitize_inputs_outputs=True) | ||||||
|             agent_type = AGENT_TYPE_MAPPING[input_type] |             agent_type = _AGENT_TYPE_MAPPING[input_type] | ||||||
|             self.assertTrue(isinstance(output, agent_type)) |             self.assertTrue(isinstance(output, agent_type)) | ||||||
|  |  | ||||||
|  | @ -21,16 +21,12 @@ from unittest.mock import MagicMock, patch | ||||||
| import mcp | import mcp | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest | import pytest | ||||||
|  | import torch | ||||||
| from transformers import is_torch_available, is_vision_available | from transformers import is_torch_available, is_vision_available | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool | from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool | ||||||
| from smolagents.types import ( | from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText | ||||||
|     AGENT_TYPE_MAPPING, |  | ||||||
|     AgentAudio, |  | ||||||
|     AgentImage, |  | ||||||
|     AgentText, |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|  | @ -96,7 +92,7 @@ class ToolTesterMixin: | ||||||
|         inputs = create_inputs(self.tool.inputs) |         inputs = create_inputs(self.tool.inputs) | ||||||
|         output = self.tool(**inputs, sanitize_inputs_outputs=True) |         output = self.tool(**inputs, sanitize_inputs_outputs=True) | ||||||
|         if self.tool.output_type != "any": |         if self.tool.output_type != "any": | ||||||
|             agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] |             agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|             self.assertTrue(isinstance(output, agent_type)) |             self.assertTrue(isinstance(output, agent_type)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -121,4 +121,3 @@ class AgentTextTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(string, agent_type.to_string()) |         self.assertEqual(string, agent_type.to_string()) | ||||||
|         self.assertEqual(string, agent_type.to_raw()) |         self.assertEqual(string, agent_type.to_raw()) | ||||||
|         self.assertEqual(string, agent_type) |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue