Support OpenTelemetry (#136)

* Support OpenTelemetry
This commit is contained in:
Aymeric Roucher 2025-01-09 23:08:17 +01:00 committed by GitHub
parent cf04285cc1
commit cb9830a554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 118 additions and 227 deletions

View File

@ -39,10 +39,6 @@ contains the API docs for the underlying classes.
[[autodoc]] Tool [[autodoc]] Tool
### Toolbox
[[autodoc]] Toolbox
### launch_gradio_demo ### launch_gradio_demo
[[autodoc]] launch_gradio_demo [[autodoc]] launch_gradio_demo

View File

@ -187,7 +187,7 @@ from smolagents import HfApiModel
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[], model=model, add_base_tools=True) agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.toolbox.add_tool(model_download_tool) agent.tools.append(model_download_tool)
``` ```
Now we can leverage the new tool: Now we can leverage the new tool:
@ -202,11 +202,6 @@ agent.run(
> Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines. > Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines.
Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
### Use a collection of tools ### Use a collection of tools
You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use. You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.

View File

@ -18,13 +18,14 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from rich import box
from rich.console import Group from rich.console import Group
from rich.panel import Panel from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
from .default_tools import FinalAnswerTool from .default_tools import FinalAnswerTool, TOOL_MAPPING
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
@ -49,7 +50,6 @@ from .prompts import (
from .tools import ( from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool, Tool,
Toolbox,
get_tool_description_with_args, get_tool_description_with_args,
) )
from .types import AgentAudio, AgentImage, handle_agent_output_types from .types import AgentAudio, AgentImage, handle_agent_output_types
@ -107,18 +107,27 @@ class SystemPromptStep(AgentStep):
system_prompt: str system_prompt: str
def format_prompt_with_tools( def get_tool_descriptions(
toolbox: Toolbox, prompt_template: str, tool_description_template: str tools: Dict[str, Tool], tool_description_template: str
) -> str: ) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) return "\n".join(
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) [
get_tool_description_with_args(tool, tool_description_template)
for tool in tools.values()
]
)
def format_prompt_with_tools(
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
) -> str:
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "{{tool_names}}" in prompt: if "{{tool_names}}" in prompt:
prompt = prompt.replace( prompt = prompt.replace(
"{{tool_names}}", "{{tool_names}}",
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]), ", ".join([f"'{tool.name}'" for tool in tools.values()]),
) )
return prompt return prompt
@ -163,7 +172,7 @@ class MultiStepAgent:
def __init__( def __init__(
self, self,
tools: Union[List[Tool], Toolbox], tools: List[Tool],
model: Callable[[List[Dict[str, str]]], str], model: Callable[[List[Dict[str, str]]], str],
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None, tool_description_template: Optional[str] = None,
@ -172,7 +181,7 @@ class MultiStepAgent:
add_base_tools: bool = False, add_base_tools: bool = False,
verbose: bool = False, verbose: bool = False,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[Dict] = None, managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None, step_callbacks: Optional[List[Callable]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
): ):
@ -196,17 +205,18 @@ class MultiStepAgent:
self.managed_agents = {} self.managed_agents = {}
if managed_agents is not None: if managed_agents is not None:
print("NOTNONE")
self.managed_agents = {agent.name: agent for agent in managed_agents} self.managed_agents = {agent.name: agent for agent in managed_agents}
if isinstance(tools, Toolbox): self.tools = {tool.name: tool for tool in tools}
self._toolbox = tools if add_base_tools:
if add_base_tools: for tool_name, tool_class in TOOL_MAPPING.items():
self._toolbox.add_base_tools( if (
add_python_interpreter=(self.__class__ == ToolCallingAgent) tool_name != "python_interpreter"
) or self.__class__.__name__ == "ToolCallingAgent"
else: ):
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self.tools[tool_name] = tool_class()
self._toolbox.add_tool(FinalAnswerTool()) self.tools["final_answer"] = FinalAnswerTool()
self.system_prompt = self.initialize_system_prompt() self.system_prompt = self.initialize_system_prompt()
self.input_messages = None self.input_messages = None
@ -217,14 +227,9 @@ class MultiStepAgent:
self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics) self.step_callbacks.append(self.monitor.update_metrics)
@property
def toolbox(self) -> Toolbox:
"""Get the toolbox currently available to the agent"""
return self._toolbox
def initialize_system_prompt(self): def initialize_system_prompt(self):
self.system_prompt = format_prompt_with_tools( self.system_prompt = format_prompt_with_tools(
self._toolbox, self.tools,
self.system_prompt_template, self.system_prompt_template,
self.tool_description_template, self.tool_description_template,
) )
@ -384,10 +389,10 @@ class MultiStepAgent:
This method replaces arguments with the actual values from the state if they refer to state variables. This method replaces arguments with the actual values from the state if they refer to state variables.
Args: Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). tool_name (`str`): Name of the Tool to execute (should be one from self.tools).
arguments (Dict[str, str]): Arguments passed to the Tool. arguments (Dict[str, str]): Arguments passed to the Tool.
""" """
available_tools = {**self.toolbox.tools, **self.managed_agents} available_tools = {**self.tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
@ -415,7 +420,7 @@ class MultiStepAgent:
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
return observation return observation
except Exception as e: except Exception as e:
if tool_name in self.toolbox.tools: if tool_name in self.tools:
tool_description = get_tool_description_with_args( tool_description = get_tool_description_with_args(
available_tools[tool_name] available_tools[tool_name]
) )
@ -512,20 +517,26 @@ You have been provided with these additional arguments, that you can access usin
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
""" """
final_answer = None final_answer = None
step_number = 0 self.step_number = 0
while final_answer is None and step_number < self.max_steps: while final_answer is None and self.step_number < self.max_steps:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(step=step_number, start_time=step_start_time) step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try: try:
if ( if (
self.planning_interval is not None self.planning_interval is not None
and step_number % self.planning_interval == 0 and self.step_number % self.planning_interval == 0
): ):
self.planning_step( self.planning_step(
task, is_first_step=(step_number == 0), step=step_number task,
is_first_step=(self.step_number == 0),
step=self.step_number,
) )
console.print( console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX) Rule(
f"[bold]Step {self.step_number}",
characters="",
style=YELLOW_HEX,
)
) )
# Run one step! # Run one step!
@ -538,10 +549,10 @@ You have been provided with these additional arguments, that you can access usin
self.logs.append(step_log) self.logs.append(step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(step_log) callback(step_log)
step_number += 1 self.step_number += 1
yield step_log yield step_log
if final_answer is None and step_number == self.max_steps: if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps." error_message = "Reached max steps."
final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
@ -561,20 +572,26 @@ You have been provided with these additional arguments, that you can access usin
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
""" """
final_answer = None final_answer = None
step_number = 0 self.step_number = 0
while final_answer is None and step_number < self.max_steps: while final_answer is None and self.step_number < self.max_steps:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(step=step_number, start_time=step_start_time) step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try: try:
if ( if (
self.planning_interval is not None self.planning_interval is not None
and step_number % self.planning_interval == 0 and self.step_number % self.planning_interval == 0
): ):
self.planning_step( self.planning_step(
task, is_first_step=(step_number == 0), step=step_number task,
is_first_step=(self.step_number == 0),
step=self.step_number,
) )
console.print( console.print(
Rule(f"[bold]Step {step_number}", characters="", style=YELLOW_HEX) Rule(
f"[bold]Step {self.step_number}",
characters="",
style=YELLOW_HEX,
)
) )
# Run one step! # Run one step!
@ -589,9 +606,9 @@ You have been provided with these additional arguments, that you can access usin
self.logs.append(step_log) self.logs.append(step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
callback(step_log) callback(step_log)
step_number += 1 self.step_number += 1
if final_answer is None and step_number == self.max_steps: if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps." error_message = "Reached max steps."
final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log) self.logs.append(final_step_log)
@ -637,8 +654,8 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format( "content": USER_PROMPT_PLAN.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions( tool_descriptions=get_tool_descriptions(
self.tool_description_template self.tools, self.tool_description_template
), ),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) show_agents_descriptions(self.managed_agents)
@ -692,8 +709,8 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format( "content": USER_PROMPT_PLAN_UPDATE.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions( tool_descriptions=get_tool_descriptions(
self.tool_description_template self.tools, self.tool_description_template
), ),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) show_agents_descriptions(self.managed_agents)
@ -761,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent):
try: try:
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call( tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
self.input_messages, self.input_messages,
available_tools=list(self.toolbox._tools.values()), available_tools=list(self.tools.values()),
stop_sequences=["Observation:"], stop_sequences=["Observation:"],
) )
except Exception as e: except Exception as e:
@ -856,7 +873,7 @@ class CodeAgent(MultiStepAgent):
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution." f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
) )
all_tools = {**self.toolbox.tools, **self.managed_agents} all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor: if use_e2b_executor:
self.python_executor = E2BExecutor( self.python_executor = E2BExecutor(
self.additional_authorized_imports, list(all_tools.values()) self.additional_authorized_imports, list(all_tools.values())
@ -941,10 +958,10 @@ class CodeAgent(MultiStepAgent):
lexer="python", lexer="python",
theme="monokai", theme="monokai",
word_wrap=True, word_wrap=True,
line_numbers=True,
), ),
title="[bold]Executing this code:", title="[bold]Executing this code:",
title_align="left", title_align="left",
box=box.HORIZONTALS,
) )
) )
observation = "" observation = ""
@ -1045,5 +1062,4 @@ __all__ = [
"MultiStepAgent", "MultiStepAgent",
"CodeAgent", "CodeAgent",
"ToolCallingAgent", "ToolCallingAgent",
"Toolbox",
] ]

