diff --git a/examples/open_deep_research/scripts/mdconvert.py b/examples/open_deep_research/scripts/mdconvert.py
index 72cb0a0..68f13a2 100644
--- a/examples/open_deep_research/scripts/mdconvert.py
+++ b/examples/open_deep_research/scripts/mdconvert.py
@@ -567,13 +567,13 @@ class WavConverter(MediaConverter):
class Mp3Converter(WavConverter):
"""
- Converts MP3 files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
+ Converts MP3 and M4A files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
"""
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
# Bail if not a MP3
extension = kwargs.get("file_extension", "")
- if extension.lower() != ".mp3":
+ if extension.lower() not in [".mp3", ".m4a"]:
return None
md_content = ""
@@ -600,7 +600,10 @@ class Mp3Converter(WavConverter):
handle, temp_path = tempfile.mkstemp(suffix=".wav")
os.close(handle)
try:
- sound = pydub.AudioSegment.from_mp3(local_path)
+ if extension.lower() == ".mp3":
+ sound = pydub.AudioSegment.from_mp3(local_path)
+ else:
+ sound = pydub.AudioSegment.from_file(local_path, format="m4a")
sound.export(temp_path, format="wav")
_args = dict()
diff --git a/examples/open_deep_research/scripts/text_inspector_tool.py b/examples/open_deep_research/scripts/text_inspector_tool.py
index 09e7c11..056168c 100644
--- a/examples/open_deep_research/scripts/text_inspector_tool.py
+++ b/examples/open_deep_research/scripts/text_inspector_tool.py
@@ -10,7 +10,7 @@ class TextInspectorTool(Tool):
name = "inspect_file_as_text"
description = """
You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
-This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
+This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
inputs = {
"file_path": {
diff --git a/examples/open_deep_research/scripts/text_web_browser.py b/examples/open_deep_research/scripts/text_web_browser.py
index 18763c4..ef40f85 100644
--- a/examples/open_deep_research/scripts/text_web_browser.py
+++ b/examples/open_deep_research/scripts/text_web_browser.py
@@ -410,7 +410,7 @@ class VisitTool(Tool):
class DownloadTool(Tool):
name = "download_file"
description = """
-Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
+Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".png", ".docx"]
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}}
diff --git a/src/smolagents/_function_type_hints_utils.py b/src/smolagents/_function_type_hints_utils.py
index 5eb9502..13c6a25 100644
--- a/src/smolagents/_function_type_hints_utils.py
+++ b/src/smolagents/_function_type_hints_utils.py
@@ -24,7 +24,6 @@ TODO: move them to `huggingface_hub` to avoid code duplication.
import inspect
import json
-import os
import re
import types
from copy import copy
@@ -46,34 +45,31 @@ from huggingface_hub.utils import is_torch_available
from .utils import _is_pillow_available
-def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
+def get_imports(code: str) -> List[str]:
"""
- Extracts all the libraries (not relative imports this time) that are imported in a file.
+ Extracts all the libraries (not relative imports) that are imported in a code.
Args:
- filename (`str` or `os.PathLike`): The module file to inspect.
+ code (`str`): Code text to inspect.
Returns:
- `List[str]`: The list of all packages required to use the input module.
+ `list[str]`: List of all packages required to use the input code.
"""
- with open(filename, "r", encoding="utf-8") as f:
- content = f.read()
-
# filter out try/except block so in custom code we can have try/except imports
- content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
+ code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
- content = re.sub(
+ code = re.sub(
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
"",
- content,
+ code,
flags=re.MULTILINE,
)
- # Imports of the form `import xxx`
- imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `import xxx` or `import xxx as yyy`
+ imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
- imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports))
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index 753ae18..1ce54d1 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -14,44 +14,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
-
-import importlib.resources
+import importlib
import inspect
+import json
+import os
import re
+import tempfile
import textwrap
import time
from collections import deque
from logging import getLogger
+from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
+import jinja2
import yaml
+from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
from jinja2 import StrictUndefined, Template
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
-from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types
-from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
-from smolagents.monitoring import (
- YELLOW_HEX,
- AgentLogger,
- LogLevel,
-)
-from smolagents.utils import (
- AgentError,
- AgentExecutionError,
- AgentGenerationError,
- AgentMaxStepsError,
- AgentParsingError,
- parse_code_blobs,
- parse_json_tool_call,
- truncate_content,
-)
-
-from .agent_types import AgentType
+from .agent_types import AgentAudio, AgentImage, AgentType, handle_agent_output_types
from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import (
@@ -59,12 +44,30 @@ from .local_python_executor import (
LocalPythonInterpreter,
fix_final_answer_code,
)
+from .memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from .models import (
ChatMessage,
MessageRole,
+ Model,
+)
+from .monitoring import (
+ YELLOW_HEX,
+ AgentLogger,
+ LogLevel,
+ Monitor,
)
-from .monitoring import Monitor
from .tools import Tool
+from .utils import (
+ AgentError,
+ AgentExecutionError,
+ AgentGenerationError,
+ AgentMaxStepsError,
+ AgentParsingError,
+ make_init_file,
+ parse_code_blobs,
+ parse_json_tool_call,
+ truncate_content,
+)
logger = getLogger(__name__)
@@ -228,9 +231,19 @@ class MultiStepAgent:
)
self.managed_agents = {agent.name: agent for agent in managed_agents}
+ tool_and_managed_agent_names = [tool.name for tool in tools]
+ if managed_agents is not None:
+ tool_and_managed_agent_names += [agent.name for agent in managed_agents]
+ if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
+ raise ValueError(
+ "Each tool or managed_agent should have a unique name! You passed these duplicate names: "
+ f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
+ )
+
for tool in tools:
assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}"
self.tools = {tool.name: tool for tool in tools}
+
if add_base_tools:
for tool_name, tool_class in TOOL_MAPPING.items():
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
@@ -709,6 +722,310 @@ You have been provided with these additional arguments, that you can access usin
answer += "\n"
return answer
+ def save(self, output_dir: str, relative_path: Optional[str] = None):
+ """
+ Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
+
+ - a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
+ - a `managed_agents` folder containing the logic for each of the managed agents.
+ - an `agent.json` file containing a dictionary representing your agent.
+ - a `prompt.yaml` file containing the prompt templates used by your agent.
+ - an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
+ - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
+ code)
+
+ Args:
+ output_dir (`str`): The folder in which you want to save your tool.
+ """
+ make_init_file(output_dir)
+
+ # Recursively save managed agents
+ if self.managed_agents:
+ make_init_file(os.path.join(output_dir, "managed_agents"))
+ for agent_name, agent in self.managed_agents.items():
+ agent_suffix = f"managed_agents.{agent_name}"
+ if relative_path:
+ agent_suffix = relative_path + "." + agent_suffix
+ agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
+
+ class_name = self.__class__.__name__
+
+ # Save tools to different .py files
+ for tool in self.tools.values():
+ make_init_file(os.path.join(output_dir, "tools"))
+ tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
+
+ # Save prompts to yaml
+ yaml_prompts = yaml.safe_dump(
+ self.prompt_templates,
+ default_style="|", # This forces block literals for all strings
+ default_flow_style=False,
+ width=float("inf"),
+ sort_keys=False,
+ allow_unicode=True,
+ indent=2,
+ )
+
+ with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
+ f.write(yaml_prompts)
+
+ # Save agent dictionary to json
+ agent_dict = self.to_dict()
+ agent_dict["tools"] = [tool.name for tool in self.tools.values()]
+ with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
+ json.dump(agent_dict, f, indent=4)
+
+ # Save requirements
+ with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
+ f.writelines(f"{r}\n" for r in agent_dict["requirements"])
+
+ # Make agent.py file with Gradio UI
+ agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
+ managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
+ app_template = textwrap.dedent("""
+ import yaml
+ import os
+ from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
+
+ # Get current directory path
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+ {% for tool in tools.values() -%}
+ from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
+ {% endfor %}
+ {% for managed_agent in managed_agents.values() -%}
+ from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
+ {% endfor %}
+
+ model = {{ agent_dict['model']['class'] }}(
+ {% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
+ {{ key }}={{ agent_dict['model']['data'][key]|repr }},
+ {% endfor %})
+
+ {% for tool in tools.values() -%}
+ {{ tool.name }} = {{ tool.name | camelcase }}()
+ {% endfor %}
+
+ with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
+ prompt_templates = yaml.safe_load(stream)
+
+ {{ agent_name }} = {{ class_name }}(
+ model=model,
+ tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
+ managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
+ {% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
+ {{ attribute_name }}={{ value|repr }},
+ {% endfor %}prompt_templates=prompt_templates
+ )
+ if __name__ == "__main__":
+ GradioUI({{ agent_name }}).launch()
+ """).strip()
+ template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined)
+ template_env.filters["repr"] = repr
+ template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_"))
+ template = template_env.from_string(app_template)
+
+ # Render the app.py file from Jinja2 template
+ app_text = template.render(
+ {
+ "agent_name": agent_name,
+ "class_name": class_name,
+ "agent_dict": agent_dict,
+ "tools": self.tools,
+ "managed_agents": self.managed_agents,
+ "managed_agent_relative_path": managed_agent_relative_path,
+ }
+ )
+
+ with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
+ f.write(app_text + "\n") # Append newline at the end
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Converts agent into a dictionary."""
+ # TODO: handle serializing step_callbacks and final_answer_checks
+ for attr in ["final_answer_checks", "step_callbacks"]:
+ if getattr(self, attr, None):
+ self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
+
+ tool_dicts = [tool.to_dict() for tool in self.tools.values()]
+ tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
+ managed_agents_requirements = {
+ req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
+ }
+ requirements = tool_requirements | managed_agents_requirements
+ if hasattr(self, "authorized_imports"):
+ requirements.update(
+ {package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
+ )
+
+ agent_dict = {
+ "tools": tool_dicts,
+ "model": {
+ "class": self.model.__class__.__name__,
+ "data": self.model.to_dict(),
+ },
+ "managed_agents": {
+ managed_agent.name: managed_agent.__class__.__name__ for managed_agent in self.managed_agents.values()
+ },
+ "prompt_templates": self.prompt_templates,
+ "max_steps": self.max_steps,
+ "verbosity_level": int(self.logger.level),
+ "grammar": self.grammar,
+ "planning_interval": self.planning_interval,
+ "name": self.name,
+ "description": self.description,
+ "requirements": list(requirements),
+ }
+ if hasattr(self, "authorized_imports"):
+ agent_dict["authorized_imports"] = self.authorized_imports
+ return agent_dict
+
+ @classmethod
+ def from_hub(
+ cls,
+ repo_id: str,
+ token: Optional[str] = None,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
+ """
+ Loads an agent defined on the Hub.
+
+
+
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
+ installing a package using pip/npm/apt.
+
+
+
+ Args:
+ repo_id (`str`):
+ The name of the repo on the Hub where your tool is defined.
+ token (`str`, *optional*):
+ The token to identify you on hf.co. If unset, will use the token generated when running
+ `huggingface-cli login` (stored in `~/.huggingface`).
+ trust_remote_code(`bool`, *optional*, defaults to False):
+ This flags marks that you understand the risk of running remote code and that you trust this tool.
+ If not setting this to True, loading the tool from Hub will fail.
+ kwargs (additional keyword arguments, *optional*):
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
+ others will be passed along to its init.
+ """
+ if not trust_remote_code:
+ raise ValueError(
+ "Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
+ )
+
+ # Get the agent's Hub folder.
+ download_kwargs = {"token": token, "repo_type": "space"} | {
+ key: kwargs.pop(key)
+ for key in [
+ "cache_dir",
+ "force_download",
+ "resume_download",
+ "proxies",
+ "revision",
+ "subfolder",
+ "local_files_only",
+ ]
+ if key in kwargs
+ }
+
+ download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
+ return cls.from_folder(download_folder, **kwargs)
+
+ @classmethod
+ def from_folder(cls, folder: Union[str, Path], **kwargs):
+ """Loads an agent from a local folder"""
+ folder = Path(folder)
+ agent_dict = json.loads((folder / "agent.json").read_text())
+
+ # Recursively get managed agents
+ managed_agents = []
+ for managed_agent_name, managed_agent_class in agent_dict["managed_agents"].items():
+ agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class)
+ managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
+
+ tools = []
+ for tool_name in agent_dict["tools"]:
+ tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
+ tools.append(Tool.from_code(tool_code))
+
+ model_class: Model = getattr(importlib.import_module("smolagents.models"), agent_dict["model"]["class"])
+ model = model_class.from_dict(agent_dict["model"]["data"])
+
+ args = dict(
+ model=model,
+ tools=tools,
+ managed_agents=managed_agents,
+ name=agent_dict["name"],
+ description=agent_dict["description"],
+ max_steps=agent_dict["max_steps"],
+ planning_interval=agent_dict["planning_interval"],
+ grammar=agent_dict["grammar"],
+ verbosity_level=agent_dict["verbosity_level"],
+ )
+ if cls.__name__ == "CodeAgent":
+ args["additional_authorized_imports"] = agent_dict["authorized_imports"]
+ args.update(kwargs)
+ return cls(**args)
+
+ def push_to_hub(
+ self,
+ repo_id: str,
+ commit_message: str = "Upload agent",
+ private: Optional[bool] = None,
+ token: Optional[Union[bool, str]] = None,
+ create_pr: bool = False,
+ ) -> str:
+ """
+ Upload the agent to the Hub.
+
+ Parameters:
+ repo_id (`str`):
+ The name of the repository you want to push to. It should contain your organization name when
+ pushing to a given organization.
+ commit_message (`str`, *optional*, defaults to `"Upload agent"`):
+ Message to commit while pushing.
+ private (`bool`, *optional*, defaults to `None`):
+ Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ token (`bool` or `str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ create_pr (`bool`, *optional*, defaults to `False`):
+ Whether to create a PR with the uploaded files or directly commit.
+ """
+ repo_url = create_repo(
+ repo_id=repo_id,
+ token=token,
+ private=private,
+ exist_ok=True,
+ repo_type="space",
+ space_sdk="gradio",
+ )
+ repo_id = repo_url.repo_id
+ metadata_update(
+ repo_id,
+ {"tags": ["smolagents", "agent"]},
+ repo_type="space",
+ token=token,
+ overwrite=True,
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ self.save(work_dir)
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
+ return upload_folder(
+ repo_id=repo_id,
+ commit_message=commit_message,
+ folder_path=work_dir,
+ token=token,
+ create_pr=create_pr,
+ repo_type="space",
+ )
+
class ToolCallingAgent(MultiStepAgent):
"""
@@ -863,6 +1180,8 @@ class CodeAgent(MultiStepAgent):
):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
+ self.use_e2b_executor = use_e2b_executor
+ self.max_print_outputs_length = max_print_outputs_length
prompt_templates = prompt_templates or yaml.safe_load(
importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
)
diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py
index f744251..4425f1e 100644
--- a/src/smolagents/gradio_ui.py
+++ b/src/smolagents/gradio_ui.py
@@ -141,7 +141,7 @@ def stream_to_gradio(
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
# Track tokens if model provides them
- if hasattr(agent.model, "last_input_token_count"):
+ if getattr(agent.model, "last_input_token_count", None):
total_input_tokens += agent.model.last_input_token_count
total_output_tokens += agent.model.last_output_token_count
if isinstance(step_log, ActionStep):
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index 75dda78..b576846 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -331,6 +331,53 @@ class Model:
"""
pass # To be implemented in child classes!
+ def to_dict(self) -> Dict:
+ """
+ Converts the model into a JSON-compatible dictionary.
+ """
+ model_dictionary = {
+ **self.kwargs,
+ "last_input_token_count": self.last_input_token_count,
+ "last_output_token_count": self.last_output_token_count,
+ "model_id": self.model_id,
+ }
+ for attribute in [
+ "custom_role_conversion",
+ "temperature",
+ "max_tokens",
+ "provider",
+ "timeout",
+ "api_base",
+ "torch_dtype",
+ "device_map",
+ "organization",
+ "project",
+ "azure_endpoint",
+ ]:
+ if hasattr(self, attribute):
+ model_dictionary[attribute] = getattr(self, attribute)
+
+ dangerous_attributes = ["token", "api_key"]
+ for attribute_name in dangerous_attributes:
+ if hasattr(self, attribute_name):
+ print(
+ f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually."
+ )
+ return model_dictionary
+
+ @classmethod
+ def from_dict(cls, model_dictionary: Dict[str, Any]) -> "Model":
+ model_instance = cls(
+ **{
+ k: v
+ for k, v in model_dictionary.items()
+ if k not in ["last_input_token_count", "last_output_token_count"]
+ }
+ )
+ model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None)
+ model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None)
+ return model_instance
+
class HfApiModel(Model):
"""A class to interact with Hugging Face's Inference API for language model interaction.
diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py
index 1f4d457..b8665c2 100644
--- a/src/smolagents/tool_validation.py
+++ b/src/smolagents/tool_validation.py
@@ -83,6 +83,31 @@ class MethodChecker(ast.NodeVisitor):
self.assigned_names.add(elt.id)
self.generic_visit(node)
+ def _handle_comprehension_generators(self, generators):
+ """Helper method to handle generators in all types of comprehensions"""
+ for generator in generators:
+ if isinstance(generator.target, ast.Name):
+ self.assigned_names.add(generator.target.id)
+ elif isinstance(generator.target, ast.Tuple):
+ for elt in generator.target.elts:
+ if isinstance(elt, ast.Name):
+ self.assigned_names.add(elt.id)
+
+ def visit_ListComp(self, node):
+ """Track variables in list comprehensions"""
+ self._handle_comprehension_generators(node.generators)
+ self.generic_visit(node)
+
+ def visit_DictComp(self, node):
+ """Track variables in dictionary comprehensions"""
+ self._handle_comprehension_generators(node.generators)
+ self.generic_visit(node)
+
+ def visit_SetComp(self, node):
+ """Track variables in set comprehensions"""
+ self._handle_comprehension_generators(node.generators)
+ self.generic_visit(node)
+
def visit_Attribute(self, node):
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
@@ -121,7 +146,8 @@ class MethodChecker(ast.NodeVisitor):
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
"""
Validates that a Tool class follows the proper patterns:
- 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
+ 0. Any argument of __init__ should have a default.
+ Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
1. About the class:
- Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes
@@ -140,13 +166,20 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
if not isinstance(tree.body[0], ast.ClassDef):
raise ValueError("Source code must define a class")
- # Check that __init__ method takes no arguments
+ # Check that __init__ method only has arguments with defaults
if not cls.__init__.__qualname__ == "Tool.__init__":
sig = inspect.signature(cls.__init__)
- non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
- if len(non_self_params) > 0:
+ non_default_params = [
+ arg_name
+ for arg_name, param in sig.parameters.items()
+ if arg_name != "self"
+ and param.default == inspect.Parameter.empty
+ and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs
+ ]
+ if non_default_params:
errors.append(
- f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
+ f"This tool has required arguments in __init__: {non_default_params}. "
+ "All parameters of __init__ must have default values!"
)
class_node = tree.body[0]
@@ -198,5 +231,5 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors:
- raise ValueError("Tool validation failed:\n" + "\n".join(errors))
+ raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
return
diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py
index d73fcce..15275e7 100644
--- a/src/smolagents/tools.py
+++ b/src/smolagents/tools.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
-import importlib
import inspect
import json
import logging
@@ -23,6 +22,7 @@ import os
import sys
import tempfile
import textwrap
+import types
from contextlib import contextmanager
from functools import wraps
from pathlib import Path
@@ -199,24 +199,9 @@ class Tool:
"""
self.is_initialized = True
- def save(self, output_dir):
- """
- Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
- tool in `output_dir` as well as autogenerate:
-
- - a `tool.py` file containing the logic for your tool.
- - an `app.py` file providing an UI for your tool when it is exported to a Space with `tool.push_to_hub()`
- - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
- code)
-
- Args:
- output_dir (`str`): The folder in which you want to save your tool.
- """
- os.makedirs(output_dir, exist_ok=True)
+ def to_dict(self) -> dict:
+ """Returns a dictionary representing the tool"""
class_name = self.__class__.__name__
- tool_file = os.path.join(output_dir, "tool.py")
-
- # Save tool file
if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained
source_code = get_source(self.forward).replace("@tool", "")
@@ -232,7 +217,7 @@ class Tool:
tool_code = textwrap.dedent(
f"""
from smolagents import Tool
- from typing import Optional
+ from typing import Any, Optional
class {class_name}(Tool):
name = "{self.name}"
@@ -272,33 +257,59 @@ class Tool:
validate_tool_attributes(self.__class__)
- tool_code = instance_to_source(self, base_cls=Tool)
+ tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool)
+
+ requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"}
+
+ return {"name": self.name, "code": tool_code, "requirements": requirements}
+
+ def save(self, output_dir: str, tool_file_name: str = "tool", make_gradio_app: bool = True):
+ """
+ Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
+ tool in `output_dir` as well as autogenerate:
+
+ - a `{tool_file_name}.py` file containing the logic for your tool.
+ If you pass `make_gradio_app=True`, this will also write:
+ - an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()`
+ - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
+ code)
+
+ Args:
+ output_dir (`str`): The folder in which you want to save your tool.
+ tool_file_name (`str`, *optional*): The file name in which you want to save your tool.
+ make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI.
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ class_name = self.__class__.__name__
+ tool_file = os.path.join(output_dir, f"{tool_file_name}.py")
+
+ tool_dict = self.to_dict()
+ tool_code = tool_dict["code"]
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}"))
- # Save app file
- app_file = os.path.join(output_dir, "app.py")
- with open(app_file, "w", encoding="utf-8") as f:
- f.write(
- textwrap.dedent(
- f"""
- from smolagents import launch_gradio_demo
- from typing import Optional
- from tool import {class_name}
+ if make_gradio_app:
+ # Save app file
+ app_file = os.path.join(output_dir, "app.py")
+ with open(app_file, "w", encoding="utf-8") as f:
+ f.write(
+ textwrap.dedent(
+ f"""
+ from smolagents import launch_gradio_demo
+ from {tool_file_name} import {class_name}
- tool = {class_name}()
+ tool = {class_name}()
- launch_gradio_demo(tool)
- """
- ).lstrip()
- )
+ launch_gradio_demo(tool)
+ """
+ ).lstrip()
+ )
- # Save requirements file
- imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"}
- requirements_file = os.path.join(output_dir, "requirements.txt")
- with open(requirements_file, "w", encoding="utf-8") as f:
- f.write("\n".join(imports) + "\n")
+ # Save requirements file
+ requirements_file = os.path.join(output_dir, "requirements.txt")
+ with open(requirements_file, "w", encoding="utf-8") as f:
+ f.write("\n".join(tool_dict["requirements"]) + "\n")
def push_to_hub(
self,
@@ -311,14 +322,6 @@ class Tool:
"""
Upload the tool to the Hub.
- For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
- For instance:
- ```
- from my_tool_module import MyTool
- my_tool = MyTool()
- my_tool.push_to_hub("my-username/my-space")
- ```
-
Parameters:
repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when
@@ -342,13 +345,11 @@ class Tool:
space_sdk="gradio",
)
repo_id = repo_url.repo_id
- metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space", token=token)
+ metadata_update(repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token)
with tempfile.TemporaryDirectory() as work_dir:
# Save all files.
self.save(work_dir)
- with open(work_dir + "/tool.py", "r") as f:
- print("\n".join(f.readlines()))
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return upload_folder(
repo_id=repo_id,
@@ -394,7 +395,7 @@ class Tool:
"""
if not trust_remote_code:
raise ValueError(
- "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
+ "Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
)
# Get the tool's tool.py file.
@@ -413,30 +414,26 @@ class Tool:
)
tool_code = Path(tool_file).read_text()
+ return Tool.from_code(tool_code, **kwargs)
- # Find the Tool subclass in the namespace
- with tempfile.TemporaryDirectory() as temp_dir:
- # Save the code to a file
- module_path = os.path.join(temp_dir, "tool.py")
- with open(module_path, "w") as f:
- f.write(tool_code)
+ @classmethod
+ def from_code(cls, tool_code: str, **kwargs):
+ module = types.ModuleType("dynamic_tool")
- print("TOOL CODE:\n", tool_code)
+ exec(tool_code, module.__dict__)
- # Load module from file path
- spec = importlib.util.spec_from_file_location("tool", module_path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ # Find the Tool subclass
+ tool_class = next(
+ (
+ obj
+ for _, obj in inspect.getmembers(module, inspect.isclass)
+ if issubclass(obj, Tool) and obj is not Tool
+ ),
+ None,
+ )
- # Find and instantiate the Tool class
- for item_name in dir(module):
- item = getattr(module, item_name)
- if isinstance(item, type) and issubclass(item, Tool) and item != Tool:
- tool_class = item
- break
-
- if tool_class is None:
- raise ValueError("No Tool subclass found in the code.")
+ if tool_class is None:
+ raise ValueError("No Tool subclass found in the code.")
if not isinstance(tool_class.inputs, dict):
tool_class.inputs = ast.literal_eval(tool_class.inputs)
diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py
index 6176269..b9868dc 100644
--- a/src/smolagents/utils.py
+++ b/src/smolagents/utils.py
@@ -20,6 +20,7 @@ import importlib.metadata
import importlib.util
import inspect
import json
+import os
import re
import textwrap
import types
@@ -414,3 +415,10 @@ def encode_image_base64(image):
def make_image_url(base64_image):
return f"data:image/png;base64,{base64_image}"
+
+
+def make_init_file(folder: str):
+ os.makedirs(folder, exist_ok=True)
+ # Create __init__
+ with open(os.path.join(folder, "__init__.py"), "w"):
+ pass
diff --git a/tests/test_agents.py b/tests/test_agents.py
index d4ce8d7..adcdb6a 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -31,12 +31,13 @@ from smolagents.agents import (
ToolCallingAgent,
populate_template,
)
-from smolagents.default_tools import PythonInterpreterTool
+from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
from smolagents.memory import PlanningStep
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
+ HfApiModel,
MessageRole,
TransformersModel,
)
@@ -436,10 +437,15 @@ class AgentTests(unittest.TestCase):
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
- agent = CodeAgent(tools=toolset_2, model=fake_code_model)
- assert (
- len(agent.tools) == 2
- ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
+ with pytest.raises(ValueError) as e:
+ agent = CodeAgent(tools=toolset_2, model=fake_code_model)
+ assert "Each tool or managed_agent should have a unique name!" in str(e)
+
+ with pytest.raises(ValueError) as e:
+ agent.name = "python_interpreter"
+ agent.description = "empty"
+ CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model, managed_agents=[agent])
+ assert "Each tool or managed_agent should have a unique name!" in str(e)
# check that python_interpreter base tool does not get added to CodeAgent
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
@@ -484,132 +490,6 @@ class AgentTests(unittest.TestCase):
str_output = capture.get()
assert "`additional_authorized_imports`" in str_output.replace("\n", "")
- def test_multiagents(self):
- class FakeModelMultiagentsManagerAgent:
- model_id = "fake_model"
-
- def __call__(
- self,
- messages,
- stop_sequences=None,
- grammar=None,
- tools_to_call_from=None,
- ):
- if tools_to_call_from is not None:
- if len(messages) < 3:
- return ChatMessage(
- role="assistant",
- content="",
- tool_calls=[
- ChatMessageToolCall(
- id="call_0",
- type="function",
- function=ChatMessageToolCallDefinition(
- name="search_agent",
- arguments="Who is the current US president?",
- ),
- )
- ],
- )
- else:
- assert "Report on the current US president" in str(messages)
- return ChatMessage(
- role="assistant",
- content="",
- tool_calls=[
- ChatMessageToolCall(
- id="call_0",
- type="function",
- function=ChatMessageToolCallDefinition(
- name="final_answer", arguments="Final report."
- ),
- )
- ],
- )
- else:
- if len(messages) < 3:
- return ChatMessage(
- role="assistant",
- content="""
-Thought: Let's call our search agent.
-Code:
-```py
-result = search_agent("Who is the current US president?")
-```
-""",
- )
- else:
- assert "Report on the current US president" in str(messages)
- return ChatMessage(
- role="assistant",
- content="""
-Thought: Let's return the report.
-Code:
-```py
-final_answer("Final report.")
-```
-""",
- )
-
- manager_model = FakeModelMultiagentsManagerAgent()
-
- class FakeModelMultiagentsManagedAgent:
- model_id = "fake_model"
-
- def __call__(
- self,
- messages,
- tools_to_call_from=None,
- stop_sequences=None,
- grammar=None,
- ):
- return ChatMessage(
- role="assistant",
- content="",
- tool_calls=[
- ChatMessageToolCall(
- id="call_0",
- type="function",
- function=ChatMessageToolCallDefinition(
- name="final_answer",
- arguments="Report on the current US president",
- ),
- )
- ],
- )
-
- managed_model = FakeModelMultiagentsManagedAgent()
-
- web_agent = ToolCallingAgent(
- tools=[],
- model=managed_model,
- max_steps=10,
- name="search_agent",
- description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
- )
-
- manager_code_agent = CodeAgent(
- tools=[],
- model=manager_model,
- managed_agents=[web_agent],
- additional_authorized_imports=["time", "numpy", "pandas"],
- )
-
- report = manager_code_agent.run("Fake question.")
- assert report == "Final report."
-
- manager_toolcalling_agent = ToolCallingAgent(
- tools=[],
- model=manager_model,
- managed_agents=[web_agent],
- )
-
- report = manager_toolcalling_agent.run("Fake question.")
- assert report == "Final report."
-
- # Test that visualization works
- manager_code_agent.visualize()
-
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return ChatMessage(
@@ -887,6 +767,191 @@ class TestCodeAgent:
assert result == expected_summary
+class MultiAgentsTests(unittest.TestCase):
+ def test_multiagents_save(self):
+ model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)
+
+ web_agent = ToolCallingAgent(
+ model=model,
+ tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()],
+ name="web_agent",
+ description="does web searches",
+ )
+ code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular")
+
+ agent = CodeAgent(
+ model=model,
+ tools=[],
+ additional_authorized_imports=["pandas", "datetime"],
+ managed_agents=[web_agent, code_agent],
+ )
+ agent.save("agent_export")
+
+ expected_structure = {
+ "managed_agents": {
+ "useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]},
+ "web_agent": {
+ "tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]},
+ "files": ["agent.json", "prompts.yaml"],
+ },
+ },
+ "tools": {"files": ["final_answer.py"]},
+ "files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"],
+ }
+
+ def verify_structure(current_path: Path, structure: dict):
+ for dir_name, contents in structure.items():
+ if dir_name != "files":
+ # For directories, verify they exist and recurse into them
+ dir_path = current_path / dir_name
+ assert dir_path.exists(), f"Directory {dir_path} does not exist"
+ assert dir_path.is_dir(), f"{dir_path} is not a directory"
+ verify_structure(dir_path, contents)
+ else:
+ # For files, verify each exists in the current path
+ for file_name in contents:
+ file_path = current_path / file_name
+ assert file_path.exists(), f"File {file_path} does not exist"
+ assert file_path.is_file(), f"{file_path} is not a file"
+
+ verify_structure(Path("agent_export"), expected_structure)
+
+ # Test that re-loaded agents work as expected.
+ agent2 = CodeAgent.from_folder("agent_export", planning_interval=5)
+ assert agent2.planning_interval == 5 # Check that kwargs are used
+ assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
+ assert (
+ agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10
+ ) # For now tool init parameters are forgotten
+ assert agent2.model.kwargs["temperature"] == pytest.approx(0.5)
+
+ def test_multiagents(self):
+ class FakeModelMultiagentsManagerAgent:
+ model_id = "fake_model"
+
+ def __call__(
+ self,
+ messages,
+ stop_sequences=None,
+ grammar=None,
+ tools_to_call_from=None,
+ ):
+ if tools_to_call_from is not None:
+ if len(messages) < 3:
+ return ChatMessage(
+ role="assistant",
+ content="",
+ tool_calls=[
+ ChatMessageToolCall(
+ id="call_0",
+ type="function",
+ function=ChatMessageToolCallDefinition(
+ name="search_agent",
+ arguments="Who is the current US president?",
+ ),
+ )
+ ],
+ )
+ else:
+ assert "Report on the current US president" in str(messages)
+ return ChatMessage(
+ role="assistant",
+ content="",
+ tool_calls=[
+ ChatMessageToolCall(
+ id="call_0",
+ type="function",
+ function=ChatMessageToolCallDefinition(
+ name="final_answer", arguments="Final report."
+ ),
+ )
+ ],
+ )
+ else:
+ if len(messages) < 3:
+ return ChatMessage(
+ role="assistant",
+ content="""
+Thought: Let's call our search agent.
+Code:
+```py
+result = search_agent("Who is the current US president?")
+```
+""",
+ )
+ else:
+ assert "Report on the current US president" in str(messages)
+ return ChatMessage(
+ role="assistant",
+ content="""
+Thought: Let's return the report.
+Code:
+```py
+final_answer("Final report.")
+```
+""",
+ )
+
+ manager_model = FakeModelMultiagentsManagerAgent()
+
+ class FakeModelMultiagentsManagedAgent:
+ model_id = "fake_model"
+
+ def __call__(
+ self,
+ messages,
+ tools_to_call_from=None,
+ stop_sequences=None,
+ grammar=None,
+ ):
+ return ChatMessage(
+ role="assistant",
+ content="",
+ tool_calls=[
+ ChatMessageToolCall(
+ id="call_0",
+ type="function",
+ function=ChatMessageToolCallDefinition(
+ name="final_answer",
+ arguments="Report on the current US president",
+ ),
+ )
+ ],
+ )
+
+ managed_model = FakeModelMultiagentsManagedAgent()
+
+ web_agent = ToolCallingAgent(
+ tools=[],
+ model=managed_model,
+ max_steps=10,
+ name="search_agent",
+ description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
+ )
+
+ manager_code_agent = CodeAgent(
+ tools=[],
+ model=manager_model,
+ managed_agents=[web_agent],
+ additional_authorized_imports=["time", "numpy", "pandas"],
+ )
+
+ report = manager_code_agent.run("Fake question.")
+ assert report == "Final report."
+
+ manager_toolcalling_agent = ToolCallingAgent(
+ tools=[],
+ model=manager_model,
+ managed_agents=[web_agent],
+ )
+
+ report = manager_toolcalling_agent.run("Fake question.")
+ assert report == "Final report."
+
+ # Test that visualization works
+ manager_code_agent.visualize()
+
+
@pytest.fixture
def prompt_templates():
return {
diff --git a/tests/test_function_type_hints_utils.py b/tests/test_function_type_hints_utils.py
index 9e58985..3379237 100644
--- a/tests/test_function_type_hints_utils.py
+++ b/tests/test_function_type_hints_utils.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
-from smolagents._function_type_hints_utils import get_json_schema
+import pytest
+
+from smolagents._function_type_hints_utils import get_imports, get_json_schema
-class AgentTextTests(unittest.TestCase):
- def test_return_none(self):
+class TestJsonSchema(unittest.TestCase):
+ def test_get_json_schema(self):
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
"""
Test function
@@ -52,3 +54,65 @@ class AgentTextTests(unittest.TestCase):
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
)
self.assertEqual(schema["function"], expected_schema)
+
+
+class TestGetCode:
+ @pytest.mark.parametrize(
+ "code, expected",
+ [
+ (
+ """
+ import numpy
+ import pandas
+ """,
+ ["numpy", "pandas"],
+ ),
+ # From imports
+ (
+ """
+ from torch import nn
+ from transformers import AutoModel
+ """,
+ ["torch", "transformers"],
+ ),
+ # Mixed case with nested imports
+ (
+ """
+ import numpy as np
+ from torch.nn import Linear
+ import os.path
+ """,
+ ["numpy", "torch", "os"],
+ ),
+ # Try/except block (should be filtered)
+ (
+ """
+ try:
+ import torch
+ except ImportError:
+ pass
+ import numpy
+ """,
+ ["numpy"],
+ ),
+ # Flash attention block (should be filtered)
+ (
+ """
+ if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func
+ import transformers
+ """,
+ ["transformers"],
+ ),
+ # Relative imports (should be excluded)
+ (
+ """
+ from .utils import helper
+ from ..models import transformer
+ """,
+ [],
+ ),
+ ],
+ )
+ def test_get_imports(self, code: str, expected: List[str]):
+ assert sorted(get_imports(code)) == sorted(expected)
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 4df4b4d..fcc05d5 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -215,8 +215,9 @@ class ToolTests(unittest.TestCase):
return str(datetime.now())
- def test_saving_tool_allows_no_arg_in_init(self):
- # Test one cannot save tool with additional args in init
+ def test_tool_to_dict_allows_no_arg_in_init(self):
+ """Test that a tool cannot be saved with required args in init"""
+
class FailTool(Tool):
name = "specific"
description = "test description"
@@ -225,15 +226,31 @@ class ToolTests(unittest.TestCase):
def __init__(self, url):
super().__init__(self)
- self.url = "none"
+ self.url = url
def forward(self, string_input: str) -> str:
return self.url + string_input
fail_tool = FailTool("dummy_url")
with pytest.raises(Exception) as e:
- fail_tool.save("output")
- assert "__init__" in str(e)
+ fail_tool.to_dict()
+ assert "All parameters of __init__ must have default values!" in str(e)
+
+ class PassTool(Tool):
+ name = "specific"
+ description = "test description"
+ inputs = {"string_input": {"type": "string", "description": "input description"}}
+ output_type = "string"
+
+ def __init__(self, url: Optional[str] = "none"):
+ super().__init__(self)
+ self.url = url
+
+ def forward(self, string_input: str) -> str:
+ return self.url + string_input
+
+ fail_tool = PassTool()
+ fail_tool.to_dict()
def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 31a8a68..25f1632 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -146,11 +146,12 @@ def test_e2e_class_tool_save():
test_tool = TestTool()
with tempfile.TemporaryDirectory() as tmp_dir:
- test_tool.save(tmp_dir)
+ test_tool.save(tmp_dir, make_gradio_app=True)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
- == """from smolagents.tools import Tool
+ == """from typing import Any, Optional
+from smolagents.tools import Tool
import IPython
class TestTool(Tool):
@@ -173,7 +174,6 @@ class TestTool(Tool):
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
-from typing import Optional
from tool import TestTool
tool = TestTool()
@@ -201,13 +201,14 @@ def test_e2e_ipython_class_tool_save():
import IPython # noqa: F401
return task
- TestTool().save("{tmp_dir}")
+ TestTool().save("{tmp_dir}", make_gradio_app=True)
""")
assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
- == """from smolagents.tools import Tool
+ == """from typing import Any, Optional
+from smolagents.tools import Tool
import IPython
class TestTool(Tool):
@@ -230,7 +231,6 @@ class TestTool(Tool):
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
-from typing import Optional
from tool import TestTool
tool = TestTool()
@@ -254,12 +254,12 @@ def test_e2e_function_tool_save():
return task
with tempfile.TemporaryDirectory() as tmp_dir:
- test_tool.save(tmp_dir)
+ test_tool.save(tmp_dir, make_gradio_app=True)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool
-from typing import Optional
+from typing import Any, Optional
class SimpleTool(Tool):
name = "test_tool"
@@ -283,7 +283,6 @@ class SimpleTool(Tool):
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
-from typing import Optional
from tool import SimpleTool
tool = SimpleTool()
@@ -311,14 +310,14 @@ def test_e2e_ipython_function_tool_save():
return task
- test_tool.save("{tmp_dir}")
+ test_tool.save("{tmp_dir}", make_gradio_app=True)
""")
assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert (
pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool
-from typing import Optional
+from typing import Any, Optional
class SimpleTool(Tool):
name = "test_tool"
@@ -342,7 +341,6 @@ class SimpleTool(Tool):
assert (
pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo
-from typing import Optional
from tool import SimpleTool
tool = SimpleTool()