Small fixes in tool pushing

This commit is contained in:
Aymeric 2024-12-30 15:05:06 +01:00
parent 6f13e78aac
commit 4ea3c00b45
3 changed files with 26 additions and 12 deletions

View File

@ -33,7 +33,7 @@ Note that with this definition, "agent" is not a discrete, 0 or 1 definition: in
Then it can get more agentic.
- If you use an LLM output to determine which function is run and with which arguments, that's tool calling.
- If you use an LLM output to determine if you should keep iterating in a while loop, you get a multi-step agent.
- If you use an LLM output to determine if you should keep iterating in a while loop, you have a multi-step agent.
| Agency Level | Description | How that's called | Example Pattern |
|-------------|-------------|-------------|-----------------|
@ -52,9 +52,9 @@ One type of agentic system is quite simple: the multi-step agent. It has this st
```python
memory = [user_defined_task]
while llm_should_continue(memory): # this loop is the multi-step part
action = llm_get_next_action(memory) # this is the tool-calling part
observations = execute_action(action)
memory += [action, observations]
action = llm_get_next_action(memory) # this is the tool-calling part
observations = execute_action(action)
memory += [action, observations]
```
This agentic system just runs in a loop, execution a new action at each step (the action can involve calling some pre-determined *tools* that are just functions), until its observations make it apparent that a satisfactory state has been reached to solve the given task.

View File

@ -19,6 +19,7 @@ import importlib
import inspect
import json
import os
import sys
import tempfile
import torch
import textwrap
@ -268,16 +269,19 @@ class Tool:
# Save tool file
if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained
forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward)))
source_code = inspect.getsource(self.forward).replace("@tool", "")
forward_node = ast.parse(textwrap.dedent(source_code))
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
method_checker = MethodChecker(set())
method_checker.visit(forward_node)
if len(method_checker.errors) > 0:
raise (ValueError("\n".join(method_checker.errors)))
forward_source_code = inspect.getsource(self.forward)
tool_code = textwrap.dedent(f"""
from smolagents import Tool
from typing import Optional
class {class_name}(Tool):
name = "{self.name}"
@ -319,7 +323,7 @@ class Tool:
tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}"))
# Save app file
app_file = os.path.join(output_dir, "app.py")
@ -327,6 +331,7 @@ class Tool:
f.write(
textwrap.dedent(f"""
from smolagents import launch_gradio_demo
from typing import Optional
from tool import {class_name}
tool = {class_name}()
@ -341,9 +346,17 @@ class Tool:
imports = []
for module in [tool_file]:
imports.extend(get_imports(module))
imports = list(set(imports))
imports = list(
set(
[
el
for el in imports + ["smolagents"]
if el not in sys.stdlib_module_names
]
)
)
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("agents_package\n" + "\n".join(imports) + "\n")
f.write("\n".join(imports) + "\n")
def push_to_hub(
self,

View File

@ -34,15 +34,16 @@ def is_pygments_available():
console = Console()
BASE_BUILTIN_MODULES = [
"random",
"collections",
"math",
"time",
"queue",
"datetime",
"itertools",
"math",
"queue",
"random",
"re",
"stat",
"statistics",
"time",
"unicodedata",
]