Consolidate pushing Tools to Hub

This commit is contained in:
Aymeric 2024-12-19 16:05:17 +01:00
parent 00b9a71453
commit 584ce8f363
17 changed files with 506 additions and 213 deletions

View File

@ -1,5 +1,5 @@
# Base Python image # Base Python image
FROM python:3.12-slim FROM python:3.9-slim
# Set working directory # Set working directory
WORKDIR /app WORKDIR /app
@ -25,4 +25,7 @@ RUN pip install -e .
COPY server.py /app/server.py COPY server.py /app/server.py
# Expose the port your server will run on
EXPOSE 65432
CMD ["python", "/app/server.py"] CMD ["python", "/app/server.py"]

View File

@ -27,3 +27,20 @@ limitations under the License.
<h3 align="center"> <h3 align="center">
<p>Run agents! <p>Run agents!
</h3> </h3>
W
<div class="flex justify-center">
<img
class="block dark:hidden"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
/>
<img
class="hidden dark:block"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
/>
</div>
To run Docker, run `docker build -t pyrunner:latest .`
This will use the Local Dockerfile to create your Docker image!

View File

@ -23,11 +23,11 @@ Here, we're going to see advanced tool usage.
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents). > If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
### Directly define a tool by subclassing Tool, and share it to the Hub ### Directly define a tool by subclassing Tool
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator. Let's take again the tool example from the [quicktour](../quicktour), for which we had implemented a `@tool` decorator. The `tool` decorator is the standard format, but sometimes you need more: use several methods in a class for more clarity, or using additional class attributes.
If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass. In this case, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
The custom tool needs: The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`. - An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
@ -67,19 +67,28 @@ tool = HFModelDownloadsTool()
Now the custom `HfModelDownloadsTool` class is ready. Now the custom `HfModelDownloadsTool` class is ready.
### Share your tool to the Hub
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python ```python
tool.push_to_hub("m-ric/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>") tool.push_to_hub("{your_username}/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>")
``` ```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. For the push to Hub to work, your tool will need to respect some rules:
- All method are self-contained, e.g. use variables that come either from their args,
- If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`.
Once your tool is pushed to Hub, you can load it with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`. Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
```python ```python
from agents import load_tool, CodeAgent from agents import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads", trust_remote_code=True) model_download_tool = load_tool(
"{your_username}/hf-model-downloads",
trust_remote_code=True
)
``` ```
### Import a Space as a tool 🚀 ### Import a Space as a tool 🚀
@ -215,7 +224,3 @@ agent.run("Please draw me a picture of rivers and lakes.")
``` ```
To speed up the start, tools are loaded only if called by the agent. To speed up the start, tools are loaded only if called by the agent.
This gets you this image:
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png">

View File

@ -1,9 +1,8 @@
from agents.search import DuckDuckGoSearchTool from agents.tools.search import DuckDuckGoSearchTool
from agents.docker_alternative import DockerPythonInterpreter from agents.docker_alternative import DockerPythonInterpreter
test = """ from agents.tool import Tool
from agents.tools import Tool
class DummyTool(Tool): class DummyTool(Tool):
name = "echo" name = "echo"
@ -17,7 +16,6 @@ class DummyTool(Tool):
def forward(self, cmd: str) -> str: def forward(self, cmd: str) -> str:
return cmd return cmd
"""
container = DockerPythonInterpreter() container = DockerPythonInterpreter()
@ -30,10 +28,8 @@ breakpoint()
print("---------") print("---------")
output = container.execute(test)
print(output)
output = container.execute("res = DummyTool(cmd='echo this'); print(res)") output = container.execute("res = DummyTool(cmd='echo this'); print(res())")
print(output) print(output)
container.stop() container.stop()

View File

@ -1,4 +1,4 @@
from agents.tools import Tool from agents.tool import Tool
class DummyTool(Tool): class DummyTool(Tool):

View File

@ -30,11 +30,12 @@ if TYPE_CHECKING:
from .local_python_executor import * from .local_python_executor import *
from .monitoring import * from .monitoring import *
from .prompts import * from .prompts import *
from .search import * from .tools.search import *
from .tools import * from .tool import *
from .types import * from .types import *
from .utils import * from .utils import *
else: else:
import sys import sys

View File

@ -43,7 +43,7 @@ from .prompts import (
SYSTEM_PROMPT_PLAN, SYSTEM_PROMPT_PLAN,
) )
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import ( from .tool import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool, Tool,
get_tool_description_with_args, get_tool_description_with_args,

View File

@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode from transformers.utils import is_offline_mode
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, Tool from .tool import TOOL_CONFIG_FILE, Tool
def custom_print(*args): def custom_print(*args):

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import warnings import warnings
import socket import socket
from agents.tools import Tool from agents.tool import Tool
class DockerPythonInterpreter: class DockerPythonInterpreter:
def __init__(self): def __init__(self):

View File

@ -47,8 +47,10 @@ from transformers.utils import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_inputs, handle_agent_outputs from .types import ImageType, handle_agent_inputs, handle_agent_outputs
from .utils import ImportFinder from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker
import logging import logging
@ -97,14 +99,6 @@ def setup_default_tools():
return default_tools return default_tools
# docstyle-ignore
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from tool import {class_name}
launch_gradio_demo({class_name})
"""
def validate_after_init(cls, do_validate_forward: bool = True): def validate_after_init(cls, do_validate_forward: bool = True):
original_init = cls.__init__ original_init = cls.__init__
@ -117,112 +111,6 @@ def validate_after_init(cls, do_validate_forward: bool = True):
return cls return cls
def validate_args_are_self_contained(source_code):
"""Validates that all names in forward method are properly defined.
In particular it will check that all imports are done within the function."""
print("CODDDD", source_code)
tree = ast.parse(textwrap.dedent(source_code))
# Get function arguments
func_node = tree.body[0]
arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"}
builtin_names = set(vars(builtins))
class NameChecker(ast.NodeVisitor):
def __init__(self):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
self.assigned_names = set()
def visit_Import(self, node):
"""Handle simple imports like 'import datetime'."""
for name in node.names:
actual_name = name.asname or name.name
self.imports[actual_name] = (name.name, actual_name)
def visit_ImportFrom(self, node):
"""Handle from imports like 'from datetime import datetime'."""
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name, actual_name)
def visit_Assign(self, node):
"""Track variable assignments."""
for target in node.targets:
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
self.assigned_names.add(node.target.id)
if node.value:
self.visit(node.value)
def _handle_for_target(self, target) -> Set[str]:
"""Extract all names from a for loop target."""
names = set()
if isinstance(target, ast.Name):
names.add(target.id)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
names.add(elt.id)
return names
def visit_For(self, node):
"""Track for-loop target variables and handle enumerate specially."""
# Add names from the target
target_names = self._handle_for_target(node.target)
self.assigned_names.update(target_names)
# Special handling for enumerate
if (
isinstance(node.iter, ast.Call)
and isinstance(node.iter.func, ast.Name)
and node.iter.func.id == "enumerate"
):
# For enumerate, if we have "for i, x in enumerate(...)",
# both i and x should be marked as assigned
if isinstance(node.target, ast.Tuple):
for elt in node.target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
# Visit the rest of the node
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load) and not (
node.id == "tool"
or node.id in builtin_names
or node.id in arg_names
or node.id == "self"
or node.id in self.assigned_names
):
if node.id not in self.from_imports and node.id not in self.imports:
self.undefined_names.add(node.id)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
checker = NameChecker()
checker.visit(tree)
if checker.undefined_names:
raise ValueError(
f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
Make sure all imports and variables are self-contained within the method.
"""
)
AUTHORIZED_TYPES = [ AUTHORIZED_TYPES = [
"string", "string",
"boolean", "boolean",
@ -339,64 +227,79 @@ class Tool:
""" """
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
class_name = self.__class__.__name__ class_name = self.__class__.__name__
tool_file = os.path.join(output_dir, "tool.py")
# Save tool file # Save tool file
forward_source_code = inspect.getsource(self.forward) if type(self).__name__ == "SimpleTool":
validate_args_are_self_contained(forward_source_code) # Check that imports are self-contained
tool_code = f""" forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward)))
from agents import Tool # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
method_checker = MethodChecker(set())
method_checker.visit(forward_node)
if len(method_checker.errors) > 0:
raise(ValueError("\n".join(method_checker.errors)))
class {class_name}(Tool): forward_source_code = inspect.getsource(self.forward)
name = "{self.name}" tool_code = textwrap.dedent(f"""
description = \"\"\"{self.description}\"\"\" from agents import Tool
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
output_type = "{self.output_type}"
""".strip()
def add_self_argument(source_code: str) -> str: class {class_name}(Tool):
"""Add 'self' as first argument to a function definition if not present.""" name = "{self.name}"
pattern = r"def forward\(((?!self)[^)]*)\)" description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
output_type = "{self.output_type}"
""").strip()
import re
def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present."""
pattern = r'def forward\(((?!self)[^)]*)\)'
def replacement(match): def replacement(match):
args = match.group(1).strip() args = match.group(1).strip()
if args: # If there are other arguments if args: # If there are other arguments
return f"def forward(self, {args})" return f'def forward(self, {args})'
return "def forward(self)" return 'def forward(self)'
return re.sub(pattern, replacement, source_code) return re.sub(pattern, replacement, source_code)
forward_source_code = forward_source_code.replace(self.name, "forward") forward_source_code = forward_source_code.replace(self.name, "forward")
forward_source_code = add_self_argument(forward_source_code) forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip() forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
f.write(tool_code)
# Save config file with open(tool_file, "w", encoding="utf-8") as f:
config_file = os.path.join(output_dir, "tool_config.json") f.write(tool_code)
tool_config = { else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
"tool_class": self.__class__.__name__, if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
"description": self.description, raise ValueError(
"name": self.name, f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
"inputs": self.inputs, )
"output_type": str(self.output_type),
} validate_tool_attributes(self.__class__)
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n") tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f: with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(class_name=class_name)) f.write(textwrap.dedent(f"""
from agents import launch_gradio_demo
from tool import {class_name}
tool = {class_name}()
launch_gradio_demo(tool)
""").lstrip())
# Save requirements file # Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
tree = ast.parse(forward_source_code) imports = []
import_finder = ImportFinder() for module in [tool_file]:
import_finder.visit(tree) imports.extend(get_imports(module))
imports = list(set(imports))
imports = list(set(import_finder.packages))
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("agents_package\n" + "\n".join(imports) + "\n") f.write("agents_package\n" + "\n".join(imports) + "\n")
@ -612,7 +515,6 @@ class {class_name}(Tool):
``` ```
""" """
from gradio_client import Client, handle_file from gradio_client import Client, handle_file
from gradio_client.utils import is_http_url_like
class SpaceToolWrapper(Tool): class SpaceToolWrapper(Tool):
def __init__( def __init__(
@ -665,6 +567,7 @@ class {class_name}(Tool):
self.is_initialized = True self.is_initialized = True
def sanitize_argument_for_prediction(self, arg): def sanitize_argument_for_prediction(self, arg):
from gradio_client.utils import is_http_url_like
if isinstance(arg, ImageType): if isinstance(arg, ImageType):
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name) arg.save(temp_file.name)
@ -793,13 +696,13 @@ def compile_jinja_template(template):
return jinja_env.from_string(template) return jinja_env.from_string(template)
def launch_gradio_demo(tool_class: Tool): def launch_gradio_demo(tool: Tool):
""" """
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
`inputs` and `output_type`. `inputs` and `output_type`.
Args: Args:
tool_class (`type`): The class of the tool for which to launch the demo. tool (`type`): The tool for which to launch the demo.
""" """
try: try:
import gradio as gr import gradio as gr
@ -808,11 +711,6 @@ def launch_gradio_demo(tool_class: Tool):
"Gradio should be installed in order to launch a gradio demo." "Gradio should be installed in order to launch a gradio demo."
) )
tool = tool_class()
def fn(*args, **kwargs):
return tool(*args, **kwargs)
TYPE_TO_COMPONENT_CLASS_MAPPING = { TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image, "image": gr.Image,
"audio": gr.Audio, "audio": gr.Audio,
@ -822,7 +720,7 @@ def launch_gradio_demo(tool_class: Tool):
} }
gradio_inputs = [] gradio_inputs = []
for input_name, input_details in tool_class.inputs.items(): for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"] input_details["type"]
] ]
@ -830,15 +728,15 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs.append(new_component) gradio_inputs.append(new_component)
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[ output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
tool_class.output_type tool.output_type
] ]
gradio_output = output_gradio_componentclass(label=input_name) gradio_output = output_gradio_componentclass(label="Output")
gr.Interface( gr.Interface(
fn=fn, fn=tool, # This works because `tool` has a __call__ method
inputs=gradio_inputs, inputs=gradio_inputs,
outputs=gradio_output, outputs=gradio_output,
title=tool_class.__name__, title=tool.name,
article=tool.description, article=tool.description,
).launch() ).launch()
@ -1027,29 +925,34 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException( raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!" "Tool return type not found: make sure your function has a return type hint!"
) )
class_name = "".join([el.title() for el in parameters["name"].split("_")])
if parameters["return"]["type"] == "object": if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any" parameters["return"]["type"] = "any"
class SpecificTool(Tool): class SimpleTool(Tool):
name = parameters["name"] def __init__(self, name, description, inputs, output_type, function):
description = parameters["description"] self.name = name
inputs = parameters["parameters"]["properties"] self.description = description
output_type = parameters["return"]["type"] self.inputs = inputs
self.output_type = output_type
@wraps(tool_function) self.forward = function
def forward(self, *args, **kwargs): self.is_initialized = True
return tool_function(*args, **kwargs)
simple_tool = SimpleTool(
parameters["name"],
parameters["description"],
parameters["parameters"]["properties"],
parameters["return"]["type"],
function=tool_function
)
original_signature = inspect.signature(tool_function) original_signature = inspect.signature(tool_function)
new_parameters = [ new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + list(original_signature.parameters.values()) ] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters) new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature simple_tool.forward.__signature__ = new_signature
SpecificTool.__name__ = class_name # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
return SpecificTool() return simple_tool
HUGGINGFACE_DEFAULT_TOOLS = {} HUGGINGFACE_DEFAULT_TOOLS = {}

View File

@ -0,0 +1,191 @@
import ast
import inspect
import importlib.util
import builtins
from pathlib import Path
from typing import List, Set, Dict
import textwrap
_BUILTIN_NAMES = set(vars(builtins))
def is_local_import(module_name: str) -> bool:
"""
Check if an import is from a local file or a package.
Returns True if it's a local file import.
"""
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return True # If we can't find the module, assume it's local
# If the module is found and has a file path, check if it's in site-packages
if spec.origin and 'site-packages' not in spec.origin:
# Check if it's a .py file in the current directory or subdirectories
return spec.origin.endswith('.py')
return False
except ImportError:
return True # If there's an import error, assume it's local
class MethodChecker(ast.NodeVisitor):
"""
Checks that a method
- only uses defined names
- contains no local imports (e.g. numpy is ok but local_script is not)
"""
def __init__(self, class_attributes: Set[str]):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
self.assigned_names = set()
self.arg_names = set()
self.class_attributes = class_attributes
self.errors = []
def visit_arguments(self, node):
"""Collect function arguments"""
self.arg_names = {arg.arg for arg in node.args}
if node.kwarg:
self.arg_names.add(node.kwarg.arg)
if node.vararg:
self.arg_names.add(node.vararg.arg)
def visit_Import(self, node):
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(actual_name):
self.errors.append(f"Local import '{actual_name}'")
self.imports[actual_name] = name.name
def visit_ImportFrom(self, node):
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(module):
self.errors.append(f"Local import '{module}'")
self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
self.assigned_names.add(node.target.id)
if node.value:
self.visit(node.value)
def visit_For(self, node):
target = node.target
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
self.generic_visit(node)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
if not (
node.id in _BUILTIN_NAMES
or node.id in self.arg_names
or node.id == "self"
or node.id in self.class_attributes
or node.id in self.imports
or node.id in self.from_imports
or node.id in self.assigned_names
):
self.errors.append(f"Name '{node.id}' is undefined.")
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if not (
node.func.id in _BUILTIN_NAMES
or node.func.id in self.arg_names
or node.func.id == "self"
or node.func.id in self.class_attributes
or node.func.id in self.imports
or node.func.id in self.from_imports
or node.func.id in self.assigned_names
):
self.errors.append(f"Name '{node.func.id}' is undefined.")
self.generic_visit(node)
def validate_tool_attributes(cls) -> None:
"""
Validates that a Tool class follows the proper patterns:
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
1. About the class:
- Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes
2. About all class methods:
- Imports must be from packages, not local files
- All methods must be self-contained
Raises all errors encountered, if no error returns None.
"""
errors = []
source = textwrap.dedent(inspect.getsource(cls))
tree = ast.parse(source)
if not isinstance(tree.body[0], ast.ClassDef):
raise ValueError("Source code must define a class")
# Check that __init__ method takes no arguments
if not cls.__init__.__qualname__ == 'Tool.__init__':
sig = inspect.signature(cls.__init__)
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
if len(non_self_params) > 0:
errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!")
class_node = tree.body[0]
class ClassLevelChecker(ast.NodeVisitor):
def __init__(self):
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
def visit_Assign(self, node):
# Track class attributes
for target in node.targets:
if isinstance(target, ast.Name):
self.class_attributes.add(target.id)
# Check if the assignment is more complex than simple literals
if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
for val in ast.walk(node.value)):
for target in node.targets:
if isinstance(target, ast.Name):
self.complex_attributes.add(target.id)
class_level_checker = ClassLevelChecker()
class_level_checker.visit(class_node)
if class_level_checker.complex_attributes:
errors.append(
f"Complex attributes should be defined in __init__, not as class attributes: "
f"{', '.join(class_level_checker.complex_attributes)}"
)
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors:
raise ValueError("Tool validation failed:\n" + "\n".join(errors))
return

View File

View File

@ -19,7 +19,7 @@ import re
import requests import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
from .tools import Tool from ..tools import Tool
class DuckDuckGoSearchTool(Tool): class DuckDuckGoSearchTool(Tool):

View File

@ -19,6 +19,9 @@ import re
from typing import Tuple, Dict, Union from typing import Tuple, Dict, Union
import ast import ast
from rich.console import Console from rich.console import Console
import ast
import inspect
import types
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
@ -127,5 +130,114 @@ class ImportFinder(ast.NodeVisitor):
base_package = node.module.split(".")[0] base_package = node.module.split(".")[0]
self.packages.add(base_package) self.packages.add(base_package)
import ast
import builtins
from typing import Set, Dict, List
def get_method_source(method):
"""Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType):
method = method.__func__
return inspect.getsource(method).strip()
def is_same_method(method1, method2):
"""Compare two methods by their source code."""
try:
source1 = get_method_source(method1)
source2 = get_method_source(method2)
# Remove method decorators if any
source1 = '\n'.join(line for line in source1.split('\n')
if not line.strip().startswith('@'))
source2 = '\n'.join(line for line in source2.split('\n')
if not line.strip().startswith('@'))
return source1 == source2
except (TypeError, OSError):
return False
def is_same_item(item1, item2):
"""Compare two class items (methods or attributes) for equality."""
if callable(item1) and callable(item2):
return is_same_method(item1, item2)
else:
return item1 == item2
def instance_to_source(instance, base_cls=None):
"""Convert an instance to its class source code representation."""
cls = instance.__class__
class_name = cls.__name__
# Start building class lines
class_lines = []
if base_cls:
class_lines.append(f"class {class_name}({base_cls.__name__}):")
else:
class_lines.append(f"class {class_name}:")
# Add docstring if it exists and differs from base
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
class_lines.append(f' """{cls.__doc__}"""')
# Add class-level attributes
class_attrs = {
name: value for name, value in cls.__dict__.items()
if not name.startswith('__') and not callable(value) and
not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
}
for name, value in class_attrs.items():
if isinstance(value, str):
class_lines.append(f' {name} = "{value}"')
else:
class_lines.append(f' {name} = {repr(value)}')
if class_attrs:
class_lines.append("")
# Add methods
methods = {
name: func for name, func in cls.__dict__.items()
if callable(func) and
not (base_cls and hasattr(base_cls, name) and
getattr(base_cls, name).__code__.co_code == func.__code__.co_code)
}
for name, method in methods.items():
method_source = inspect.getsource(method)
# Clean up the indentation
method_lines = method_source.split('\n')
first_line = method_lines[0]
indent = len(first_line) - len(first_line.lstrip())
method_lines = [line[indent:] for line in method_lines]
method_source = '\n'.join([' ' + line if line.strip() else line
for line in method_lines])
class_lines.append(method_source)
class_lines.append("")
# Find required imports using ImportFinder
import_finder = ImportFinder()
import_finder.visit(ast.parse('\n'.join(class_lines)))
required_imports = import_finder.packages
# Build final code with imports
final_lines = []
# Add base class import if needed
if base_cls:
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
# Add discovered imports
final_lines.extend(required_imports)
if final_lines: # Add empty line after imports
final_lines.append("")
# Add the class code
final_lines.extend(class_lines)
return '\n'.join(final_lines)
__all__ = [] __all__ = []

