Make import time faster (optional deps + delay imports) (#253)
* adapt docs * optional in pyproject.toml * get rid of some transformers imports * optional transformers in models.py * gradio, transformers, litellm * small refacto AgentType * merge conflicts * mouaif * fix tests * AgentText no longer a str * Add back AgentType as str/Image * fixed for good
This commit is contained in:
parent
a2ca95107f
commit
d19ebc7a48
|
@ -55,6 +55,7 @@ agent.run(
|
||||||
<hfoption id="Local Transformers Model">
|
<hfoption id="Local Transformers Model">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[transformers]
|
||||||
from smolagents import CodeAgent, TransformersModel
|
from smolagents import CodeAgent, TransformersModel
|
||||||
|
|
||||||
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
@ -72,6 +73,7 @@ agent.run(
|
||||||
To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization.
|
To use `LiteLLMModel`, you need to set the environment variable `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`, or pass `api_key` variable upon initialization.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[litellm]
|
||||||
from smolagents import CodeAgent, LiteLLMModel
|
from smolagents import CodeAgent, LiteLLMModel
|
||||||
|
|
||||||
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o'
|
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # Could use 'gpt-4o'
|
||||||
|
@ -85,6 +87,7 @@ agent.run(
|
||||||
<hfoption id="Ollama">
|
<hfoption id="Ollama">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[litellm]
|
||||||
from smolagents import CodeAgent, LiteLLMModel
|
from smolagents import CodeAgent, LiteLLMModel
|
||||||
|
|
||||||
model = LiteLLMModel(
|
model = LiteLLMModel(
|
||||||
|
|
|
@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization.
|
||||||
|
|
||||||
### GradioUI
|
### GradioUI
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
|
||||||
|
|
||||||
[[autodoc]] GradioUI
|
[[autodoc]] GradioUI
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
|
||||||
>>> What a
|
>>> What a
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
|
||||||
|
|
||||||
[[autodoc]] TransformersModel
|
[[autodoc]] TransformersModel
|
||||||
|
|
||||||
### HfApiModel
|
### HfApiModel
|
||||||
|
|
|
@ -61,6 +61,7 @@ agent.run(
|
||||||
<hfoption id="本地Transformers模型">
|
<hfoption id="本地Transformers模型">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[transformers]
|
||||||
from smolagents import CodeAgent, TransformersModel
|
from smolagents import CodeAgent, TransformersModel
|
||||||
|
|
||||||
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
@ -78,6 +79,7 @@ agent.run(
|
||||||
要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY` 或 `OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。
|
要使用 `LiteLLMModel`,您需要设置环境变量 `ANTHROPIC_API_KEY` 或 `OPENAI_API_KEY`,或者在初始化时传递 `api_key` 变量。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[litellm]
|
||||||
from smolagents import CodeAgent, LiteLLMModel
|
from smolagents import CodeAgent, LiteLLMModel
|
||||||
|
|
||||||
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o'
|
model = LiteLLMModel(model_id="anthropic/claude-3-5-sonnet-latest", api_key="YOUR_ANTHROPIC_API_KEY") # 也可以使用 'gpt-4o'
|
||||||
|
@ -91,6 +93,7 @@ agent.run(
|
||||||
<hfoption id="Ollama">
|
<hfoption id="Ollama">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# !pip install smolagents[litellm]
|
||||||
from smolagents import CodeAgent, LiteLLMModel
|
from smolagents import CodeAgent, LiteLLMModel
|
||||||
|
|
||||||
model = LiteLLMModel(
|
model = LiteLLMModel(
|
||||||
|
|
|
@ -55,6 +55,9 @@ Both require arguments `model` and list of tools `tools` at initialization.
|
||||||
|
|
||||||
### GradioUI
|
### GradioUI
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case.
|
||||||
|
|
||||||
[[autodoc]] GradioUI
|
[[autodoc]] GradioUI
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
@ -99,6 +102,9 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
|
||||||
>>> What a
|
>>> What a
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
|
||||||
|
|
||||||
[[autodoc]] TransformersModel
|
[[autodoc]] TransformersModel
|
||||||
|
|
||||||
### HfApiModel
|
### HfApiModel
|
||||||
|
|
|
@ -12,14 +12,12 @@ authors = [
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"transformers>=4.0.0",
|
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
"rich>=13.9.4",
|
"rich>=13.9.4",
|
||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"jinja2>=3.1.4",
|
"jinja2>=3.1.4",
|
||||||
"pillow>=11.0.0",
|
"pillow>=11.0.0",
|
||||||
"markdownify>=0.14.1",
|
"markdownify>=0.14.1",
|
||||||
"gradio>=5.8.0",
|
|
||||||
"duckduckgo-search>=6.3.7",
|
"duckduckgo-search>=6.3.7",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
"e2b-code-interpreter>=1.0.3",
|
"e2b-code-interpreter>=1.0.3",
|
||||||
|
@ -27,12 +25,20 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
audio = [
|
|
||||||
"soundfile",
|
|
||||||
]
|
|
||||||
torch = [
|
torch = [
|
||||||
"torch",
|
"torch",
|
||||||
|
]
|
||||||
|
audio = [
|
||||||
|
"soundfile",
|
||||||
|
"smolagents[torch]",
|
||||||
|
]
|
||||||
|
transformers = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
"transformers>=4.0.0",
|
||||||
|
"smolagents[torch]",
|
||||||
|
]
|
||||||
|
gradio = [
|
||||||
|
"gradio>=5.8.0",
|
||||||
]
|
]
|
||||||
litellm = [
|
litellm = [
|
||||||
"litellm>=1.55.10",
|
"litellm>=1.55.10",
|
||||||
|
@ -47,9 +53,12 @@ openai = [
|
||||||
quality = [
|
quality = [
|
||||||
"ruff>=0.9.0",
|
"ruff>=0.9.0",
|
||||||
]
|
]
|
||||||
|
all = [
|
||||||
|
"smolagents[accelerate,audio,gradio,litellm,mcp,openai,transformers]",
|
||||||
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest>=8.1.0",
|
"pytest>=8.1.0",
|
||||||
"smolagents[audio,litellm,mcp,openai,torch]",
|
"smolagents[all]",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"smolagents[quality,test]",
|
"smolagents[quality,test]",
|
||||||
|
|
|
@ -16,36 +16,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
__version__ = "1.5.0.dev"
|
__version__ = "1.5.0.dev"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from .agents import *
|
||||||
|
from .default_tools import *
|
||||||
from transformers.utils import _LazyModule
|
from .e2b_executor import *
|
||||||
from transformers.utils.import_utils import define_import_structure
|
from .gradio_ui import *
|
||||||
|
from .local_python_executor import *
|
||||||
|
from .models import *
|
||||||
if TYPE_CHECKING:
|
from .monitoring import *
|
||||||
from .agents import *
|
from .prompts import *
|
||||||
from .default_tools import *
|
from .tools import *
|
||||||
from .e2b_executor import *
|
from .types import *
|
||||||
from .gradio_ui import *
|
from .utils import *
|
||||||
from .local_python_executor import *
|
|
||||||
from .models import *
|
|
||||||
from .monitoring import *
|
|
||||||
from .prompts import *
|
|
||||||
from .tools import *
|
|
||||||
from .types import *
|
|
||||||
from .utils import *
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
import sys
|
|
||||||
|
|
||||||
_file = globals()["__file__"]
|
|
||||||
import_structure = define_import_structure(_file)
|
|
||||||
import_structure[""] = {"__version__": __version__}
|
|
||||||
sys.modules[__name__] = _LazyModule(
|
|
||||||
__name__,
|
|
||||||
_file,
|
|
||||||
import_structure,
|
|
||||||
module_spec=__spec__,
|
|
||||||
extra_objects={"__version__": __version__},
|
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,388 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""This module contains utilities exclusively taken from `transformers` repository.
|
||||||
|
|
||||||
|
Since they are not specific to `transformers` and that `transformers` is an heavy dependencies, those helpers have
|
||||||
|
been duplicated.
|
||||||
|
|
||||||
|
TODO: move them to `huggingface_hub` to avoid code duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
|
||||||
|
from huggingface_hub.utils import is_torch_available
|
||||||
|
|
||||||
|
from .utils import _is_pillow_available
|
||||||
|
|
||||||
|
|
||||||
|
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of all packages required to use the input module.
|
||||||
|
"""
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# filter out try/except block so in custom code we can have try/except imports
|
||||||
|
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
|
||||||
|
content = re.sub(
|
||||||
|
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
|
||||||
|
"",
|
||||||
|
content,
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Imports of the form `import xxx`
|
||||||
|
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
||||||
|
# Imports of the form `from xxx import yyy`
|
||||||
|
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||||
|
# Only keep the top-level module
|
||||||
|
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||||
|
return list(set(imports))
|
||||||
|
|
||||||
|
|
||||||
|
class TypeHintParsingException(Exception):
|
||||||
|
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
|
||||||
|
|
||||||
|
|
||||||
|
class DocstringParsingException(Exception):
|
||||||
|
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_schema(func: Callable) -> Dict:
|
||||||
|
"""
|
||||||
|
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||||
|
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||||
|
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
|
||||||
|
that the function has a docstring, and that each argument has a description in the docstring, in the standard
|
||||||
|
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
|
||||||
|
|
||||||
|
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
|
||||||
|
optional because most chat templates ignore the return value of the function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to generate a JSON schema for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the JSON schema for the function.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> def multiply(x: float, y: float):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that multiplies two numbers
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> x: The first number to multiply
|
||||||
|
>>> y: The second number to multiply
|
||||||
|
>>> '''
|
||||||
|
>>> return x * y
|
||||||
|
>>>
|
||||||
|
>>> print(get_json_schema(multiply))
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "A function that multiplies two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "number", "description": "The first number to multiply"},
|
||||||
|
"y": {"type": "number", "description": "The second number to multiply"}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
|
||||||
|
support them, like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer
|
||||||
|
>>> from transformers.utils import get_json_schema
|
||||||
|
>>>
|
||||||
|
>>> def multiply(x: float, y: float):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that multiplies two numbers
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> x: The first number to multiply
|
||||||
|
>>> y: The second number to multiply
|
||||||
|
>>> return x * y
|
||||||
|
>>> '''
|
||||||
|
>>>
|
||||||
|
>>> multiply_schema = get_json_schema(multiply)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
|
||||||
|
>>> formatted_chat = tokenizer.apply_chat_template(
|
||||||
|
>>> messages,
|
||||||
|
>>> tools=[multiply_schema],
|
||||||
|
>>> chat_template="tool_use",
|
||||||
|
>>> return_dict=True,
|
||||||
|
>>> return_tensors="pt",
|
||||||
|
>>> add_generation_prompt=True
|
||||||
|
>>> )
|
||||||
|
>>> # The formatted chat can now be passed to model.generate()
|
||||||
|
```
|
||||||
|
|
||||||
|
Each argument description can also have an optional `(choices: ...)` block at the end, such as
|
||||||
|
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
|
||||||
|
only be parsed correctly if it is at the end of the line:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> def drink_beverage(beverage: str):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that drinks a beverage
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
|
||||||
|
>>> '''
|
||||||
|
>>> pass
|
||||||
|
>>>
|
||||||
|
>>> print(get_json_schema(drink_beverage))
|
||||||
|
```
|
||||||
|
{
|
||||||
|
'name': 'drink_beverage',
|
||||||
|
'description': 'A function that drinks a beverage',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'beverage': {
|
||||||
|
'type': 'string',
|
||||||
|
'enum': ['tea', 'coffee'],
|
||||||
|
'description': 'The beverage to drink'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'required': ['beverage']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
doc = inspect.getdoc(func)
|
||||||
|
if not doc:
|
||||||
|
raise DocstringParsingException(
|
||||||
|
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
|
||||||
|
)
|
||||||
|
doc = doc.strip()
|
||||||
|
main_doc, param_descriptions, return_doc = _parse_google_format_docstring(doc)
|
||||||
|
|
||||||
|
json_schema = _convert_type_hints_to_json_schema(func)
|
||||||
|
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
||||||
|
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
||||||
|
return_dict["description"] = return_doc
|
||||||
|
for arg, schema in json_schema["properties"].items():
|
||||||
|
if arg not in param_descriptions:
|
||||||
|
raise DocstringParsingException(
|
||||||
|
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
||||||
|
)
|
||||||
|
desc = param_descriptions[arg]
|
||||||
|
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
||||||
|
if enum_choices:
|
||||||
|
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
||||||
|
desc = enum_choices.string[: enum_choices.start()].strip()
|
||||||
|
schema["description"] = desc
|
||||||
|
|
||||||
|
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
||||||
|
if return_dict is not None:
|
||||||
|
output["return"] = return_dict
|
||||||
|
return {"type": "function", "function": output}
|
||||||
|
|
||||||
|
|
||||||
|
# Extracts the initial segment of the docstring, containing the function description
|
||||||
|
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
|
||||||
|
# Extracts the Args: block from the docstring
|
||||||
|
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
||||||
|
# Splits the Args: block into individual arguments
|
||||||
|
args_split_re = re.compile(
|
||||||
|
r"""
|
||||||
|
(?:^|\n) # Match the start of the args block, or a newline
|
||||||
|
\s*(\w+):\s* # Capture the argument name and strip spacing
|
||||||
|
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
||||||
|
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
|
||||||
|
""",
|
||||||
|
re.DOTALL | re.VERBOSE,
|
||||||
|
)
|
||||||
|
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
|
||||||
|
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_google_format_docstring(
|
||||||
|
docstring: str,
|
||||||
|
) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Parses a Google-style docstring to extract the function description,
|
||||||
|
argument descriptions, and return description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docstring (str): The docstring to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The function description, arguments, and return description.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Extract the sections
|
||||||
|
description_match = description_re.search(docstring)
|
||||||
|
args_match = args_re.search(docstring)
|
||||||
|
returns_match = returns_re.search(docstring)
|
||||||
|
|
||||||
|
# Clean and store the sections
|
||||||
|
description = description_match.group(1).strip() if description_match else None
|
||||||
|
docstring_args = args_match.group(1).strip() if args_match else None
|
||||||
|
returns = returns_match.group(1).strip() if returns_match else None
|
||||||
|
|
||||||
|
# Parsing the arguments into a dictionary
|
||||||
|
if docstring_args is not None:
|
||||||
|
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
||||||
|
matches = args_split_re.findall(docstring_args)
|
||||||
|
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
||||||
|
else:
|
||||||
|
args_dict = {}
|
||||||
|
|
||||||
|
return description, args_dict, returns
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
required = []
|
||||||
|
for param_name, param in signature.parameters.items():
|
||||||
|
if param.annotation == inspect.Parameter.empty:
|
||||||
|
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
required.append(param_name)
|
||||||
|
|
||||||
|
properties = {}
|
||||||
|
for param_name, param_type in type_hints.items():
|
||||||
|
properties[param_name] = _parse_type_hint(param_type)
|
||||||
|
|
||||||
|
schema = {"type": "object", "properties": properties}
|
||||||
|
if required:
|
||||||
|
schema["required"] = required
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_type_hint(hint: str) -> Dict:
|
||||||
|
origin = get_origin(hint)
|
||||||
|
args = get_args(hint)
|
||||||
|
|
||||||
|
if origin is None:
|
||||||
|
try:
|
||||||
|
return _get_json_schema_type(hint)
|
||||||
|
except KeyError:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
"Couldn't parse this type hint, likely due to a custom class or object: ",
|
||||||
|
hint,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
||||||
|
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
||||||
|
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
|
||||||
|
if len(subtypes) == 1:
|
||||||
|
# A single non-null type can be expressed directly
|
||||||
|
return_dict = subtypes[0]
|
||||||
|
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
||||||
|
# A union of basic types can be expressed as a list in the schema
|
||||||
|
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
|
||||||
|
else:
|
||||||
|
# A union of more complex types requires "anyOf"
|
||||||
|
return_dict = {"anyOf": subtypes}
|
||||||
|
if type(None) in args:
|
||||||
|
return_dict["nullable"] = True
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
elif origin is list:
|
||||||
|
if not args:
|
||||||
|
return {"type": "array"}
|
||||||
|
else:
|
||||||
|
# Lists can only have a single type argument, so recurse into it
|
||||||
|
return {"type": "array", "items": _parse_type_hint(args[0])}
|
||||||
|
|
||||||
|
elif origin is tuple:
|
||||||
|
if not args:
|
||||||
|
return {"type": "array"}
|
||||||
|
if len(args) == 1:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
||||||
|
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
||||||
|
"more than one element, we recommend "
|
||||||
|
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
||||||
|
"pass the element directly."
|
||||||
|
)
|
||||||
|
if ... in args:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
"Conversion of '...' is not supported in Tuple type hints. "
|
||||||
|
"Use List[] types for variable-length"
|
||||||
|
" inputs instead."
|
||||||
|
)
|
||||||
|
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
||||||
|
# However, we can specify the type of the dict values with "additionalProperties"
|
||||||
|
out = {"type": "object"}
|
||||||
|
if len(args) == 2:
|
||||||
|
out["additionalProperties"] = _parse_type_hint(args[1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||||
|
|
||||||
|
|
||||||
|
_BASE_TYPE_MAPPING = {
|
||||||
|
int: {"type": "integer"},
|
||||||
|
float: {"type": "number"},
|
||||||
|
str: {"type": "string"},
|
||||||
|
bool: {"type": "boolean"},
|
||||||
|
Any: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||||
|
if param_type in _BASE_TYPE_MAPPING:
|
||||||
|
return _BASE_TYPE_MAPPING[param_type]
|
||||||
|
if str(param_type) == "Image" and _is_pillow_available():
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
if param_type == Image:
|
||||||
|
return {"type": "image"}
|
||||||
|
if str(param_type) == "Tensor" and is_torch_available():
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
if param_type == Tensor:
|
||||||
|
return {"type": "audio"}
|
||||||
|
return {"type": "object"}
|
|
@ -14,33 +14,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
|
||||||
from transformers.utils import is_offline_mode, is_torch_available
|
|
||||||
|
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
BASE_PYTHON_TOOLS,
|
BASE_PYTHON_TOOLS,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
)
|
)
|
||||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
from .tools import PipelineTool, Tool
|
||||||
from .types import AgentAudio
|
from .types import AgentAudio
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
from transformers.models.whisper import (
|
|
||||||
WhisperForConditionalGeneration,
|
|
||||||
WhisperProcessor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
WhisperForConditionalGeneration = object
|
|
||||||
WhisperProcessor = object
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PreTool:
|
class PreTool:
|
||||||
name: str
|
name: str
|
||||||
|
@ -51,31 +37,6 @@ class PreTool:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
|
|
||||||
|
|
||||||
def get_remote_tools(logger, organization="huggingface-tools"):
|
|
||||||
if is_offline_mode():
|
|
||||||
logger.info("You are in offline mode, so remote tools are not available.")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
spaces = list_spaces(author=organization)
|
|
||||||
tools = {}
|
|
||||||
for space_info in spaces:
|
|
||||||
repo_id = space_info.id
|
|
||||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
|
||||||
config = json.load(reader)
|
|
||||||
task = repo_id.split("/")[-1]
|
|
||||||
tools[config["name"]] = PreTool(
|
|
||||||
task=task,
|
|
||||||
description=config["description"],
|
|
||||||
repo_id=repo_id,
|
|
||||||
name=task,
|
|
||||||
inputs=config["inputs"],
|
|
||||||
output_type=config["output_type"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInterpreterTool(Tool):
|
class PythonInterpreterTool(Tool):
|
||||||
name = "python_interpreter"
|
name = "python_interpreter"
|
||||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||||
|
@ -150,10 +111,10 @@ class DuckDuckGoSearchTool(Tool):
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
try:
|
try:
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
||||||
)
|
) from e
|
||||||
self.ddgs = DDGS()
|
self.ddgs = DDGS()
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> str:
|
||||||
|
@ -259,10 +220,10 @@ class VisitWebpageTool(Tool):
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from smolagents.utils import truncate_content
|
from smolagents.utils import truncate_content
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
|
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
|
||||||
)
|
) from e
|
||||||
try:
|
try:
|
||||||
# Send a GET request to the URL
|
# Send a GET request to the URL
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
@ -286,9 +247,6 @@ class SpeechToTextTool(PipelineTool):
|
||||||
default_checkpoint = "openai/whisper-large-v3-turbo"
|
default_checkpoint = "openai/whisper-large-v3-turbo"
|
||||||
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
||||||
name = "transcriber"
|
name = "transcriber"
|
||||||
pre_processor_class = WhisperProcessor
|
|
||||||
model_class = WhisperForConditionalGeneration
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"audio": {
|
"audio": {
|
||||||
"type": "audio",
|
"type": "audio",
|
||||||
|
@ -297,6 +255,18 @@ class SpeechToTextTool(PipelineTool):
|
||||||
}
|
}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
from transformers.models.whisper import (
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(cls, "pre_processor_class"):
|
||||||
|
cls.pre_processor_class = WhisperProcessor
|
||||||
|
if not hasattr(cls, "model_class"):
|
||||||
|
cls.model_class = WhisperForConditionalGeneration
|
||||||
|
return super().__new__()
|
||||||
|
|
||||||
def encode(self, audio):
|
def encode(self, audio):
|
||||||
audio = AgentAudio(audio).to_raw()
|
audio = AgentAudio(audio).to_raw()
|
||||||
return self.pre_processor(audio, return_tensors="pt")
|
return self.pre_processor(audio, return_tensors="pt")
|
||||||
|
|
|
@ -19,14 +19,15 @@ import re
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
||||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
|
from .utils import _is_package_available
|
||||||
|
|
||||||
|
|
||||||
def pull_messages_from_step(step_log: AgentStepLog):
|
def pull_messages_from_step(step_log: AgentStepLog):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
|
||||||
if step_log.tool_calls is not None:
|
if step_log.tool_calls is not None:
|
||||||
|
@ -57,6 +58,11 @@ def stream_to_gradio(
|
||||||
additional_args: Optional[dict] = None,
|
additional_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
if not _is_package_available("gradio"):
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`"
|
||||||
|
)
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
|
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
|
||||||
for message in pull_messages_from_step(step_log):
|
for message in pull_messages_from_step(step_log):
|
||||||
|
@ -88,6 +94,10 @@ class GradioUI:
|
||||||
"""A one-line interface to launch your agent in Gradio"""
|
"""A one-line interface to launch your agent in Gradio"""
|
||||||
|
|
||||||
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
|
def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
|
||||||
|
if not _is_package_available("gradio"):
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`"
|
||||||
|
)
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.file_upload_folder = file_upload_folder
|
self.file_upload_folder = file_upload_folder
|
||||||
if self.file_upload_folder is not None:
|
if self.file_upload_folder is not None:
|
||||||
|
@ -95,6 +105,8 @@ class GradioUI:
|
||||||
os.mkdir(file_upload_folder)
|
os.mkdir(file_upload_folder)
|
||||||
|
|
||||||
def interact_with_agent(self, prompt, messages):
|
def interact_with_agent(self, prompt, messages):
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
messages.append(gr.ChatMessage(role="user", content=prompt))
|
messages.append(gr.ChatMessage(role="user", content=prompt))
|
||||||
yield messages
|
yield messages
|
||||||
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
|
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
|
||||||
|
@ -115,6 +127,7 @@ class GradioUI:
|
||||||
"""
|
"""
|
||||||
Handle file uploads, default allowed types are .pdf, .docx, and .txt
|
Handle file uploads, default allowed types are .pdf, .docx, and .txt
|
||||||
"""
|
"""
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
if file is None:
|
if file is None:
|
||||||
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
|
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
|
||||||
|
@ -161,6 +174,8 @@ class GradioUI:
|
||||||
)
|
)
|
||||||
|
|
||||||
def launch(self):
|
def launch(self):
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
stored_messages = gr.State([])
|
stored_messages = gr.State([])
|
||||||
file_uploads_log = gr.State([])
|
file_uploads_log = gr.State([])
|
||||||
|
|
|
@ -21,20 +21,18 @@ import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
from transformers import (
|
from huggingface_hub.utils import is_torch_available
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
StoppingCriteria,
|
|
||||||
StoppingCriteriaList,
|
|
||||||
is_torch_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
from .utils import _is_package_available
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import StoppingCriteriaList
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||||
|
@ -320,6 +318,9 @@ class TransformersModel(Model):
|
||||||
|
|
||||||
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
|
||||||
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
|
@ -358,9 +359,12 @@ class TransformersModel(Model):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not is_torch_available():
|
if not is_torch_available() or not _is_package_available("transformers"):
|
||||||
raise ImportError("Please install torch in order to use TransformersModel.")
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
|
||||||
|
)
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
|
@ -387,7 +391,9 @@ class TransformersModel(Model):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> "StoppingCriteriaList":
|
||||||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
class StopOnStrings(StoppingCriteria):
|
class StopOnStrings(StoppingCriteria):
|
||||||
def __init__(self, stop_strings: List[str], tokenizer):
|
def __init__(self, stop_strings: List[str], tokenizer):
|
||||||
self.stop_strings = stop_strings
|
self.stop_strings = stop_strings
|
||||||
|
@ -491,6 +497,7 @@ class LiteLLMModel(Model):
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||||
|
@ -506,9 +513,10 @@ class LiteLLMModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||||
|
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
|
|
|
@ -35,54 +35,22 @@ from huggingface_hub import (
|
||||||
metadata_update,
|
metadata_update,
|
||||||
upload_folder,
|
upload_folder,
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import is_torch_available
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.dynamic_module_utils import get_imports
|
|
||||||
from transformers.utils import (
|
|
||||||
TypeHintParsingException,
|
|
||||||
cached_file,
|
|
||||||
get_json_schema,
|
|
||||||
is_accelerate_available,
|
|
||||||
is_torch_available,
|
|
||||||
)
|
|
||||||
from transformers.utils.chat_template_utils import _parse_type_hint
|
|
||||||
|
|
||||||
|
from ._transformers_utils import (
|
||||||
|
TypeHintParsingException,
|
||||||
|
_parse_type_hint,
|
||||||
|
get_imports,
|
||||||
|
get_json_schema,
|
||||||
|
)
|
||||||
from .tool_validation import MethodChecker, validate_tool_attributes
|
from .tool_validation import MethodChecker, validate_tool_attributes
|
||||||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
from .types import handle_agent_input_types, handle_agent_output_types
|
||||||
from .utils import instance_to_source
|
from .utils import _is_package_available, _is_pillow_available, instance_to_source
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate import PartialState
|
|
||||||
from accelerate.utils import send_to_device
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
else:
|
|
||||||
AutoProcessor = object
|
|
||||||
|
|
||||||
TOOL_CONFIG_FILE = "tool_config.json"
|
|
||||||
|
|
||||||
|
|
||||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
|
||||||
if repo_type is not None:
|
|
||||||
return repo_type
|
|
||||||
try:
|
|
||||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
|
|
||||||
return "space"
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
try:
|
|
||||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
|
||||||
return "model"
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
|
||||||
except Exception:
|
|
||||||
return "model"
|
|
||||||
except Exception:
|
|
||||||
return "space"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_after_init(cls):
|
def validate_after_init(cls):
|
||||||
original_init = cls.__init__
|
original_init = cls.__init__
|
||||||
|
@ -337,12 +305,8 @@ class Tool:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save requirements file
|
# Save requirements file
|
||||||
|
imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"}
|
||||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||||
|
|
||||||
imports = []
|
|
||||||
for module in [tool_file]:
|
|
||||||
imports.extend(get_imports(module))
|
|
||||||
imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
|
|
||||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(imports) + "\n")
|
f.write("\n".join(imports) + "\n")
|
||||||
|
|
||||||
|
@ -439,53 +403,27 @@ class Tool:
|
||||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||||
others will be passed along to its init.
|
others will be passed along to its init.
|
||||||
"""
|
"""
|
||||||
assert trust_remote_code, (
|
if not trust_remote_code:
|
||||||
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
raise ValueError(
|
||||||
)
|
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||||
|
)
|
||||||
hub_kwargs_names = [
|
|
||||||
"cache_dir",
|
|
||||||
"force_download",
|
|
||||||
"resume_download",
|
|
||||||
"proxies",
|
|
||||||
"revision",
|
|
||||||
"repo_type",
|
|
||||||
"subfolder",
|
|
||||||
"local_files_only",
|
|
||||||
]
|
|
||||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
|
||||||
|
|
||||||
tool_file = "tool.py"
|
|
||||||
|
|
||||||
# Get the tool's tool.py file.
|
# Get the tool's tool.py file.
|
||||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
tool_file = hf_hub_download(
|
||||||
resolved_tool_file = cached_file(
|
|
||||||
repo_id,
|
repo_id,
|
||||||
tool_file,
|
"tool.py",
|
||||||
token=token,
|
token=token,
|
||||||
**hub_kwargs,
|
repo_type="space",
|
||||||
_raise_exceptions_for_gated_repo=False,
|
cache_dir=kwargs.get("cache_dir"),
|
||||||
_raise_exceptions_for_missing_entries=False,
|
force_download=kwargs.get("force_download"),
|
||||||
_raise_exceptions_for_connection_errors=False,
|
resume_download=kwargs.get("resume_download"),
|
||||||
|
proxies=kwargs.get("proxies"),
|
||||||
|
revision=kwargs.get("revision"),
|
||||||
|
subfolder=kwargs.get("subfolder"),
|
||||||
|
local_files_only=kwargs.get("local_files_only"),
|
||||||
)
|
)
|
||||||
tool_code = resolved_tool_file is not None
|
|
||||||
if resolved_tool_file is None:
|
|
||||||
resolved_tool_file = cached_file(
|
|
||||||
repo_id,
|
|
||||||
tool_file,
|
|
||||||
token=token,
|
|
||||||
**hub_kwargs,
|
|
||||||
_raise_exceptions_for_gated_repo=False,
|
|
||||||
_raise_exceptions_for_missing_entries=False,
|
|
||||||
_raise_exceptions_for_connection_errors=False,
|
|
||||||
)
|
|
||||||
if resolved_tool_file is None:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(resolved_tool_file, encoding="utf-8") as reader:
|
tool_code = Path(tool_file).read_text()
|
||||||
tool_code = "".join(reader.readlines())
|
|
||||||
|
|
||||||
# Find the Tool subclass in the namespace
|
# Find the Tool subclass in the namespace
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
@ -613,7 +551,10 @@ class Tool:
|
||||||
def sanitize_argument_for_prediction(self, arg):
|
def sanitize_argument_for_prediction(self, arg):
|
||||||
from gradio_client.utils import is_http_url_like
|
from gradio_client.utils import is_http_url_like
|
||||||
|
|
||||||
if isinstance(arg, ImageType):
|
if _is_pillow_available():
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
if _is_pillow_available() and isinstance(arg, Image):
|
||||||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||||
arg.save(temp_file.name)
|
arg.save(temp_file.name)
|
||||||
arg = temp_file.name
|
arg = temp_file.name
|
||||||
|
@ -988,13 +929,13 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
||||||
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
||||||
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
- **pre_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
|
||||||
pre-processor
|
pre-processor
|
||||||
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
- **post_processor_class** (`type`, *optional*, defaults to [`transformers.AutoProcessor`]) -- The class to use to load the
|
||||||
post-processor (when different from the pre-processor).
|
post-processor (when different from the pre-processor).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`str` or [`PreTrainedModel`], *optional*):
|
model (`str` or [`transformers.PreTrainedModel`], *optional*):
|
||||||
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
||||||
value of the class attribute `default_checkpoint`.
|
value of the class attribute `default_checkpoint`.
|
||||||
pre_processor (`str` or `Any`, *optional*):
|
pre_processor (`str` or `Any`, *optional*):
|
||||||
|
@ -1019,9 +960,9 @@ class PipelineTool(Tool):
|
||||||
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pre_processor_class = AutoProcessor
|
pre_processor_class = None
|
||||||
model_class = None
|
model_class = None
|
||||||
post_processor_class = AutoProcessor
|
post_processor_class = None
|
||||||
default_checkpoint = None
|
default_checkpoint = None
|
||||||
description = "This is a pipeline tool"
|
description = "This is a pipeline tool"
|
||||||
name = "pipeline"
|
name = "pipeline"
|
||||||
|
@ -1040,11 +981,10 @@ class PipelineTool(Tool):
|
||||||
token=None,
|
token=None,
|
||||||
**hub_kwargs,
|
**hub_kwargs,
|
||||||
):
|
):
|
||||||
if not is_torch_available():
|
if not is_torch_available() or not _is_package_available("accelerate"):
|
||||||
raise ImportError("Please install torch in order to use this tool.")
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'transformers' extra to use a PipelineTool: `pip install 'smolagents[transformers]'`"
|
||||||
if not is_accelerate_available():
|
)
|
||||||
raise ImportError("Please install accelerate in order to use this tool.")
|
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
if self.default_checkpoint is None:
|
if self.default_checkpoint is None:
|
||||||
|
@ -1071,6 +1011,10 @@ class PipelineTool(Tool):
|
||||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||||
"""
|
"""
|
||||||
if isinstance(self.pre_processor, str):
|
if isinstance(self.pre_processor, str):
|
||||||
|
if self.pre_processor_class is None:
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
self.pre_processor_class = AutoProcessor
|
||||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||||
|
|
||||||
if isinstance(self.model, str):
|
if isinstance(self.model, str):
|
||||||
|
@ -1079,12 +1023,18 @@ class PipelineTool(Tool):
|
||||||
if self.post_processor is None:
|
if self.post_processor is None:
|
||||||
self.post_processor = self.pre_processor
|
self.post_processor = self.pre_processor
|
||||||
elif isinstance(self.post_processor, str):
|
elif isinstance(self.post_processor, str):
|
||||||
|
if self.post_processor_class is None:
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
self.post_processor_class = AutoProcessor
|
||||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||||
|
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
if self.device_map is not None:
|
if self.device_map is not None:
|
||||||
self.device = list(self.model.hf_device_map.values())[0]
|
self.device = list(self.model.hf_device_map.values())[0]
|
||||||
else:
|
else:
|
||||||
|
from accelerate import PartialState
|
||||||
|
|
||||||
self.device = PartialState().default_device
|
self.device = PartialState().default_device
|
||||||
|
|
||||||
if self.device_map is None:
|
if self.device_map is None:
|
||||||
|
@ -1115,6 +1065,7 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate.utils import send_to_device
|
||||||
|
|
||||||
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib.util
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
@ -22,26 +21,15 @@ from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from transformers.utils import (
|
from huggingface_hub.utils import is_torch_available
|
||||||
is_torch_available,
|
from PIL import Image
|
||||||
is_vision_available,
|
from PIL.Image import Image as ImageType
|
||||||
)
|
|
||||||
|
from .utils import _is_package_available
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_vision_available():
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
else:
|
|
||||||
ImageType = object
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
else:
|
|
||||||
Tensor = object
|
|
||||||
|
|
||||||
|
|
||||||
class AgentType:
|
class AgentType:
|
||||||
"""
|
"""
|
||||||
|
@ -94,9 +82,6 @@ class AgentImage(AgentType, ImageType):
|
||||||
AgentType.__init__(self, value)
|
AgentType.__init__(self, value)
|
||||||
ImageType.__init__(self)
|
ImageType.__init__(self)
|
||||||
|
|
||||||
if not is_vision_available():
|
|
||||||
raise ImportError("PIL must be installed in order to handle images.")
|
|
||||||
|
|
||||||
self._path = None
|
self._path = None
|
||||||
self._raw = None
|
self._raw = None
|
||||||
self._tensor = None
|
self._tensor = None
|
||||||
|
@ -109,11 +94,15 @@ class AgentImage(AgentType, ImageType):
|
||||||
self._raw = Image.open(BytesIO(value))
|
self._raw = Image.open(BytesIO(value))
|
||||||
elif isinstance(value, (str, pathlib.Path)):
|
elif isinstance(value, (str, pathlib.Path)):
|
||||||
self._path = value
|
self._path = value
|
||||||
elif isinstance(value, torch.Tensor):
|
elif is_torch_available():
|
||||||
self._tensor = value
|
import torch
|
||||||
elif isinstance(value, np.ndarray):
|
|
||||||
self._tensor = torch.from_numpy(value)
|
if isinstance(value, torch.Tensor):
|
||||||
else:
|
self._tensor = value
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
self._tensor = torch.from_numpy(value)
|
||||||
|
|
||||||
|
if self._path is None and self._raw is None and self._tensor is None:
|
||||||
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||||
|
|
||||||
def _ipython_display_(self, include=None, exclude=None):
|
def _ipython_display_(self, include=None, exclude=None):
|
||||||
|
@ -183,10 +172,12 @@ class AgentAudio(AgentType, str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value, samplerate=16_000):
|
def __init__(self, value, samplerate=16_000):
|
||||||
if importlib.util.find_spec("soundfile") is None:
|
if not _is_package_available("soundfile") or not is_torch_available:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
|
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
|
||||||
)
|
)
|
||||||
|
import torch
|
||||||
|
|
||||||
super().__init__(value)
|
super().__init__(value)
|
||||||
|
|
||||||
self._path = None
|
self._path = None
|
||||||
|
@ -223,6 +214,8 @@ class AgentAudio(AgentType, str):
|
||||||
if self._tensor is not None:
|
if self._tensor is not None:
|
||||||
return self._tensor
|
return self._tensor
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
if self._path is not None:
|
if self._path is not None:
|
||||||
if "://" in str(self._path):
|
if "://" in str(self._path):
|
||||||
response = requests.get(self._path)
|
response = requests.get(self._path)
|
||||||
|
@ -250,15 +243,7 @@ class AgentAudio(AgentType, str):
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
|
|
||||||
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
_AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||||
INSTANCE_TYPE_MAPPING = {
|
|
||||||
str: AgentText,
|
|
||||||
ImageType: AgentImage,
|
|
||||||
Tensor: AgentAudio,
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_input_types(*args, **kwargs):
|
def handle_agent_input_types(*args, **kwargs):
|
||||||
|
@ -268,17 +253,22 @@ def handle_agent_input_types(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_output_types(output, output_type=None):
|
def handle_agent_output_types(output, output_type=None):
|
||||||
if output_type in AGENT_TYPE_MAPPING:
|
if output_type in _AGENT_TYPE_MAPPING:
|
||||||
# If the class has defined outputs, we can map directly according to the class definition
|
# If the class has defined outputs, we can map directly according to the class definition
|
||||||
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
decoded_outputs = _AGENT_TYPE_MAPPING[output_type](output)
|
||||||
return decoded_outputs
|
return decoded_outputs
|
||||||
else:
|
|
||||||
# If the class does not have defined output, then we map according to the type
|
# If the class does not have defined output, then we map according to the type
|
||||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
if isinstance(output, str):
|
||||||
if isinstance(output, _k):
|
return AgentText(output)
|
||||||
if _k is not object: # avoid converting to audio if torch is not installed
|
if isinstance(output, ImageType):
|
||||||
return _v(output)
|
return AgentImage(output)
|
||||||
return output
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
return AgentAudio(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
|
__all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
|
||||||
|
|
|
@ -15,18 +15,30 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import ast
|
import ast
|
||||||
|
import importlib.metadata
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
||||||
def is_pygments_available():
|
@lru_cache
|
||||||
return importlib.util.find_spec("soundfile") is not None
|
def _is_package_available(package_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
importlib.metadata.version(package_name)
|
||||||
|
return True
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _is_pillow_available():
|
||||||
|
return importlib.util.find_spec("PIL") is not None
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
|
@ -17,7 +17,7 @@ import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
from smolagents.types import _AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = ["2 * 2"]
|
inputs = ["2 * 2"]
|
||||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
def test_agent_types_inputs(self):
|
||||||
|
@ -56,13 +56,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||||
input_type = expected_input["type"]
|
input_type = expected_input["type"]
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list):
|
||||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
_inputs.append([_AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||||
else:
|
else:
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
_inputs.append(_AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
# Should not raise an error
|
# Should not raise an error
|
||||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
def test_imports_work(self):
|
def test_imports_work(self):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from transformers import is_torch_available
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch
|
from transformers.testing_utils import get_tests_dir, require_torch
|
||||||
|
|
||||||
from smolagents.default_tools import FinalAnswerTool
|
from smolagents.default_tools import FinalAnswerTool
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
from smolagents.types import _AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
@ -55,5 +55,5 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
inputs = self.create_inputs()
|
inputs = self.create_inputs()
|
||||||
for input_type, input in inputs.items():
|
for input_type, input in inputs.items():
|
||||||
output = self.tool(**input, sanitize_inputs_outputs=True)
|
output = self.tool(**input, sanitize_inputs_outputs=True)
|
||||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
agent_type = _AGENT_TYPE_MAPPING[input_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
|
@ -21,16 +21,12 @@ from unittest.mock import MagicMock, patch
|
||||||
import mcp
|
import mcp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
|
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
|
||||||
from smolagents.types import (
|
from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||||
AGENT_TYPE_MAPPING,
|
|
||||||
AgentAudio,
|
|
||||||
AgentImage,
|
|
||||||
AgentText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
@ -96,7 +92,7 @@ class ToolTesterMixin:
|
||||||
inputs = create_inputs(self.tool.inputs)
|
inputs = create_inputs(self.tool.inputs)
|
||||||
output = self.tool(**inputs, sanitize_inputs_outputs=True)
|
output = self.tool(**inputs, sanitize_inputs_outputs=True)
|
||||||
if self.tool.output_type != "any":
|
if self.tool.output_type != "any":
|
||||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
agent_type = _AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -121,4 +121,3 @@ class AgentTextTests(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(string, agent_type.to_string())
|
self.assertEqual(string, agent_type.to_string())
|
||||||
self.assertEqual(string, agent_type.to_raw())
|
self.assertEqual(string, agent_type.to_raw())
|
||||||
self.assertEqual(string, agent_type)
|
|
||||||
|
|
Loading…
Reference in New Issue