Small fixes in tool pushing

This commit is contained in:
Aymeric 2024-12-30 15:05:06 +01:00
parent 6f13e78aac
commit 4ea3c00b45
3 changed files with 26 additions and 12 deletions

View File

@ -33,7 +33,7 @@ Note that with this definition, "agent" is not a discrete, 0 or 1 definition: in
Then it can get more agentic. Then it can get more agentic.
- If you use an LLM output to determine which function is run and with which arguments, that's tool calling. - If you use an LLM output to determine which function is run and with which arguments, that's tool calling.
- If you use an LLM output to determine if you should keep iterating in a while loop, you get a multi-step agent. - If you use an LLM output to determine if you should keep iterating in a while loop, you have a multi-step agent.
| Agency Level | Description | How that's called | Example Pattern | | Agency Level | Description | How that's called | Example Pattern |
|-------------|-------------|-------------|-----------------| |-------------|-------------|-------------|-----------------|

View File

@ -19,6 +19,7 @@ import importlib
import inspect import inspect
import json import json
import os import os
import sys
import tempfile import tempfile
import torch import torch
import textwrap import textwrap
@ -268,16 +269,19 @@ class Tool:
# Save tool file # Save tool file
if type(self).__name__ == "SimpleTool": if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained # Check that imports are self-contained
forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward))) source_code = inspect.getsource(self.forward).replace("@tool", "")
forward_node = ast.parse(textwrap.dedent(source_code))
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
method_checker = MethodChecker(set()) method_checker = MethodChecker(set())
method_checker.visit(forward_node) method_checker.visit(forward_node)
if len(method_checker.errors) > 0: if len(method_checker.errors) > 0:
raise (ValueError("\n".join(method_checker.errors))) raise (ValueError("\n".join(method_checker.errors)))
forward_source_code = inspect.getsource(self.forward) forward_source_code = inspect.getsource(self.forward)
tool_code = textwrap.dedent(f""" tool_code = textwrap.dedent(f"""
from smolagents import Tool from smolagents import Tool
from typing import Optional
class {class_name}(Tool): class {class_name}(Tool):
name = "{self.name}" name = "{self.name}"
@ -319,7 +323,7 @@ class Tool:
tool_code = instance_to_source(self, base_cls=Tool) tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f: with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code) f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}"))
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
@ -327,6 +331,7 @@ class Tool:
f.write( f.write(
textwrap.dedent(f""" textwrap.dedent(f"""
from smolagents import launch_gradio_demo from smolagents import launch_gradio_demo
from typing import Optional
from tool import {class_name} from tool import {class_name}
tool = {class_name}() tool = {class_name}()
@ -341,9 +346,17 @@ class Tool:
imports = [] imports = []
for module in [tool_file]: for module in [tool_file]:
imports.extend(get_imports(module)) imports.extend(get_imports(module))
imports = list(set(imports)) imports = list(
set(
[
el
for el in imports + ["smolagents"]
if el not in sys.stdlib_module_names
]
)
)
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("agents_package\n" + "\n".join(imports) + "\n") f.write("\n".join(imports) + "\n")
def push_to_hub( def push_to_hub(
self, self,

View File

@ -34,15 +34,16 @@ def is_pygments_available():
console = Console() console = Console()
BASE_BUILTIN_MODULES = [ BASE_BUILTIN_MODULES = [
"random",
"collections", "collections",
"math", "datetime",
"time",
"queue",
"itertools", "itertools",
"math",
"queue",
"random",
"re", "re",
"stat", "stat",
"statistics", "statistics",
"time",
"unicodedata", "unicodedata",
] ]