parent
cf04285cc1
commit
cb9830a554
|
@ -39,10 +39,6 @@ contains the API docs for the underlying classes.
|
|||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### Toolbox
|
||||
|
||||
[[autodoc]] Toolbox
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
||||
|
|
|
@ -187,7 +187,7 @@ from smolagents import HfApiModel
|
|||
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
|
||||
|
||||
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
||||
agent.toolbox.add_tool(model_download_tool)
|
||||
agent.tools.append(model_download_tool)
|
||||
```
|
||||
Now we can leverage the new tool:
|
||||
|
||||
|
@ -202,11 +202,6 @@ agent.run(
|
|||
> Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines.
|
||||
|
||||
|
||||
Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
|
||||
This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
|
||||
Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
|
||||
|
||||
|
||||
### Use a collection of tools
|
||||
|
||||
You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
|
||||
|
|
|
@ -18,13 +18,14 @@ import time
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from rich import box
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
from .default_tools import FinalAnswerTool
|
||||
from .default_tools import FinalAnswerTool, TOOL_MAPPING
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .local_python_executor import (
|
||||
BASE_BUILTIN_MODULES,
|
||||
|
@ -49,7 +50,6 @@ from .prompts import (
|
|||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
Toolbox,
|
||||
get_tool_description_with_args,
|
||||
)
|
||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||
|
@ -107,18 +107,27 @@ class SystemPromptStep(AgentStep):
|
|||
system_prompt: str
|
||||
|
||||
|
||||
def format_prompt_with_tools(
|
||||
toolbox: Toolbox, prompt_template: str, tool_description_template: str
|
||||
def get_tool_descriptions(
|
||||
tools: Dict[str, Tool], tool_description_template: str
|
||||
) -> str:
|
||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||
return "\n".join(
|
||||
[
|
||||
get_tool_description_with_args(tool, tool_description_template)
|
||||
for tool in tools.values()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def format_prompt_with_tools(
|
||||
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
|
||||
) -> str:
|
||||
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
|
||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||
if "{{tool_names}}" in prompt:
|
||||
prompt = prompt.replace(
|
||||
"{{tool_names}}",
|
||||
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
|
||||
", ".join([f"'{tool.name}'" for tool in tools.values()]),
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
|
@ -163,7 +172,7 @@ class MultiStepAgent:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolbox],
|
||||
tools: List[Tool],
|
||||
model: Callable[[List[Dict[str, str]]], str],
|
||||
system_prompt: Optional[str] = None,
|
||||
tool_description_template: Optional[str] = None,
|
||||
|
@ -172,7 +181,7 @@ class MultiStepAgent:
|
|||
add_base_tools: bool = False,
|
||||
verbose: bool = False,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
managed_agents: Optional[Dict] = None,
|
||||
managed_agents: Optional[List] = None,
|
||||
step_callbacks: Optional[List[Callable]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
):
|
||||
|
@ -196,17 +205,18 @@ class MultiStepAgent:
|
|||
|
||||
self.managed_agents = {}
|
||||
if managed_agents is not None:
|
||||
print("NOTNONE")
|
||||
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||
|
||||
if isinstance(tools, Toolbox):
|
||||
self._toolbox = tools
|
||||
self.tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
self._toolbox.add_base_tools(
|
||||
add_python_interpreter=(self.__class__ == ToolCallingAgent)
|
||||
)
|
||||
else:
|
||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||
self._toolbox.add_tool(FinalAnswerTool())
|
||||
for tool_name, tool_class in TOOL_MAPPING.items():
|
||||
if (
|
||||
tool_name != "python_interpreter"
|
||||
or self.__class__.__name__ == "ToolCallingAgent"
|
||||
):
|
||||
self.tools[tool_name] = tool_class()
|
||||
self.tools["final_answer"] = FinalAnswerTool()
|
||||
|
||||
self.system_prompt = self.initialize_system_prompt()
|
||||
self.input_messages = None
|
||||
|
@ -217,14 +227,9 @@ class MultiStepAgent:
|
|||
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
||||
self.step_callbacks.append(self.monitor.update_metrics)
|
||||
|
||||
@property
|
||||
def toolbox(self) -> Toolbox:
|
||||
"""Get the toolbox currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def initialize_system_prompt(self):
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox,
|
||||
self.tools,
|
||||
self.system_prompt_template,
|
||||
self.tool_description_template,
|
||||
)
|
||||
|
@ -384,10 +389,10 @@ class MultiStepAgent:
|
|||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.tools).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||
available_tools = {**self.tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
@ -415,7 +420,7 @@ class MultiStepAgent:
|
|||
raise AgentExecutionError(error_msg)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.toolbox.tools:
|
||||
if tool_name in self.tools:
|
||||
tool_description = get_tool_description_with_args(
|
||||
available_tools[tool_name]
|
||||
)
|
||||
|
@ -512,20 +517,26 @@ You have been provided with these additional arguments, that you can access usin
|
|||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
step_number = 0
|
||||
while final_answer is None and step_number < self.max_steps:
|
||||
self.step_number = 0
|
||||
while final_answer is None and self.step_number < self.max_steps:
|
||||
step_start_time = time.time()
|
||||
step_log = ActionStep(step=step_number, start_time=step_start_time)
|
||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||
try:
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and step_number % self.planning_interval == 0
|
||||
and self.step_number % self.planning_interval == 0
|
||||
):
|
||||
self.planning_step(
|
||||
task, is_first_step=(step_number == 0), step=step_number
|
||||
task,
|
||||
is_first_step=(self.step_number == 0),
|
||||
step=self.step_number,
|
||||
)
|
||||
console.print(
|
||||
Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX)
|
||||
Rule(
|
||||
f"[bold]Step {self.step_number}",
|
||||
characters="━",
|
||||
style=YELLOW_HEX,
|
||||
)
|
||||
)
|
||||
|
||||
# Run one step!
|
||||
|
@ -538,10 +549,10 @@ You have been provided with these additional arguments, that you can access usin
|
|||
self.logs.append(step_log)
|
||||
for callback in self.step_callbacks:
|
||||
callback(step_log)
|
||||
step_number += 1
|
||||
self.step_number += 1
|
||||
yield step_log
|
||||
|
||||
if final_answer is None and step_number == self.max_steps:
|
||||
if final_answer is None and self.step_number == self.max_steps:
|
||||
error_message = "Reached max steps."
|
||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||
self.logs.append(final_step_log)
|
||||
|
@ -561,20 +572,26 @@ You have been provided with these additional arguments, that you can access usin
|
|||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
step_number = 0
|
||||
while final_answer is None and step_number < self.max_steps:
|
||||
self.step_number = 0
|
||||
while final_answer is None and self.step_number < self.max_steps:
|
||||
step_start_time = time.time()
|
||||
step_log = ActionStep(step=step_number, start_time=step_start_time)
|
||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||
try:
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and step_number % self.planning_interval == 0
|
||||
and self.step_number % self.planning_interval == 0
|
||||
):
|
||||
self.planning_step(
|
||||
task, is_first_step=(step_number == 0), step=step_number
|
||||
task,
|
||||
is_first_step=(self.step_number == 0),
|
||||
step=self.step_number,
|
||||
)
|
||||
console.print(
|
||||
Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX)
|
||||
Rule(
|
||||
f"[bold]Step {self.step_number}",
|
||||
characters="━",
|
||||
style=YELLOW_HEX,
|
||||
)
|
||||
)
|
||||
|
||||
# Run one step!
|
||||
|
@ -589,9 +606,9 @@ You have been provided with these additional arguments, that you can access usin
|
|||
self.logs.append(step_log)
|
||||
for callback in self.step_callbacks:
|
||||
callback(step_log)
|
||||
step_number += 1
|
||||
self.step_number += 1
|
||||
|
||||
if final_answer is None and step_number == self.max_steps:
|
||||
if final_answer is None and self.step_number == self.max_steps:
|
||||
error_message = "Reached max steps."
|
||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||
self.logs.append(final_step_log)
|
||||
|
@ -637,8 +654,8 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||
self.tool_description_template
|
||||
tool_descriptions=get_tool_descriptions(
|
||||
self.tools, self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
|
@ -692,8 +709,8 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||
self.tool_description_template
|
||||
tool_descriptions=get_tool_descriptions(
|
||||
self.tools, self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
|
@ -761,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
try:
|
||||
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
|
||||
self.input_messages,
|
||||
available_tools=list(self.toolbox._tools.values()),
|
||||
available_tools=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -856,7 +873,7 @@ class CodeAgent(MultiStepAgent):
|
|||
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
||||
)
|
||||
|
||||
all_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||
all_tools = {**self.tools, **self.managed_agents}
|
||||
if use_e2b_executor:
|
||||
self.python_executor = E2BExecutor(
|
||||
self.additional_authorized_imports, list(all_tools.values())
|
||||
|
@ -941,10 +958,10 @@ class CodeAgent(MultiStepAgent):
|
|||
lexer="python",
|
||||
theme="monokai",
|
||||
word_wrap=True,
|
||||
line_numbers=True,
|
||||
),
|
||||
title="[bold]Executing this code:",
|
||||
title_align="left",
|
||||
box=box.HORIZONTALS,
|
||||
)
|
||||
)
|
||||
observation = ""
|
||||
|
@ -1045,5 +1062,4 @@ __all__ = [
|
|||
"MultiStepAgent",
|
||||
"CodeAgent",
|
||||
"ToolCallingAgent",
|
||||
"Toolbox",
|
||||
]
|
||||
|
|
|
@ -322,6 +322,15 @@ class SpeechToTextTool(PipelineTool):
|
|||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
TOOL_MAPPING = {
|
||||
tool_class.name: tool_class
|
||||
for tool_class in [
|
||||
PythonInterpreterTool,
|
||||
DuckDuckGoSearchTool,
|
||||
VisitWebpageTool,
|
||||
]
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"PythonInterpreterTool",
|
||||
"FinalAnswerTool",
|
||||
|
|
|
@ -157,6 +157,14 @@ class Model:
|
|||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tool_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
available_tools: List[Tool],
|
||||
stop_sequences,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
|
|
@ -25,7 +25,7 @@ import tempfile
|
|||
import textwrap
|
||||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||
from typing import Callable, Dict, Optional, Union, get_type_hints
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
|
@ -85,18 +85,6 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
|||
return "space"
|
||||
|
||||
|
||||
def setup_default_tools():
|
||||
default_tools = {}
|
||||
main_module = importlib.import_module("smolagents")
|
||||
|
||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||
tool_class = getattr(main_module, tool_class_name)
|
||||
tool_instance = tool_class()
|
||||
default_tools[tool_class.name] = tool_instance
|
||||
|
||||
return default_tools
|
||||
|
||||
|
||||
def validate_after_init(cls):
|
||||
original_init = cls.__init__
|
||||
|
||||
|
@ -727,10 +715,10 @@ def get_tool_description_with_args(
|
|||
if description_template is None:
|
||||
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
compiled_template = compile_jinja_template(description_template)
|
||||
rendered = compiled_template.render(
|
||||
tool_description = compiled_template.render(
|
||||
tool=tool,
|
||||
)
|
||||
return rendered
|
||||
return tool_description
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
@ -806,13 +794,6 @@ def launch_gradio_demo(tool: Tool):
|
|||
).launch()
|
||||
|
||||
|
||||
TOOL_MAPPING = {
|
||||
"python_interpreter": "PythonInterpreterTool",
|
||||
"web_search": "DuckDuckGoSearchTool",
|
||||
"transcriber": "SpeechToTextTool",
|
||||
}
|
||||
|
||||
|
||||
def load_tool(
|
||||
task_or_repo_id,
|
||||
model_repo_id: Optional[str] = None,
|
||||
|
@ -821,7 +802,7 @@ def load_tool(
|
|||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
||||
Main function to quickly load a tool from the Hub.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
@ -854,13 +835,6 @@ def load_tool(
|
|||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||
will be passed along to its init.
|
||||
"""
|
||||
if task_or_repo_id in TOOL_MAPPING:
|
||||
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("smolagents")
|
||||
tools_module = main_module
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
return tool_class(token=token, **kwargs)
|
||||
else:
|
||||
return Tool.from_hub(
|
||||
task_or_repo_id,
|
||||
model_repo_id=model_repo_id,
|
||||
|
@ -961,107 +935,6 @@ def tool(tool_function: Callable) -> Tool:
|
|||
return simple_tool
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
class Toolbox:
|
||||
"""
|
||||
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
||||
manage them.
|
||||
|
||||
Args:
|
||||
tools (`List[Tool]`):
|
||||
The list of tools to instantiate the toolbox with
|
||||
add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to add the tools available within `transformers` to the toolbox.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: List[Tool], add_base_tools: bool = False):
|
||||
self._tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
self.add_base_tools()
|
||||
|
||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
|
||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||
self.add_tool(tool)
|
||||
|
||||
@property
|
||||
def tools(self) -> Dict[str, Tool]:
|
||||
"""Get all tools currently in the toolbox"""
|
||||
return self._tools
|
||||
|
||||
def show_tool_descriptions(
|
||||
self, tool_description_template: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns the description of all tools in the toolbox
|
||||
|
||||
Args:
|
||||
tool_description_template (`str`, *optional*):
|
||||
The template to use to describe the tools. If not provided, the default template will be used.
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
get_tool_description_with_args(tool, tool_description_template)
|
||||
for tool in self._tools.values()
|
||||
]
|
||||
)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
"""
|
||||
Adds a tool to the toolbox
|
||||
|
||||
Args:
|
||||
tool (`Tool`):
|
||||
The tool to add to the toolbox.
|
||||
"""
|
||||
if tool.name in self._tools:
|
||||
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def remove_tool(self, tool_name: str):
|
||||
"""
|
||||
Removes a tool from the toolbox
|
||||
|
||||
Args:
|
||||
tool_name (`str`):
|
||||
The tool to remove from the toolbox.
|
||||
"""
|
||||
if tool_name not in self._tools:
|
||||
raise KeyError(
|
||||
f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
|
||||
)
|
||||
del self._tools[tool_name]
|
||||
|
||||
def update_tool(self, tool: Tool):
|
||||
"""
|
||||
Updates a tool in the toolbox according to its name.
|
||||
|
||||
Args:
|
||||
tool (`Tool`):
|
||||
The tool to update to the toolbox.
|
||||
"""
|
||||
if tool.name not in self._tools:
|
||||
raise KeyError(
|
||||
f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
|
||||
)
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def clear_toolbox(self):
|
||||
"""Clears the toolbox"""
|
||||
self._tools = {}
|
||||
|
||||
def __repr__(self):
|
||||
toolbox_description = "Toolbox contents:\n"
|
||||
for tool in self._tools.values():
|
||||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
||||
return toolbox_description
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
"""
|
||||
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
||||
|
@ -1234,6 +1107,5 @@ __all__ = [
|
|||
"tool",
|
||||
"load_tool",
|
||||
"launch_gradio_demo",
|
||||
"Toolbox",
|
||||
"ToolCollection",
|
||||
]
|
||||
|
|
|
@ -18,14 +18,12 @@ import unittest
|
|||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents.agents import (
|
||||
AgentMaxStepsError,
|
||||
CodeAgent,
|
||||
ManagedAgent,
|
||||
Toolbox,
|
||||
ToolCall,
|
||||
ToolCallingAgent,
|
||||
)
|
||||
|
@ -289,37 +287,35 @@ class AgentTests(unittest.TestCase):
|
|||
assert len(agent.logs) == 8
|
||||
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
||||
|
||||
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
||||
tool = PythonInterpreterTool()
|
||||
tool.name = "fake_tool_name"
|
||||
tool.description = "fake_tool_description"
|
||||
agent = CodeAgent(tools=[tool], model=fake_code_model)
|
||||
agent.run("Empty task")
|
||||
assert tool.name in agent.system_prompt
|
||||
assert tool.description in agent.system_prompt
|
||||
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 1
|
||||
len(agent.tools) == 1
|
||||
) # when no tools are provided, only the final_answer tool is added by default
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
len(agent.tools) == 2
|
||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = CodeAgent(tools=toolset_3, model=fake_code_model)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ToolCallingAgent(
|
||||
tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True
|
||||
)
|
||||
assert "already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
# check that python_interpreter base tool does not get added to CodeAgent
|
||||
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 3
|
||||
) # added final_answer tool + search + transcribe
|
||||
assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage
|
||||
|
||||
# check that python_interpreter base tool gets added to ToolCallingAgent
|
||||
agent = ToolCallingAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
||||
assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = CodeAgent(
|
||||
|
|
|
@ -18,8 +18,7 @@ import unittest
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from smolagents import load_tool
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
|
||||
from smolagents.local_python_executor import (
|
||||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
|
@ -37,7 +36,7 @@ def add_two(x):
|
|||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
||||
self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
|
|
@ -15,14 +15,14 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from smolagents import load_tool
|
||||
from smolagents import DuckDuckGoSearchTool
|
||||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("web_search")
|
||||
self.tool = DuckDuckGoSearchTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
|
Loading…
Reference in New Issue