View File

@ -29,7 +29,7 @@ from agents.agents import (
Toolbox, Toolbox,
ToolCall, ToolCall,
) )
from agents.tools import tool from agents.tool import tool
from agents.default_tools import PythonInterpreterTool from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir

View File

@ -125,7 +125,7 @@ class TestDocs:
"from_langchain", "from_langchain",
] ]
code_blocks = [ code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token) block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace("{your_username}", "m-ric")
for block in code_blocks for block in code_blocks
if not any( if not any(
[snippet in block for snippet in excluded_snippets] [snippet in block for snippet in excluded_snippets]

View File

@ -26,7 +26,7 @@ from agents.types import (
AgentImage, AgentImage,
AgentText, AgentText,
) )
from agents.tools import Tool, tool, AUTHORIZED_TYPES from agents.tool import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
@ -175,6 +175,8 @@ class ToolTests(unittest.TestCase):
""" """
return str(datetime.now()) return str(datetime.now())
get_current_time.save("output")
assert "datetime" in str(e) assert "datetime" in str(e)
# Also test with classic definition # Also test with classic definition
@ -189,6 +191,9 @@ class ToolTests(unittest.TestCase):
def forward(self): def forward(self):
return str(datetime.now()) return str(datetime.now())
get_current_time = GetCurrentTimeTool()
get_current_time.save("output")
assert "datetime" in str(e) assert "datetime" in str(e)
def test_tool_definition_raises_no_error_imports_in_function(self): def test_tool_definition_raises_no_error_imports_in_function(self):
@ -210,3 +215,63 @@ class ToolTests(unittest.TestCase):
def forward(self): def forward(self):
from datetime import datetime from datetime import datetime
return str(datetime.now()) return str(datetime.now())
def test_saving_tool_allows_no_arg_in_init(self):
# Test one cannot save tool with additional args in init
class FailTool(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def __init__(self, url):
super().__init__(self)
self.url = "none"
def forward(self, string_input):
return self.url + string_input
fail_tool = FailTool("dummy_url")
with pytest.raises(Exception) as e:
fail_tool.save('output')
assert '__init__' in str(e)
def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails
from numpy import random
class FailTool2(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
self.client = random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_2 = FailTool2()
with pytest.raises(Exception) as e:
fail_tool_2.save('output')
assert 'random' in str(e)
# Test that putting these imports inside functions works
class FailTool3(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
from numpy import random
self.client = random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_3 = FailTool3()
fail_tool_3.save('output')