View File

@ -322,6 +322,15 @@ class SpeechToTextTool(PipelineTool):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
TOOL_MAPPING = {
tool_class.name: tool_class
for tool_class in [
PythonInterpreterTool,
DuckDuckGoSearchTool,
VisitWebpageTool,
]
}
__all__ = [ __all__ = [
"PythonInterpreterTool", "PythonInterpreterTool",
"FinalAnswerTool", "FinalAnswerTool",

View File

@ -157,6 +157,14 @@ class Model:
): ):
raise NotImplementedError raise NotImplementedError
def get_tool_call(
self,
messages: List[Dict[str, str]],
available_tools: List[Tool],
stop_sequences,
):
raise NotImplementedError
def __call__( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],

View File

@ -25,7 +25,7 @@ import tempfile
import textwrap import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union, get_type_hints from typing import Callable, Dict, Optional, Union, get_type_hints
import torch import torch
from huggingface_hub import ( from huggingface_hub import (
@ -85,18 +85,6 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
return "space" return "space"
def setup_default_tools():
default_tools = {}
main_module = importlib.import_module("smolagents")
for task_name, tool_class_name in TOOL_MAPPING.items():
tool_class = getattr(main_module, tool_class_name)
tool_instance = tool_class()
default_tools[tool_class.name] = tool_instance
return default_tools
def validate_after_init(cls): def validate_after_init(cls):
original_init = cls.__init__ original_init = cls.__init__
@ -727,10 +715,10 @@ def get_tool_description_with_args(
if description_template is None: if description_template is None:
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
compiled_template = compile_jinja_template(description_template) compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render( tool_description = compiled_template.render(
tool=tool, tool=tool,
) )
return rendered return tool_description
@lru_cache @lru_cache
@ -806,13 +794,6 @@ def launch_gradio_demo(tool: Tool):
).launch() ).launch()
TOOL_MAPPING = {
"python_interpreter": "PythonInterpreterTool",
"web_search": "DuckDuckGoSearchTool",
"transcriber": "SpeechToTextTool",
}
def load_tool( def load_tool(
task_or_repo_id, task_or_repo_id,
model_repo_id: Optional[str] = None, model_repo_id: Optional[str] = None,
@ -821,7 +802,7 @@ def load_tool(
**kwargs, **kwargs,
): ):
""" """
Main function to quickly load a tool, be it on the Hub or in the Transformers library. Main function to quickly load a tool from the Hub.
<Tip warning={true}> <Tip warning={true}>
@ -854,20 +835,13 @@ def load_tool(
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
will be passed along to its init. will be passed along to its init.
""" """
if task_or_repo_id in TOOL_MAPPING: return Tool.from_hub(
tool_class_name = TOOL_MAPPING[task_or_repo_id] task_or_repo_id,
main_module = importlib.import_module("smolagents") model_repo_id=model_repo_id,
tools_module = main_module token=token,
tool_class = getattr(tools_module, tool_class_name) trust_remote_code=trust_remote_code,
return tool_class(token=token, **kwargs) **kwargs,
else: )
return Tool.from_hub(
task_or_repo_id,
model_repo_id=model_repo_id,
token=token,
trust_remote_code=trust_remote_code,
**kwargs,
)
def add_description(description): def add_description(description):
@ -961,107 +935,6 @@ def tool(tool_function: Callable) -> Tool:
return simple_tool return simple_tool
HUGGINGFACE_DEFAULT_TOOLS = {}
class Toolbox:
"""
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
manage them.
Args:
tools (`List[Tool]`):
The list of tools to instantiate the toolbox with
add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to add the tools available within `transformers` to the toolbox.
"""
def __init__(self, tools: List[Tool], add_base_tools: bool = False):
self._tools = {tool.name: tool for tool in tools}
if add_base_tools:
self.add_base_tools()
def add_base_tools(self, add_python_interpreter: bool = False):
global HUGGINGFACE_DEFAULT_TOOLS
if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
if tool.name != "python_interpreter" or add_python_interpreter:
self.add_tool(tool)
@property
def tools(self) -> Dict[str, Tool]:
"""Get all tools currently in the toolbox"""
return self._tools
def show_tool_descriptions(
self, tool_description_template: Optional[str] = None
) -> str:
"""
Returns the description of all tools in the toolbox
Args:
tool_description_template (`str`, *optional*):
The template to use to describe the tools. If not provided, the default template will be used.
"""
return "\n".join(
[
get_tool_description_with_args(tool, tool_description_template)
for tool in self._tools.values()
]
)
def add_tool(self, tool: Tool):
"""
Adds a tool to the toolbox
Args:
tool (`Tool`):
The tool to add to the toolbox.
"""
if tool.name in self._tools:
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
self._tools[tool.name] = tool
def remove_tool(self, tool_name: str):
"""
Removes a tool from the toolbox
Args:
tool_name (`str`):
The tool to remove from the toolbox.
"""
if tool_name not in self._tools:
raise KeyError(
f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
)
del self._tools[tool_name]
def update_tool(self, tool: Tool):
"""
Updates a tool in the toolbox according to its name.
Args:
tool (`Tool`):
The tool to update to the toolbox.
"""
if tool.name not in self._tools:
raise KeyError(
f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
)
self._tools[tool.name] = tool
def clear_toolbox(self):
"""Clears the toolbox"""
self._tools = {}
def __repr__(self):
toolbox_description = "Toolbox contents:\n"
for tool in self._tools.values():
toolbox_description += f"\t{tool.name}: {tool.description}\n"
return toolbox_description
class PipelineTool(Tool): class PipelineTool(Tool):
""" """
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
@ -1234,6 +1107,5 @@ __all__ = [
"tool", "tool",
"load_tool", "load_tool",
"launch_gradio_demo", "launch_gradio_demo",
"Toolbox",
"ToolCollection", "ToolCollection",
] ]

View File

@ -18,14 +18,12 @@ import unittest
import uuid import uuid
from pathlib import Path from pathlib import Path
import pytest
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents.agents import ( from smolagents.agents import (
AgentMaxStepsError, AgentMaxStepsError,
CodeAgent, CodeAgent,
ManagedAgent, ManagedAgent,
Toolbox,
ToolCall, ToolCall,
ToolCallingAgent, ToolCallingAgent,
) )
@ -289,37 +287,35 @@ class AgentTests(unittest.TestCase):
assert len(agent.logs) == 8 assert len(agent.logs) == 8
assert type(agent.logs[-1].error) is AgentMaxStepsError assert type(agent.logs[-1].error) is AgentMaxStepsError
def test_tool_descriptions_get_baked_in_system_prompt(self):
tool = PythonInterpreterTool()
tool.name = "fake_tool_name"
tool.description = "fake_tool_description"
agent = CodeAgent(tools=[tool], model=fake_code_model)
agent.run("Empty task")
assert tool.name in agent.system_prompt
assert tool.description in agent.system_prompt
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = CodeAgent(tools=toolset_1, model=fake_code_model) agent = CodeAgent(tools=toolset_1, model=fake_code_model)
assert ( assert (
len(agent.toolbox.tools) == 1 len(agent.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default ) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, model=fake_code_model) agent = CodeAgent(tools=toolset_2, model=fake_code_model)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2) # check that python_interpreter base tool does not get added to CodeAgent
agent = CodeAgent(tools=toolset_3, model=fake_code_model)
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ToolCallingAgent(
tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True
)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True) agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
assert ( assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage
len(agent.toolbox.tools) == 3
) # added final_answer tool + search + transcribe # check that python_interpreter base tool gets added to ToolCallingAgent
agent = ToolCallingAgent(tools=[], model=fake_code_model, add_base_tools=True)
assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = CodeAgent( agent = CodeAgent(

View File

@ -18,8 +18,7 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
from smolagents import load_tool from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import ( from smolagents.local_python_executor import (
InterpreterError, InterpreterError,
evaluate_python_code, evaluate_python_code,
@ -37,7 +36,7 @@ def add_two(x):
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):

View File

@ -15,14 +15,14 @@
import unittest import unittest
from smolagents import load_tool from smolagents import DuckDuckGoSearchTool
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("web_search") self.tool = DuckDuckGoSearchTool()
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):