Small fixes in tool pushing
This commit is contained in:
		
							parent
							
								
									6f13e78aac
								
							
						
					
					
						commit
						4ea3c00b45
					
				|  | @ -33,7 +33,7 @@ Note that with this definition, "agent" is not a discrete, 0 or 1 definition: in | ||||||
| Then it can get more agentic. | Then it can get more agentic. | ||||||
| 
 | 
 | ||||||
| - If you use an LLM output to determine which function is run and with which arguments, that's tool calling. | - If you use an LLM output to determine which function is run and with which arguments, that's tool calling. | ||||||
| - If you use an LLM output to determine if you should keep iterating in a while loop, you get a multi-step agent. | - If you use an LLM output to determine if you should keep iterating in a while loop, you have a multi-step agent. | ||||||
| 
 | 
 | ||||||
| | Agency Level | Description | How that's called | Example Pattern | | | Agency Level | Description | How that's called | Example Pattern | | ||||||
| |-------------|-------------|-------------|-----------------| | |-------------|-------------|-------------|-----------------| | ||||||
|  | @ -52,9 +52,9 @@ One type of agentic system is quite simple: the multi-step agent. It has this st | ||||||
| ```python | ```python | ||||||
| memory = [user_defined_task] | memory = [user_defined_task] | ||||||
| while llm_should_continue(memory): # this loop is the multi-step part | while llm_should_continue(memory): # this loop is the multi-step part | ||||||
| 		action = llm_get_next_action(memory) # this is the tool-calling part |     action = llm_get_next_action(memory) # this is the tool-calling part | ||||||
| 		observations = execute_action(action) |     observations = execute_action(action) | ||||||
| 		memory += [action, observations] |     memory += [action, observations] | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| This agentic system just runs in a loop, execution a new action at each step (the action can involve calling some pre-determined *tools* that are just functions), until its observations make it apparent that a satisfactory state has been reached to solve the given task. | This agentic system just runs in a loop, execution a new action at each step (the action can involve calling some pre-determined *tools* that are just functions), until its observations make it apparent that a satisfactory state has been reached to solve the given task. | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ import importlib | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
| import os | import os | ||||||
|  | import sys | ||||||
| import tempfile | import tempfile | ||||||
| import torch | import torch | ||||||
| import textwrap | import textwrap | ||||||
|  | @ -268,16 +269,19 @@ class Tool: | ||||||
|         # Save tool file |         # Save tool file | ||||||
|         if type(self).__name__ == "SimpleTool": |         if type(self).__name__ == "SimpleTool": | ||||||
|             # Check that imports are self-contained |             # Check that imports are self-contained | ||||||
|             forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward))) |             source_code = inspect.getsource(self.forward).replace("@tool", "") | ||||||
|  |             forward_node = ast.parse(textwrap.dedent(source_code)) | ||||||
|             # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code |             # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code | ||||||
|             method_checker = MethodChecker(set()) |             method_checker = MethodChecker(set()) | ||||||
|             method_checker.visit(forward_node) |             method_checker.visit(forward_node) | ||||||
|  | 
 | ||||||
|             if len(method_checker.errors) > 0: |             if len(method_checker.errors) > 0: | ||||||
|                 raise (ValueError("\n".join(method_checker.errors))) |                 raise (ValueError("\n".join(method_checker.errors))) | ||||||
| 
 | 
 | ||||||
|             forward_source_code = inspect.getsource(self.forward) |             forward_source_code = inspect.getsource(self.forward) | ||||||
|             tool_code = textwrap.dedent(f""" |             tool_code = textwrap.dedent(f""" | ||||||
|             from smolagents import Tool |             from smolagents import Tool | ||||||
|  |             from typing import Optional | ||||||
| 
 | 
 | ||||||
|             class {class_name}(Tool): |             class {class_name}(Tool): | ||||||
|                 name = "{self.name}" |                 name = "{self.name}" | ||||||
|  | @ -319,7 +323,7 @@ class Tool: | ||||||
|             tool_code = instance_to_source(self, base_cls=Tool) |             tool_code = instance_to_source(self, base_cls=Tool) | ||||||
| 
 | 
 | ||||||
|         with open(tool_file, "w", encoding="utf-8") as f: |         with open(tool_file, "w", encoding="utf-8") as f: | ||||||
|             f.write(tool_code) |             f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}")) | ||||||
| 
 | 
 | ||||||
|         # Save app file |         # Save app file | ||||||
|         app_file = os.path.join(output_dir, "app.py") |         app_file = os.path.join(output_dir, "app.py") | ||||||
|  | @ -327,6 +331,7 @@ class Tool: | ||||||
|             f.write( |             f.write( | ||||||
|                 textwrap.dedent(f""" |                 textwrap.dedent(f""" | ||||||
|             from smolagents import launch_gradio_demo |             from smolagents import launch_gradio_demo | ||||||
|  |             from typing import Optional | ||||||
|             from tool import {class_name} |             from tool import {class_name} | ||||||
| 
 | 
 | ||||||
|             tool = {class_name}() |             tool = {class_name}() | ||||||
|  | @ -341,9 +346,17 @@ class Tool: | ||||||
|         imports = [] |         imports = [] | ||||||
|         for module in [tool_file]: |         for module in [tool_file]: | ||||||
|             imports.extend(get_imports(module)) |             imports.extend(get_imports(module)) | ||||||
|         imports = list(set(imports)) |         imports = list( | ||||||
|  |             set( | ||||||
|  |                 [ | ||||||
|  |                     el | ||||||
|  |                     for el in imports + ["smolagents"] | ||||||
|  |                     if el not in sys.stdlib_module_names | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|         with open(requirements_file, "w", encoding="utf-8") as f: |         with open(requirements_file, "w", encoding="utf-8") as f: | ||||||
|             f.write("agents_package\n" + "\n".join(imports) + "\n") |             f.write("\n".join(imports) + "\n") | ||||||
| 
 | 
 | ||||||
|     def push_to_hub( |     def push_to_hub( | ||||||
|         self, |         self, | ||||||
|  |  | ||||||
|  | @ -34,15 +34,16 @@ def is_pygments_available(): | ||||||
| console = Console() | console = Console() | ||||||
| 
 | 
 | ||||||
| BASE_BUILTIN_MODULES = [ | BASE_BUILTIN_MODULES = [ | ||||||
|     "random", |  | ||||||
|     "collections", |     "collections", | ||||||
|     "math", |     "datetime", | ||||||
|     "time", |  | ||||||
|     "queue", |  | ||||||
|     "itertools", |     "itertools", | ||||||
|  |     "math", | ||||||
|  |     "queue", | ||||||
|  |     "random", | ||||||
|     "re", |     "re", | ||||||
|     "stat", |     "stat", | ||||||
|     "statistics", |     "statistics", | ||||||
|  |     "time", | ||||||
|     "unicodedata", |     "unicodedata", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue