diff --git a/examples/open_deep_research/scripts/mdconvert.py b/examples/open_deep_research/scripts/mdconvert.py index 72cb0a0..68f13a2 100644 --- a/examples/open_deep_research/scripts/mdconvert.py +++ b/examples/open_deep_research/scripts/mdconvert.py @@ -567,13 +567,13 @@ class WavConverter(MediaConverter): class Mp3Converter(WavConverter): """ - Converts MP3 files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed). + Converts MP3 and M4A files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed). """ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]: # Bail if not a MP3 extension = kwargs.get("file_extension", "") - if extension.lower() != ".mp3": + if extension.lower() not in [".mp3", ".m4a"]: return None md_content = "" @@ -600,7 +600,10 @@ class Mp3Converter(WavConverter): handle, temp_path = tempfile.mkstemp(suffix=".wav") os.close(handle) try: - sound = pydub.AudioSegment.from_mp3(local_path) + if extension.lower() == ".mp3": + sound = pydub.AudioSegment.from_mp3(local_path) + else: + sound = pydub.AudioSegment.from_file(local_path, format="m4a") sound.export(temp_path, format="wav") _args = dict() diff --git a/examples/open_deep_research/scripts/text_inspector_tool.py b/examples/open_deep_research/scripts/text_inspector_tool.py index 09e7c11..056168c 100644 --- a/examples/open_deep_research/scripts/text_inspector_tool.py +++ b/examples/open_deep_research/scripts/text_inspector_tool.py @@ -10,7 +10,7 @@ class TextInspectorTool(Tool): name = "inspect_file_as_text" description = """ You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it. -This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.""" +This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.""" inputs = { "file_path": { diff --git a/examples/open_deep_research/scripts/text_web_browser.py b/examples/open_deep_research/scripts/text_web_browser.py index 18763c4..ef40f85 100644 --- a/examples/open_deep_research/scripts/text_web_browser.py +++ b/examples/open_deep_research/scripts/text_web_browser.py @@ -410,7 +410,7 @@ class VisitTool(Tool): class DownloadTool(Tool): name = "download_file" description = """ -Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"] +Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".png", ".docx"] After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it. DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead.""" inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}} diff --git a/src/smolagents/_function_type_hints_utils.py b/src/smolagents/_function_type_hints_utils.py index 5eb9502..13c6a25 100644 --- a/src/smolagents/_function_type_hints_utils.py +++ b/src/smolagents/_function_type_hints_utils.py @@ -24,7 +24,6 @@ TODO: move them to `huggingface_hub` to avoid code duplication. import inspect import json -import os import re import types from copy import copy @@ -46,34 +45,31 @@ from huggingface_hub.utils import is_torch_available from .utils import _is_pillow_available -def get_imports(filename: Union[str, os.PathLike]) -> List[str]: +def get_imports(code: str) -> List[str]: """ - Extracts all the libraries (not relative imports this time) that are imported in a file. + Extracts all the libraries (not relative imports) that are imported in a code. Args: - filename (`str` or `os.PathLike`): The module file to inspect. + code (`str`): Code text to inspect. Returns: - `List[str]`: The list of all packages required to use the input module. + `list[str]`: List of all packages required to use the input code. """ - with open(filename, "r", encoding="utf-8") as f: - content = f.read() - # filter out try/except block so in custom code we can have try/except imports - content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL) + code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL) # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment - content = re.sub( + code = re.sub( r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", - content, + code, flags=re.MULTILINE, ) - # Imports of the form `import xxx` - imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `import xxx` or `import xxx as yyy` + imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE) # Imports of the form `from xxx import yyy` - imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE) # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] return list(set(imports)) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 753ae18..1ce54d1 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -14,44 +14,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"] - -import importlib.resources +import importlib import inspect +import json +import os import re +import tempfile import textwrap import time from collections import deque from logging import getLogger +from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union +import jinja2 import yaml +from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder from jinja2 import StrictUndefined, Template from rich.console import Group from rich.panel import Panel from rich.rule import Rule from rich.text import Text -from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types -from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall -from smolagents.monitoring import ( - YELLOW_HEX, - AgentLogger, - LogLevel, -) -from smolagents.utils import ( - AgentError, - AgentExecutionError, - AgentGenerationError, - AgentMaxStepsError, - AgentParsingError, - parse_code_blobs, - parse_json_tool_call, - truncate_content, -) - -from .agent_types import AgentType +from .agent_types import AgentAudio, AgentImage, AgentType, handle_agent_output_types from .default_tools import TOOL_MAPPING, FinalAnswerTool from .e2b_executor import E2BExecutor from .local_python_executor import ( @@ -59,12 +44,30 @@ from .local_python_executor import ( LocalPythonInterpreter, fix_final_answer_code, ) +from .memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall from .models import ( ChatMessage, MessageRole, + Model, +) +from .monitoring import ( + YELLOW_HEX, + AgentLogger, + LogLevel, + Monitor, ) -from .monitoring import Monitor from .tools import Tool +from .utils import ( + AgentError, + AgentExecutionError, + AgentGenerationError, + AgentMaxStepsError, + AgentParsingError, + make_init_file, + parse_code_blobs, + parse_json_tool_call, + truncate_content, +) logger = getLogger(__name__) @@ -228,9 +231,19 @@ class MultiStepAgent: ) self.managed_agents = {agent.name: agent for agent in managed_agents} + tool_and_managed_agent_names = [tool.name for tool in tools] + if managed_agents is not None: + tool_and_managed_agent_names += [agent.name for agent in managed_agents] + if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)): + raise ValueError( + "Each tool or managed_agent should have a unique name! You passed these duplicate names: " + f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}" + ) + for tool in tools: assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}" self.tools = {tool.name: tool for tool in tools} + if add_base_tools: for tool_name, tool_class in TOOL_MAPPING.items(): if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent": @@ -709,6 +722,310 @@ You have been provided with these additional arguments, that you can access usin answer += "\n" return answer + def save(self, output_dir: str, relative_path: Optional[str] = None): + """ + Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate: + + - a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`. + - a `managed_agents` folder containing the logic for each of the managed agents. + - an `agent.json` file containing a dictionary representing your agent. + - a `prompt.yaml` file containing the prompt templates used by your agent. + - an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()` + - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its + code) + + Args: + output_dir (`str`): The folder in which you want to save your tool. + """ + make_init_file(output_dir) + + # Recursively save managed agents + if self.managed_agents: + make_init_file(os.path.join(output_dir, "managed_agents")) + for agent_name, agent in self.managed_agents.items(): + agent_suffix = f"managed_agents.{agent_name}" + if relative_path: + agent_suffix = relative_path + "." + agent_suffix + agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix) + + class_name = self.__class__.__name__ + + # Save tools to different .py files + for tool in self.tools.values(): + make_init_file(os.path.join(output_dir, "tools")) + tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False) + + # Save prompts to yaml + yaml_prompts = yaml.safe_dump( + self.prompt_templates, + default_style="|", # This forces block literals for all strings + default_flow_style=False, + width=float("inf"), + sort_keys=False, + allow_unicode=True, + indent=2, + ) + + with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f: + f.write(yaml_prompts) + + # Save agent dictionary to json + agent_dict = self.to_dict() + agent_dict["tools"] = [tool.name for tool in self.tools.values()] + with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f: + json.dump(agent_dict, f, indent=4) + + # Save requirements + with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f: + f.writelines(f"{r}\n" for r in agent_dict["requirements"]) + + # Make agent.py file with Gradio UI + agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent" + managed_agent_relative_path = relative_path + "." if relative_path is not None else "" + app_template = textwrap.dedent(""" + import yaml + import os + from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }} + + # Get current directory path + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + + {% for tool in tools.values() -%} + from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }} + {% endfor %} + {% for managed_agent in managed_agents.values() -%} + from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }} + {% endfor %} + + model = {{ agent_dict['model']['class'] }}( + {% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%} + {{ key }}={{ agent_dict['model']['data'][key]|repr }}, + {% endfor %}) + + {% for tool in tools.values() -%} + {{ tool.name }} = {{ tool.name | camelcase }}() + {% endfor %} + + with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream: + prompt_templates = yaml.safe_load(stream) + + {{ agent_name }} = {{ class_name }}( + model=model, + tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}], + managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}], + {% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%} + {{ attribute_name }}={{ value|repr }}, + {% endfor %}prompt_templates=prompt_templates + ) + if __name__ == "__main__": + GradioUI({{ agent_name }}).launch() + """).strip() + template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined) + template_env.filters["repr"] = repr + template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_")) + template = template_env.from_string(app_template) + + # Render the app.py file from Jinja2 template + app_text = template.render( + { + "agent_name": agent_name, + "class_name": class_name, + "agent_dict": agent_dict, + "tools": self.tools, + "managed_agents": self.managed_agents, + "managed_agent_relative_path": managed_agent_relative_path, + } + ) + + with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f: + f.write(app_text + "\n") # Append newline at the end + + def to_dict(self) -> Dict[str, Any]: + """Converts agent into a dictionary.""" + # TODO: handle serializing step_callbacks and final_answer_checks + for attr in ["final_answer_checks", "step_callbacks"]: + if getattr(self, attr, None): + self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO) + + tool_dicts = [tool.to_dict() for tool in self.tools.values()] + tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]} + managed_agents_requirements = { + req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"] + } + requirements = tool_requirements | managed_agents_requirements + if hasattr(self, "authorized_imports"): + requirements.update( + {package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES} + ) + + agent_dict = { + "tools": tool_dicts, + "model": { + "class": self.model.__class__.__name__, + "data": self.model.to_dict(), + }, + "managed_agents": { + managed_agent.name: managed_agent.__class__.__name__ for managed_agent in self.managed_agents.values() + }, + "prompt_templates": self.prompt_templates, + "max_steps": self.max_steps, + "verbosity_level": int(self.logger.level), + "grammar": self.grammar, + "planning_interval": self.planning_interval, + "name": self.name, + "description": self.description, + "requirements": list(requirements), + } + if hasattr(self, "authorized_imports"): + agent_dict["authorized_imports"] = self.authorized_imports + return agent_dict + + @classmethod + def from_hub( + cls, + repo_id: str, + token: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Loads an agent defined on the Hub. + + + + Loading a tool from the Hub means that you'll download the tool and execute it locally. + ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when + installing a package using pip/npm/apt. + + + + Args: + repo_id (`str`): + The name of the repo on the Hub where your tool is defined. + token (`str`, *optional*): + The token to identify you on hf.co. If unset, will use the token generated when running + `huggingface-cli login` (stored in `~/.huggingface`). + trust_remote_code(`bool`, *optional*, defaults to False): + This flags marks that you understand the risk of running remote code and that you trust this tool. + If not setting this to True, loading the tool from Hub will fail. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as + `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the + others will be passed along to its init. + """ + if not trust_remote_code: + raise ValueError( + "Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`." + ) + + # Get the agent's Hub folder. + download_kwargs = {"token": token, "repo_type": "space"} | { + key: kwargs.pop(key) + for key in [ + "cache_dir", + "force_download", + "resume_download", + "proxies", + "revision", + "subfolder", + "local_files_only", + ] + if key in kwargs + } + + download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs)) + return cls.from_folder(download_folder, **kwargs) + + @classmethod + def from_folder(cls, folder: Union[str, Path], **kwargs): + """Loads an agent from a local folder""" + folder = Path(folder) + agent_dict = json.loads((folder / "agent.json").read_text()) + + # Recursively get managed agents + managed_agents = [] + for managed_agent_name, managed_agent_class in agent_dict["managed_agents"].items(): + agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class) + managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name)) + + tools = [] + for tool_name in agent_dict["tools"]: + tool_code = (folder / "tools" / f"{tool_name}.py").read_text() + tools.append(Tool.from_code(tool_code)) + + model_class: Model = getattr(importlib.import_module("smolagents.models"), agent_dict["model"]["class"]) + model = model_class.from_dict(agent_dict["model"]["data"]) + + args = dict( + model=model, + tools=tools, + managed_agents=managed_agents, + name=agent_dict["name"], + description=agent_dict["description"], + max_steps=agent_dict["max_steps"], + planning_interval=agent_dict["planning_interval"], + grammar=agent_dict["grammar"], + verbosity_level=agent_dict["verbosity_level"], + ) + if cls.__name__ == "CodeAgent": + args["additional_authorized_imports"] = agent_dict["authorized_imports"] + args.update(kwargs) + return cls(**args) + + def push_to_hub( + self, + repo_id: str, + commit_message: str = "Upload agent", + private: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + create_pr: bool = False, + ) -> str: + """ + Upload the agent to the Hub. + + Parameters: + repo_id (`str`): + The name of the repository you want to push to. It should contain your organization name when + pushing to a given organization. + commit_message (`str`, *optional*, defaults to `"Upload agent"`): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `None`): + Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether to create a PR with the uploaded files or directly commit. + """ + repo_url = create_repo( + repo_id=repo_id, + token=token, + private=private, + exist_ok=True, + repo_type="space", + space_sdk="gradio", + ) + repo_id = repo_url.repo_id + metadata_update( + repo_id, + {"tags": ["smolagents", "agent"]}, + repo_type="space", + token=token, + overwrite=True, + ) + + with tempfile.TemporaryDirectory() as work_dir: + self.save(work_dir) + logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") + return upload_folder( + repo_id=repo_id, + commit_message=commit_message, + folder_path=work_dir, + token=token, + create_pr=create_pr, + repo_type="space", + ) + class ToolCallingAgent(MultiStepAgent): """ @@ -863,6 +1180,8 @@ class CodeAgent(MultiStepAgent): ): self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) + self.use_e2b_executor = use_e2b_executor + self.max_print_outputs_length = max_print_outputs_length prompt_templates = prompt_templates or yaml.safe_load( importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text() ) diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index f744251..4425f1e 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -141,7 +141,7 @@ def stream_to_gradio( for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): # Track tokens if model provides them - if hasattr(agent.model, "last_input_token_count"): + if getattr(agent.model, "last_input_token_count", None): total_input_tokens += agent.model.last_input_token_count total_output_tokens += agent.model.last_output_token_count if isinstance(step_log, ActionStep): diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 75dda78..b576846 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -331,6 +331,53 @@ class Model: """ pass # To be implemented in child classes! + def to_dict(self) -> Dict: + """ + Converts the model into a JSON-compatible dictionary. + """ + model_dictionary = { + **self.kwargs, + "last_input_token_count": self.last_input_token_count, + "last_output_token_count": self.last_output_token_count, + "model_id": self.model_id, + } + for attribute in [ + "custom_role_conversion", + "temperature", + "max_tokens", + "provider", + "timeout", + "api_base", + "torch_dtype", + "device_map", + "organization", + "project", + "azure_endpoint", + ]: + if hasattr(self, attribute): + model_dictionary[attribute] = getattr(self, attribute) + + dangerous_attributes = ["token", "api_key"] + for attribute_name in dangerous_attributes: + if hasattr(self, attribute_name): + print( + f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually." + ) + return model_dictionary + + @classmethod + def from_dict(cls, model_dictionary: Dict[str, Any]) -> "Model": + model_instance = cls( + **{ + k: v + for k, v in model_dictionary.items() + if k not in ["last_input_token_count", "last_output_token_count"] + } + ) + model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None) + model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None) + return model_instance + class HfApiModel(Model): """A class to interact with Hugging Face's Inference API for language model interaction. diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 1f4d457..b8665c2 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -83,6 +83,31 @@ class MethodChecker(ast.NodeVisitor): self.assigned_names.add(elt.id) self.generic_visit(node) + def _handle_comprehension_generators(self, generators): + """Helper method to handle generators in all types of comprehensions""" + for generator in generators: + if isinstance(generator.target, ast.Name): + self.assigned_names.add(generator.target.id) + elif isinstance(generator.target, ast.Tuple): + for elt in generator.target.elts: + if isinstance(elt, ast.Name): + self.assigned_names.add(elt.id) + + def visit_ListComp(self, node): + """Track variables in list comprehensions""" + self._handle_comprehension_generators(node.generators) + self.generic_visit(node) + + def visit_DictComp(self, node): + """Track variables in dictionary comprehensions""" + self._handle_comprehension_generators(node.generators) + self.generic_visit(node) + + def visit_SetComp(self, node): + """Track variables in set comprehensions""" + self._handle_comprehension_generators(node.generators) + self.generic_visit(node) + def visit_Attribute(self, node): if not (isinstance(node.value, ast.Name) and node.value.id == "self"): self.generic_visit(node) @@ -121,7 +146,8 @@ class MethodChecker(ast.NodeVisitor): def validate_tool_attributes(cls, check_imports: bool = True) -> None: """ Validates that a Tool class follows the proper patterns: - 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!). + 0. Any argument of __init__ should have a default. + Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute. 1. About the class: - Class attributes should only be strings or dicts - Class attributes cannot be complex attributes @@ -140,13 +166,20 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: if not isinstance(tree.body[0], ast.ClassDef): raise ValueError("Source code must define a class") - # Check that __init__ method takes no arguments + # Check that __init__ method only has arguments with defaults if not cls.__init__.__qualname__ == "Tool.__init__": sig = inspect.signature(cls.__init__) - non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]) - if len(non_self_params) > 0: + non_default_params = [ + arg_name + for arg_name, param in sig.parameters.items() + if arg_name != "self" + and param.default == inspect.Parameter.empty + and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs + ] + if non_default_params: errors.append( - f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!" + f"This tool has required arguments in __init__: {non_default_params}. " + "All parameters of __init__ must have default values!" ) class_node = tree.body[0] @@ -198,5 +231,5 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: errors += [f"- {node.name}: {error}" for error in method_checker.errors] if errors: - raise ValueError("Tool validation failed:\n" + "\n".join(errors)) + raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors)) return diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index d73fcce..15275e7 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast -import importlib import inspect import json import logging @@ -23,6 +22,7 @@ import os import sys import tempfile import textwrap +import types from contextlib import contextmanager from functools import wraps from pathlib import Path @@ -199,24 +199,9 @@ class Tool: """ self.is_initialized = True - def save(self, output_dir): - """ - Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your - tool in `output_dir` as well as autogenerate: - - - a `tool.py` file containing the logic for your tool. - - an `app.py` file providing an UI for your tool when it is exported to a Space with `tool.push_to_hub()` - - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its - code) - - Args: - output_dir (`str`): The folder in which you want to save your tool. - """ - os.makedirs(output_dir, exist_ok=True) + def to_dict(self) -> dict: + """Returns a dictionary representing the tool""" class_name = self.__class__.__name__ - tool_file = os.path.join(output_dir, "tool.py") - - # Save tool file if type(self).__name__ == "SimpleTool": # Check that imports are self-contained source_code = get_source(self.forward).replace("@tool", "") @@ -232,7 +217,7 @@ class Tool: tool_code = textwrap.dedent( f""" from smolagents import Tool - from typing import Optional + from typing import Any, Optional class {class_name}(Tool): name = "{self.name}" @@ -272,33 +257,59 @@ class Tool: validate_tool_attributes(self.__class__) - tool_code = instance_to_source(self, base_cls=Tool) + tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool) + + requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"} + + return {"name": self.name, "code": tool_code, "requirements": requirements} + + def save(self, output_dir: str, tool_file_name: str = "tool", make_gradio_app: bool = True): + """ + Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your + tool in `output_dir` as well as autogenerate: + + - a `{tool_file_name}.py` file containing the logic for your tool. + If you pass `make_gradio_app=True`, this will also write: + - an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()` + - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its + code) + + Args: + output_dir (`str`): The folder in which you want to save your tool. + tool_file_name (`str`, *optional*): The file name in which you want to save your tool. + make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI. + """ + os.makedirs(output_dir, exist_ok=True) + class_name = self.__class__.__name__ + tool_file = os.path.join(output_dir, f"{tool_file_name}.py") + + tool_dict = self.to_dict() + tool_code = tool_dict["code"] with open(tool_file, "w", encoding="utf-8") as f: f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}")) - # Save app file - app_file = os.path.join(output_dir, "app.py") - with open(app_file, "w", encoding="utf-8") as f: - f.write( - textwrap.dedent( - f""" - from smolagents import launch_gradio_demo - from typing import Optional - from tool import {class_name} + if make_gradio_app: + # Save app file + app_file = os.path.join(output_dir, "app.py") + with open(app_file, "w", encoding="utf-8") as f: + f.write( + textwrap.dedent( + f""" + from smolagents import launch_gradio_demo + from {tool_file_name} import {class_name} - tool = {class_name}() + tool = {class_name}() - launch_gradio_demo(tool) - """ - ).lstrip() - ) + launch_gradio_demo(tool) + """ + ).lstrip() + ) - # Save requirements file - imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"} - requirements_file = os.path.join(output_dir, "requirements.txt") - with open(requirements_file, "w", encoding="utf-8") as f: - f.write("\n".join(imports) + "\n") + # Save requirements file + requirements_file = os.path.join(output_dir, "requirements.txt") + with open(requirements_file, "w", encoding="utf-8") as f: + f.write("\n".join(tool_dict["requirements"]) + "\n") def push_to_hub( self, @@ -311,14 +322,6 @@ class Tool: """ Upload the tool to the Hub. - For this method to work properly, your tool must have been defined in a separate module (not `__main__`). - For instance: - ``` - from my_tool_module import MyTool - my_tool = MyTool() - my_tool.push_to_hub("my-username/my-space") - ``` - Parameters: repo_id (`str`): The name of the repository you want to push your tool to. It should contain your organization name when @@ -342,13 +345,11 @@ class Tool: space_sdk="gradio", ) repo_id = repo_url.repo_id - metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space", token=token) + metadata_update(repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token) with tempfile.TemporaryDirectory() as work_dir: # Save all files. self.save(work_dir) - with open(work_dir + "/tool.py", "r") as f: - print("\n".join(f.readlines())) logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") return upload_folder( repo_id=repo_id, @@ -394,7 +395,7 @@ class Tool: """ if not trust_remote_code: raise ValueError( - "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." + "Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`." ) # Get the tool's tool.py file. @@ -413,30 +414,26 @@ class Tool: ) tool_code = Path(tool_file).read_text() + return Tool.from_code(tool_code, **kwargs) - # Find the Tool subclass in the namespace - with tempfile.TemporaryDirectory() as temp_dir: - # Save the code to a file - module_path = os.path.join(temp_dir, "tool.py") - with open(module_path, "w") as f: - f.write(tool_code) + @classmethod + def from_code(cls, tool_code: str, **kwargs): + module = types.ModuleType("dynamic_tool") - print("TOOL CODE:\n", tool_code) + exec(tool_code, module.__dict__) - # Load module from file path - spec = importlib.util.spec_from_file_location("tool", module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + # Find the Tool subclass + tool_class = next( + ( + obj + for _, obj in inspect.getmembers(module, inspect.isclass) + if issubclass(obj, Tool) and obj is not Tool + ), + None, + ) - # Find and instantiate the Tool class - for item_name in dir(module): - item = getattr(module, item_name) - if isinstance(item, type) and issubclass(item, Tool) and item != Tool: - tool_class = item - break - - if tool_class is None: - raise ValueError("No Tool subclass found in the code.") + if tool_class is None: + raise ValueError("No Tool subclass found in the code.") if not isinstance(tool_class.inputs, dict): tool_class.inputs = ast.literal_eval(tool_class.inputs) diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 6176269..b9868dc 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -20,6 +20,7 @@ import importlib.metadata import importlib.util import inspect import json +import os import re import textwrap import types @@ -414,3 +415,10 @@ def encode_image_base64(image): def make_image_url(base64_image): return f"data:image/png;base64,{base64_image}" + + +def make_init_file(folder: str): + os.makedirs(folder, exist_ok=True) + # Create __init__ + with open(os.path.join(folder, "__init__.py"), "w"): + pass diff --git a/tests/test_agents.py b/tests/test_agents.py index d4ce8d7..adcdb6a 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -31,12 +31,13 @@ from smolagents.agents import ( ToolCallingAgent, populate_template, ) -from smolagents.default_tools import PythonInterpreterTool +from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool from smolagents.memory import PlanningStep from smolagents.models import ( ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, + HfApiModel, MessageRole, TransformersModel, ) @@ -436,10 +437,15 @@ class AgentTests(unittest.TestCase): assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] - agent = CodeAgent(tools=toolset_2, model=fake_code_model) - assert ( - len(agent.tools) == 2 - ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer + with pytest.raises(ValueError) as e: + agent = CodeAgent(tools=toolset_2, model=fake_code_model) + assert "Each tool or managed_agent should have a unique name!" in str(e) + + with pytest.raises(ValueError) as e: + agent.name = "python_interpreter" + agent.description = "empty" + CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model, managed_agents=[agent]) + assert "Each tool or managed_agent should have a unique name!" in str(e) # check that python_interpreter base tool does not get added to CodeAgent agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True) @@ -484,132 +490,6 @@ class AgentTests(unittest.TestCase): str_output = capture.get() assert "`additional_authorized_imports`" in str_output.replace("\n", "") - def test_multiagents(self): - class FakeModelMultiagentsManagerAgent: - model_id = "fake_model" - - def __call__( - self, - messages, - stop_sequences=None, - grammar=None, - tools_to_call_from=None, - ): - if tools_to_call_from is not None: - if len(messages) < 3: - return ChatMessage( - role="assistant", - content="", - tool_calls=[ - ChatMessageToolCall( - id="call_0", - type="function", - function=ChatMessageToolCallDefinition( - name="search_agent", - arguments="Who is the current US president?", - ), - ) - ], - ) - else: - assert "Report on the current US president" in str(messages) - return ChatMessage( - role="assistant", - content="", - tool_calls=[ - ChatMessageToolCall( - id="call_0", - type="function", - function=ChatMessageToolCallDefinition( - name="final_answer", arguments="Final report." - ), - ) - ], - ) - else: - if len(messages) < 3: - return ChatMessage( - role="assistant", - content=""" -Thought: Let's call our search agent. -Code: -```py -result = search_agent("Who is the current US president?") -``` -""", - ) - else: - assert "Report on the current US president" in str(messages) - return ChatMessage( - role="assistant", - content=""" -Thought: Let's return the report. -Code: -```py -final_answer("Final report.") -``` -""", - ) - - manager_model = FakeModelMultiagentsManagerAgent() - - class FakeModelMultiagentsManagedAgent: - model_id = "fake_model" - - def __call__( - self, - messages, - tools_to_call_from=None, - stop_sequences=None, - grammar=None, - ): - return ChatMessage( - role="assistant", - content="", - tool_calls=[ - ChatMessageToolCall( - id="call_0", - type="function", - function=ChatMessageToolCallDefinition( - name="final_answer", - arguments="Report on the current US president", - ), - ) - ], - ) - - managed_model = FakeModelMultiagentsManagedAgent() - - web_agent = ToolCallingAgent( - tools=[], - model=managed_model, - max_steps=10, - name="search_agent", - description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports", - ) - - manager_code_agent = CodeAgent( - tools=[], - model=manager_model, - managed_agents=[web_agent], - additional_authorized_imports=["time", "numpy", "pandas"], - ) - - report = manager_code_agent.run("Fake question.") - assert report == "Final report." - - manager_toolcalling_agent = ToolCallingAgent( - tools=[], - model=manager_model, - managed_agents=[web_agent], - ) - - report = manager_toolcalling_agent.run("Fake question.") - assert report == "Final report." - - # Test that visualization works - manager_code_agent.visualize() - def test_code_nontrivial_final_answer_works(self): def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): return ChatMessage( @@ -887,6 +767,191 @@ class TestCodeAgent: assert result == expected_summary +class MultiAgentsTests(unittest.TestCase): + def test_multiagents_save(self): + model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5) + + web_agent = ToolCallingAgent( + model=model, + tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()], + name="web_agent", + description="does web searches", + ) + code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular") + + agent = CodeAgent( + model=model, + tools=[], + additional_authorized_imports=["pandas", "datetime"], + managed_agents=[web_agent, code_agent], + ) + agent.save("agent_export") + + expected_structure = { + "managed_agents": { + "useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]}, + "web_agent": { + "tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]}, + "files": ["agent.json", "prompts.yaml"], + }, + }, + "tools": {"files": ["final_answer.py"]}, + "files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"], + } + + def verify_structure(current_path: Path, structure: dict): + for dir_name, contents in structure.items(): + if dir_name != "files": + # For directories, verify they exist and recurse into them + dir_path = current_path / dir_name + assert dir_path.exists(), f"Directory {dir_path} does not exist" + assert dir_path.is_dir(), f"{dir_path} is not a directory" + verify_structure(dir_path, contents) + else: + # For files, verify each exists in the current path + for file_name in contents: + file_path = current_path / file_name + assert file_path.exists(), f"File {file_path} does not exist" + assert file_path.is_file(), f"{file_path} is not a file" + + verify_structure(Path("agent_export"), expected_structure) + + # Test that re-loaded agents work as expected. + agent2 = CodeAgent.from_folder("agent_export", planning_interval=5) + assert agent2.planning_interval == 5 # Check that kwargs are used + assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES) + assert ( + agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10 + ) # For now tool init parameters are forgotten + assert agent2.model.kwargs["temperature"] == pytest.approx(0.5) + + def test_multiagents(self): + class FakeModelMultiagentsManagerAgent: + model_id = "fake_model" + + def __call__( + self, + messages, + stop_sequences=None, + grammar=None, + tools_to_call_from=None, + ): + if tools_to_call_from is not None: + if len(messages) < 3: + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="search_agent", + arguments="Who is the current US president?", + ), + ) + ], + ) + else: + assert "Report on the current US president" in str(messages) + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="final_answer", arguments="Final report." + ), + ) + ], + ) + else: + if len(messages) < 3: + return ChatMessage( + role="assistant", + content=""" +Thought: Let's call our search agent. +Code: +```py +result = search_agent("Who is the current US president?") +``` +""", + ) + else: + assert "Report on the current US president" in str(messages) + return ChatMessage( + role="assistant", + content=""" +Thought: Let's return the report. +Code: +```py +final_answer("Final report.") +``` +""", + ) + + manager_model = FakeModelMultiagentsManagerAgent() + + class FakeModelMultiagentsManagedAgent: + model_id = "fake_model" + + def __call__( + self, + messages, + tools_to_call_from=None, + stop_sequences=None, + grammar=None, + ): + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + ChatMessageToolCall( + id="call_0", + type="function", + function=ChatMessageToolCallDefinition( + name="final_answer", + arguments="Report on the current US president", + ), + ) + ], + ) + + managed_model = FakeModelMultiagentsManagedAgent() + + web_agent = ToolCallingAgent( + tools=[], + model=managed_model, + max_steps=10, + name="search_agent", + description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports", + ) + + manager_code_agent = CodeAgent( + tools=[], + model=manager_model, + managed_agents=[web_agent], + additional_authorized_imports=["time", "numpy", "pandas"], + ) + + report = manager_code_agent.run("Fake question.") + assert report == "Final report." + + manager_toolcalling_agent = ToolCallingAgent( + tools=[], + model=manager_model, + managed_agents=[web_agent], + ) + + report = manager_toolcalling_agent.run("Fake question.") + assert report == "Final report." + + # Test that visualization works + manager_code_agent.visualize() + + @pytest.fixture def prompt_templates(): return { diff --git a/tests/test_function_type_hints_utils.py b/tests/test_function_type_hints_utils.py index 9e58985..3379237 100644 --- a/tests/test_function_type_hints_utils.py +++ b/tests/test_function_type_hints_utils.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from typing import Optional, Tuple +from typing import List, Optional, Tuple -from smolagents._function_type_hints_utils import get_json_schema +import pytest + +from smolagents._function_type_hints_utils import get_imports, get_json_schema -class AgentTextTests(unittest.TestCase): - def test_return_none(self): +class TestJsonSchema(unittest.TestCase): + def test_get_json_schema(self): def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None: """ Test function @@ -52,3 +54,65 @@ class AgentTextTests(unittest.TestCase): schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"] ) self.assertEqual(schema["function"], expected_schema) + + +class TestGetCode: + @pytest.mark.parametrize( + "code, expected", + [ + ( + """ + import numpy + import pandas + """, + ["numpy", "pandas"], + ), + # From imports + ( + """ + from torch import nn + from transformers import AutoModel + """, + ["torch", "transformers"], + ), + # Mixed case with nested imports + ( + """ + import numpy as np + from torch.nn import Linear + import os.path + """, + ["numpy", "torch", "os"], + ), + # Try/except block (should be filtered) + ( + """ + try: + import torch + except ImportError: + pass + import numpy + """, + ["numpy"], + ), + # Flash attention block (should be filtered) + ( + """ + if is_flash_attn_2_available(): + from flash_attn import flash_attn_func + import transformers + """, + ["transformers"], + ), + # Relative imports (should be excluded) + ( + """ + from .utils import helper + from ..models import transformer + """, + [], + ), + ], + ) + def test_get_imports(self, code: str, expected: List[str]): + assert sorted(get_imports(code)) == sorted(expected) diff --git a/tests/test_tools.py b/tests/test_tools.py index 4df4b4d..fcc05d5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -215,8 +215,9 @@ class ToolTests(unittest.TestCase): return str(datetime.now()) - def test_saving_tool_allows_no_arg_in_init(self): - # Test one cannot save tool with additional args in init + def test_tool_to_dict_allows_no_arg_in_init(self): + """Test that a tool cannot be saved with required args in init""" + class FailTool(Tool): name = "specific" description = "test description" @@ -225,15 +226,31 @@ class ToolTests(unittest.TestCase): def __init__(self, url): super().__init__(self) - self.url = "none" + self.url = url def forward(self, string_input: str) -> str: return self.url + string_input fail_tool = FailTool("dummy_url") with pytest.raises(Exception) as e: - fail_tool.save("output") - assert "__init__" in str(e) + fail_tool.to_dict() + assert "All parameters of __init__ must have default values!" in str(e) + + class PassTool(Tool): + name = "specific" + description = "test description" + inputs = {"string_input": {"type": "string", "description": "input description"}} + output_type = "string" + + def __init__(self, url: Optional[str] = "none"): + super().__init__(self) + self.url = url + + def forward(self, string_input: str) -> str: + return self.url + string_input + + fail_tool = PassTool() + fail_tool.to_dict() def test_saving_tool_allows_no_imports_from_outside_methods(self): # Test that using imports from outside functions fails diff --git a/tests/test_utils.py b/tests/test_utils.py index 31a8a68..25f1632 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -146,11 +146,12 @@ def test_e2e_class_tool_save(): test_tool = TestTool() with tempfile.TemporaryDirectory() as tmp_dir: - test_tool.save(tmp_dir) + test_tool.save(tmp_dir, make_gradio_app=True) assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert ( pathlib.Path(tmp_dir, "tool.py").read_text() - == """from smolagents.tools import Tool + == """from typing import Any, Optional +from smolagents.tools import Tool import IPython class TestTool(Tool): @@ -173,7 +174,6 @@ class TestTool(Tool): assert ( pathlib.Path(tmp_dir, "app.py").read_text() == """from smolagents import launch_gradio_demo -from typing import Optional from tool import TestTool tool = TestTool() @@ -201,13 +201,14 @@ def test_e2e_ipython_class_tool_save(): import IPython # noqa: F401 return task - TestTool().save("{tmp_dir}") + TestTool().save("{tmp_dir}", make_gradio_app=True) """) assert shell.run_cell(code_blob, store_history=True).success assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert ( pathlib.Path(tmp_dir, "tool.py").read_text() - == """from smolagents.tools import Tool + == """from typing import Any, Optional +from smolagents.tools import Tool import IPython class TestTool(Tool): @@ -230,7 +231,6 @@ class TestTool(Tool): assert ( pathlib.Path(tmp_dir, "app.py").read_text() == """from smolagents import launch_gradio_demo -from typing import Optional from tool import TestTool tool = TestTool() @@ -254,12 +254,12 @@ def test_e2e_function_tool_save(): return task with tempfile.TemporaryDirectory() as tmp_dir: - test_tool.save(tmp_dir) + test_tool.save(tmp_dir, make_gradio_app=True) assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert ( pathlib.Path(tmp_dir, "tool.py").read_text() == """from smolagents import Tool -from typing import Optional +from typing import Any, Optional class SimpleTool(Tool): name = "test_tool" @@ -283,7 +283,6 @@ class SimpleTool(Tool): assert ( pathlib.Path(tmp_dir, "app.py").read_text() == """from smolagents import launch_gradio_demo -from typing import Optional from tool import SimpleTool tool = SimpleTool() @@ -311,14 +310,14 @@ def test_e2e_ipython_function_tool_save(): return task - test_tool.save("{tmp_dir}") + test_tool.save("{tmp_dir}", make_gradio_app=True) """) assert shell.run_cell(code_blob, store_history=True).success assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert ( pathlib.Path(tmp_dir, "tool.py").read_text() == """from smolagents import Tool -from typing import Optional +from typing import Any, Optional class SimpleTool(Tool): name = "test_tool" @@ -342,7 +341,6 @@ class SimpleTool(Tool): assert ( pathlib.Path(tmp_dir, "app.py").read_text() == """from smolagents import launch_gradio_demo -from typing import Optional from tool import SimpleTool tool = SimpleTool()