Support optional arguments in tool calls
This commit is contained in:
		
							parent
							
								
									93569bd7c1
								
							
						
					
					
						commit
						e5ca0f0cb8
					
				|  | @ -44,7 +44,7 @@ This library offers: | |||
| 
 | ||||
| First install the package. | ||||
| ```bash | ||||
| pip install agents | ||||
| pip install smolagents | ||||
| ``` | ||||
| Then define your agent, give it the tools it needs and run it! | ||||
| ```py | ||||
|  |  | |||
|  | @ -1,9 +1,9 @@ | |||
| # docstyle-ignore | ||||
| INSTALL_CONTENT = """ | ||||
| # Transformers installation | ||||
| ! pip install agents | ||||
| # Installation | ||||
| ! pip install smolagents | ||||
| # To install from source instead of the last release, comment the command above and uncomment the following one. | ||||
| # ! pip install git+https://github.com/huggingface/agents.git | ||||
| # ! pip install git+https://github.com/huggingface/smolagents.git | ||||
| """ | ||||
| 
 | ||||
| notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] | ||||
|  |  | |||
|  | @ -1,19 +1,21 @@ | |||
| from smolagents.agents import ToolCallingAgent | ||||
| from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| # Choose which LLM engine to use! | ||||
| model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") | ||||
| model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") | ||||
| # model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") | ||||
| # model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") | ||||
| model = LiteLLMModel("gpt-4o") | ||||
| 
 | ||||
| @tool | ||||
| def get_weather(location: str) -> str: | ||||
| def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||
|     """ | ||||
|     Get weather in the next days at given location. | ||||
|     Secretly this tool does not care about the location, it hates the weather everywhere. | ||||
| 
 | ||||
|     Args: | ||||
|         location: the location | ||||
|         celsius: the temperature | ||||
|     """ | ||||
|     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
| 
 | ||||
|  |  | |||
|  | @ -172,6 +172,7 @@ class GoogleSearchTool(Tool): | |||
|         "filter_year": { | ||||
|             "type": "integer", | ||||
|             "description": "Optionally restrict results to a certain year", | ||||
|             "nullable": True, | ||||
|         }, | ||||
|     } | ||||
|     output_type = "string" | ||||
|  | @ -209,9 +210,14 @@ class GoogleSearchTool(Tool): | |||
|             raise ValueError(response.json()) | ||||
| 
 | ||||
|         if "organic_results" not in results.keys(): | ||||
|             raise Exception( | ||||
|                 f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." | ||||
|             ) | ||||
|             if filter_year is not None: | ||||
|                 raise Exception( | ||||
|                     f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." | ||||
|                 ) | ||||
|             else: | ||||
|                 raise Exception( | ||||
|                     f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." | ||||
|                 ) | ||||
|         if len(results["organic_results"]) == 0: | ||||
|             year_filter_message = ( | ||||
|                 f" with filter year={filter_year}" if filter_year is not None else "" | ||||
|  |  | |||
|  | @ -67,9 +67,12 @@ tool_role_conversions = { | |||
| 
 | ||||
| def get_json_schema(tool: Tool) -> Dict: | ||||
|     properties = deepcopy(tool.inputs) | ||||
|     for value in properties.values(): | ||||
|     required = [] | ||||
|     for key, value in properties.items(): | ||||
|         if value["type"] == "any": | ||||
|             value["type"] = "string" | ||||
|         if not ("nullable" in value and value["nullable"]): | ||||
|             required.append(key) | ||||
|     return { | ||||
|         "type": "function", | ||||
|         "function": { | ||||
|  | @ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict: | |||
|             "parameters": { | ||||
|                 "type": "object", | ||||
|                 "properties": properties, | ||||
|                 "required": list(tool.inputs.keys()), | ||||
|                 "required": required, | ||||
|             }, | ||||
|         }, | ||||
|     } | ||||
|  |  | |||
|  | @ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?" | |||
| Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. | ||||
| Code: | ||||
| ```py | ||||
| population_guangzhou = search("Guangzhou population") | ||||
| print("Population Guangzhou:", population_guangzhou) | ||||
| population_shanghai = search("Shanghai population") | ||||
| print("Population Shanghai:", population_shanghai) | ||||
| for city in ["Guangzhou", "Shanghai"]: | ||||
|     print(f"Population {city}:", search(f"{city} population") | ||||
| ```<end_code> | ||||
| Observation: | ||||
| Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] | ||||
|  | @ -278,11 +276,13 @@ final_answer("Shanghai") | |||
| --- | ||||
| Task: "What is the current age of the pope, raised to the power 0.36?" | ||||
| 
 | ||||
| Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36. | ||||
| Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search. | ||||
| Code: | ||||
| ```py | ||||
| pope_age = wiki(query="current pope age") | ||||
| print("Pope age:", pope_age) | ||||
| pope_age_wiki = wiki(query="current pope age") | ||||
| print("Pope age as per wikipedia:", pope_age_wiki) | ||||
| pope_age_search = web_search(query="current pope age") | ||||
| print("Pope age as per google search:", pope_age_search) | ||||
| ```<end_code> | ||||
| Observation: | ||||
| Pope age: "The pope Francis is currently 85 years old." | ||||
|  |  | |||
|  | @ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor): | |||
|         self.generic_visit(node) | ||||
| 
 | ||||
|     def visit_Attribute(self, node): | ||||
|         # Skip self.something | ||||
|         if not (isinstance(node.value, ast.Name) and node.value.id == "self"): | ||||
|             self.generic_visit(node) | ||||
| 
 | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ import torch | |||
| import textwrap | ||||
| from functools import lru_cache, wraps | ||||
| from pathlib import Path | ||||
| from typing import Callable, Dict, List, Optional, Union | ||||
| from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||
| from huggingface_hub import ( | ||||
|     create_repo, | ||||
|     get_collection, | ||||
|  | @ -42,6 +42,8 @@ from transformers.utils import ( | |||
|     is_accelerate_available, | ||||
|     is_torch_available, | ||||
| ) | ||||
| from transformers.utils.chat_template_utils import _parse_type_hint | ||||
| 
 | ||||
| from transformers.dynamic_module_utils import get_imports | ||||
| from transformers import AutoProcessor | ||||
| 
 | ||||
|  | @ -95,17 +97,27 @@ def setup_default_tools(): | |||
|     return default_tools | ||||
| 
 | ||||
| 
 | ||||
| def validate_after_init(cls, do_validate_forward: bool = True): | ||||
| def validate_after_init(cls): | ||||
|     original_init = cls.__init__ | ||||
| 
 | ||||
|     @wraps(original_init) | ||||
|     def new_init(self, *args, **kwargs): | ||||
|         original_init(self, *args, **kwargs) | ||||
|         self.validate_arguments(do_validate_forward=do_validate_forward) | ||||
|         self.validate_arguments() | ||||
| 
 | ||||
|     cls.__init__ = new_init | ||||
|     return cls | ||||
| 
 | ||||
| def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | ||||
|     type_hints = get_type_hints(func) | ||||
|     signature = inspect.signature(func) | ||||
|     properties = {} | ||||
|     for param_name, param_type in type_hints.items(): | ||||
|         if param_name != "return": | ||||
|             properties[param_name] = _parse_type_hint(param_type) | ||||
|             if signature.parameters[param_name].default != inspect.Parameter.empty: | ||||
|                 properties[param_name]["nullable"] = True | ||||
|     return properties | ||||
| 
 | ||||
| AUTHORIZED_TYPES = [ | ||||
|     "string", | ||||
|  | @ -145,7 +157,7 @@ class Tool: | |||
| 
 | ||||
|     name: str | ||||
|     description: str | ||||
|     inputs: Dict[str, Dict[str, Union[str, type]]] | ||||
|     inputs: Dict[str, Dict[str, Union[str, type, bool]]] | ||||
|     output_type: str | ||||
| 
 | ||||
|     def __init__(self, *args, **kwargs): | ||||
|  | @ -153,9 +165,9 @@ class Tool: | |||
| 
 | ||||
|     def __init_subclass__(cls, **kwargs): | ||||
|         super().__init_subclass__(**kwargs) | ||||
|         validate_after_init(cls, do_validate_forward=False) | ||||
|         validate_after_init(cls) | ||||
| 
 | ||||
|     def validate_arguments(self, do_validate_forward: bool = True): | ||||
|     def validate_arguments(self): | ||||
|         required_attributes = { | ||||
|             "description": str, | ||||
|             "name": str, | ||||
|  | @ -184,12 +196,21 @@ class Tool: | |||
|                 ) | ||||
| 
 | ||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||
|         if do_validate_forward: | ||||
|             signature = inspect.signature(self.forward) | ||||
|             if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||
|                 raise Exception( | ||||
|                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||
|                 ) | ||||
| 
 | ||||
|         # Validate forward function signature | ||||
|         signature = inspect.signature(self.forward) | ||||
| 
 | ||||
|         if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||
|             raise Exception( | ||||
|                 "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||
|             ) | ||||
| 
 | ||||
|         json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||
|         for key, value in self.inputs.items(): | ||||
|             if "nullable" in value: | ||||
|                 assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||
|             if key in json_schema and "nullable" in json_schema[key]: | ||||
|                 assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||
| 
 | ||||
|     def forward(self, *args, **kwargs): | ||||
|         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||
|  | @ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool: | |||
|             "Tool return type not found: make sure your function has a return type hint!" | ||||
|         ) | ||||
| 
 | ||||
|     if parameters["return"]["type"] == "object": | ||||
|         parameters["return"]["type"] = "any" | ||||
| 
 | ||||
|     class SimpleTool(Tool): | ||||
|         def __init__(self, name, description, inputs, output_type, function): | ||||
|             self.name = name | ||||
|  | @ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool: | |||
|     ) | ||||
|     original_signature = inspect.signature(tool_function) | ||||
|     new_parameters = [ | ||||
|         inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) | ||||
|         inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY) | ||||
|     ] + list(original_signature.parameters.values()) | ||||
|     new_signature = original_signature.replace(parameters=new_parameters) | ||||
|     simple_tool.forward.__signature__ = new_signature | ||||
|     # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")]) | ||||
|     return simple_tool | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING | |||
| 
 | ||||
| from smolagents.default_tools import FinalAnswerTool | ||||
| 
 | ||||
| from .test_tools_common import ToolTesterMixin | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ from smolagents.local_python_executor import ( | |||
|     evaluate_python_code, | ||||
| ) | ||||
| 
 | ||||
| from .test_tools_common import ToolTesterMixin | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| # Fake function we will use as tool | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ import unittest | |||
| 
 | ||||
| from smolagents import load_tool | ||||
| 
 | ||||
| from .test_tools_common import ToolTesterMixin | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ | |||
| # limitations under the License. | ||||
| import unittest | ||||
| from pathlib import Path | ||||
| from typing import Dict, Union | ||||
| from typing import Dict, Union, Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| import pytest | ||||
|  | @ -126,9 +126,9 @@ class ToolTests(unittest.TestCase): | |||
|                     "description": "the task category (such as text-classification, depth-estimation, etc)", | ||||
|                 } | ||||
|             } | ||||
|             output_type = "integer" | ||||
|             output_type = "string" | ||||
| 
 | ||||
|             def forward(self, task): | ||||
|             def forward(self, task: str) -> str: | ||||
|                 return "best model" | ||||
| 
 | ||||
|         tool = HFModelDownloadsTool() | ||||
|  | @ -223,7 +223,7 @@ class ToolTests(unittest.TestCase): | |||
|             name = "specific" | ||||
|             description = "test description" | ||||
|             inputs = { | ||||
|                 "input_str": {"type": "string", "description": "input description"} | ||||
|                 "string_input": {"type": "string", "description": "input description"} | ||||
|             } | ||||
|             output_type = "string" | ||||
| 
 | ||||
|  | @ -231,7 +231,7 @@ class ToolTests(unittest.TestCase): | |||
|                 super().__init__(self) | ||||
|                 self.url = "none" | ||||
| 
 | ||||
|             def forward(self, string_input): | ||||
|             def forward(self, string_input: str) -> str: | ||||
|                 return self.url + string_input | ||||
| 
 | ||||
|         fail_tool = FailTool("dummy_url") | ||||
|  | @ -241,46 +241,127 @@ class ToolTests(unittest.TestCase): | |||
| 
 | ||||
|     def test_saving_tool_allows_no_imports_from_outside_methods(self): | ||||
|         # Test that using imports from outside functions fails | ||||
|         from numpy import random | ||||
|         import numpy as np | ||||
| 
 | ||||
|         class FailTool2(Tool): | ||||
|         class FailTool(Tool): | ||||
|             name = "specific" | ||||
|             description = "test description" | ||||
|             inputs = { | ||||
|                 "input_str": {"type": "string", "description": "input description"} | ||||
|                 "string_input": {"type": "string", "description": "input description"} | ||||
|             } | ||||
|             output_type = "string" | ||||
| 
 | ||||
|             def useless_method(self): | ||||
|                 self.client = random.random() | ||||
|                 self.client = np.random.random() | ||||
|                 return "" | ||||
| 
 | ||||
|             def forward(self, string_input): | ||||
|                 return self.useless_method() + string_input | ||||
| 
 | ||||
|         fail_tool_2 = FailTool2() | ||||
|         fail_tool = FailTool() | ||||
|         with pytest.raises(Exception) as e: | ||||
|             fail_tool_2.save("output") | ||||
|         assert "random" in str(e) | ||||
|             fail_tool.save("output") | ||||
|         assert "'np' is undefined" in str(e) | ||||
| 
 | ||||
|         # Test that putting these imports inside functions works | ||||
| 
 | ||||
|         class FailTool3(Tool): | ||||
|         class SuccessTool(Tool): | ||||
|             name = "specific" | ||||
|             description = "test description" | ||||
|             inputs = { | ||||
|                 "input_str": {"type": "string", "description": "input description"} | ||||
|                 "string_input": {"type": "string", "description": "input description"} | ||||
|             } | ||||
|             output_type = "string" | ||||
| 
 | ||||
|             def useless_method(self): | ||||
|                 from numpy import random | ||||
|                 import numpy as np | ||||
| 
 | ||||
|                 self.client = random.random() | ||||
|                 self.client = np.random.random() | ||||
|                 return "" | ||||
| 
 | ||||
|             def forward(self, string_input): | ||||
|                 return self.useless_method() + string_input | ||||
| 
 | ||||
|         fail_tool_3 = FailTool3() | ||||
|         fail_tool_3.save("output") | ||||
|         success_tool = SuccessTool() | ||||
|         success_tool.save("output") | ||||
| 
 | ||||
|     def test_tool_missing_class_attributes_raises_error(self): | ||||
|         with pytest.raises(Exception) as e: | ||||
|             class GetWeatherTool(Tool): | ||||
|                 name = "get_weather" | ||||
|                 description = "Get weather in the next days at given location." | ||||
|                 inputs = { | ||||
|                     "location": {"type": "string", "description": "the location"}, | ||||
|                     "celsius": {"type": "string", "description": "the temperature type"} | ||||
|                 } | ||||
| 
 | ||||
|                 def forward(self, location: str, celsius: Optional[bool] = False) -> str: | ||||
|                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
|                  | ||||
|             tool = GetWeatherTool() | ||||
|         assert "You must set an attribute output_type" in str(e) | ||||
| 
 | ||||
|     def test_tool_from_decorator_optional_args(self): | ||||
|         @tool | ||||
|         def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||
|             """ | ||||
|             Get weather in the next days at given location. | ||||
|             Secretly this tool does not care about the location, it hates the weather everywhere. | ||||
| 
 | ||||
|             Args: | ||||
|                 location: the location | ||||
|                 celsius: the temperature type | ||||
|             """ | ||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
|          | ||||
|         assert "nullable" in get_weather.inputs["celsius"] | ||||
|         assert get_weather.inputs["celsius"]["nullable"] == True | ||||
|         assert "nullable" not in get_weather.inputs["location"] | ||||
| 
 | ||||
|     def test_tool_mismatching_nullable_args_raises_error(self): | ||||
|         with pytest.raises(Exception) as e: | ||||
|             class GetWeatherTool(Tool): | ||||
|                 name = "get_weather" | ||||
|                 description = "Get weather in the next days at given location." | ||||
|                 inputs = { | ||||
|                     "location": {"type": "string", "description": "the location"}, | ||||
|                     "celsius": {"type": "string", "description": "the temperature type"} | ||||
|                 } | ||||
|                 output_type = "string" | ||||
| 
 | ||||
|                 def forward(self, location: str, celsius: Optional[bool] = False) -> str: | ||||
|                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
|                  | ||||
|             tool = GetWeatherTool() | ||||
|         assert "Nullable" in str(e) | ||||
| 
 | ||||
|         with pytest.raises(Exception) as e: | ||||
|             class GetWeatherTool2(Tool): | ||||
|                 name = "get_weather" | ||||
|                 description = "Get weather in the next days at given location." | ||||
|                 inputs = { | ||||
|                     "location": {"type": "string", "description": "the location"}, | ||||
|                     "celsius": {"type": "string", "description": "the temperature type"} | ||||
|                 } | ||||
|                 output_type = "string" | ||||
| 
 | ||||
|                 def forward(self, location: str, celsius: bool = False) -> str: | ||||
|                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
|                  | ||||
|             tool = GetWeatherTool2() | ||||
|         assert "Nullable" in str(e) | ||||
| 
 | ||||
|         with pytest.raises(Exception) as e: | ||||
|             class GetWeatherTool3(Tool): | ||||
|                 name = "get_weather" | ||||
|                 description = "Get weather in the next days at given location." | ||||
|                 inputs = { | ||||
|                     "location": {"type": "string", "description": "the location"}, | ||||
|                     "celsius": {"type": "string", "description": "the temperature type", "nullable": True} | ||||
|                 } | ||||
|                 output_type = "string" | ||||
| 
 | ||||
|                 def forward(self, location, celsius: str) -> str: | ||||
|                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
|                  | ||||
|             tool = GetWeatherTool3() | ||||
|         assert "Nullable" in str(e) | ||||
		Loading…
	
		Reference in New Issue