Support optional arguments in tool calls
This commit is contained in:
		
							parent
							
								
									93569bd7c1
								
							
						
					
					
						commit
						e5ca0f0cb8
					
				|  | @ -44,7 +44,7 @@ This library offers: | ||||||
| 
 | 
 | ||||||
| First install the package. | First install the package. | ||||||
| ```bash | ```bash | ||||||
| pip install agents | pip install smolagents | ||||||
| ``` | ``` | ||||||
| Then define your agent, give it the tools it needs and run it! | Then define your agent, give it the tools it needs and run it! | ||||||
| ```py | ```py | ||||||
|  |  | ||||||
|  | @ -1,9 +1,9 @@ | ||||||
| # docstyle-ignore | # docstyle-ignore | ||||||
| INSTALL_CONTENT = """ | INSTALL_CONTENT = """ | ||||||
| # Transformers installation | # Installation | ||||||
| ! pip install agents | ! pip install smolagents | ||||||
| # To install from source instead of the last release, comment the command above and uncomment the following one. | # To install from source instead of the last release, comment the command above and uncomment the following one. | ||||||
| # ! pip install git+https://github.com/huggingface/agents.git | # ! pip install git+https://github.com/huggingface/smolagents.git | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] | notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] | ||||||
|  |  | ||||||
|  | @ -1,19 +1,21 @@ | ||||||
| from smolagents.agents import ToolCallingAgent | from smolagents.agents import ToolCallingAgent | ||||||
| from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel | from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel | ||||||
|  | from typing import Optional | ||||||
| 
 | 
 | ||||||
| # Choose which LLM engine to use! | # Choose which LLM engine to use! | ||||||
| model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") | # model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") | ||||||
| model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") | # model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") | ||||||
| model = LiteLLMModel("gpt-4o") | model = LiteLLMModel("gpt-4o") | ||||||
| 
 | 
 | ||||||
| @tool | @tool | ||||||
| def get_weather(location: str) -> str: | def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||||
|     """ |     """ | ||||||
|     Get weather in the next days at given location. |     Get weather in the next days at given location. | ||||||
|     Secretly this tool does not care about the location, it hates the weather everywhere. |     Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         location: the location |         location: the location | ||||||
|  |         celsius: the temperature | ||||||
|     """ |     """ | ||||||
|     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" |     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -172,6 +172,7 @@ class GoogleSearchTool(Tool): | ||||||
|         "filter_year": { |         "filter_year": { | ||||||
|             "type": "integer", |             "type": "integer", | ||||||
|             "description": "Optionally restrict results to a certain year", |             "description": "Optionally restrict results to a certain year", | ||||||
|  |             "nullable": True, | ||||||
|         }, |         }, | ||||||
|     } |     } | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
|  | @ -209,6 +210,11 @@ class GoogleSearchTool(Tool): | ||||||
|             raise ValueError(response.json()) |             raise ValueError(response.json()) | ||||||
| 
 | 
 | ||||||
|         if "organic_results" not in results.keys(): |         if "organic_results" not in results.keys(): | ||||||
|  |             if filter_year is not None: | ||||||
|  |                 raise Exception( | ||||||
|  |                     f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|                 raise Exception( |                 raise Exception( | ||||||
|                     f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." |                     f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|  | @ -67,9 +67,12 @@ tool_role_conversions = { | ||||||
| 
 | 
 | ||||||
| def get_json_schema(tool: Tool) -> Dict: | def get_json_schema(tool: Tool) -> Dict: | ||||||
|     properties = deepcopy(tool.inputs) |     properties = deepcopy(tool.inputs) | ||||||
|     for value in properties.values(): |     required = [] | ||||||
|  |     for key, value in properties.items(): | ||||||
|         if value["type"] == "any": |         if value["type"] == "any": | ||||||
|             value["type"] = "string" |             value["type"] = "string" | ||||||
|  |         if not ("nullable" in value and value["nullable"]): | ||||||
|  |             required.append(key) | ||||||
|     return { |     return { | ||||||
|         "type": "function", |         "type": "function", | ||||||
|         "function": { |         "function": { | ||||||
|  | @ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict: | ||||||
|             "parameters": { |             "parameters": { | ||||||
|                 "type": "object", |                 "type": "object", | ||||||
|                 "properties": properties, |                 "properties": properties, | ||||||
|                 "required": list(tool.inputs.keys()), |                 "required": required, | ||||||
|             }, |             }, | ||||||
|         }, |         }, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?" | ||||||
| Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. | Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| population_guangzhou = search("Guangzhou population") | for city in ["Guangzhou", "Shanghai"]: | ||||||
| print("Population Guangzhou:", population_guangzhou) |     print(f"Population {city}:", search(f"{city} population") | ||||||
| population_shanghai = search("Shanghai population") |  | ||||||
| print("Population Shanghai:", population_shanghai) |  | ||||||
| ```<end_code> | ```<end_code> | ||||||
| Observation: | Observation: | ||||||
| Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] | Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] | ||||||
|  | @ -278,11 +276,13 @@ final_answer("Shanghai") | ||||||
| --- | --- | ||||||
| Task: "What is the current age of the pope, raised to the power 0.36?" | Task: "What is the current age of the pope, raised to the power 0.36?" | ||||||
| 
 | 
 | ||||||
| Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36. | Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search. | ||||||
| Code: | Code: | ||||||
| ```py | ```py | ||||||
| pope_age = wiki(query="current pope age") | pope_age_wiki = wiki(query="current pope age") | ||||||
| print("Pope age:", pope_age) | print("Pope age as per wikipedia:", pope_age_wiki) | ||||||
|  | pope_age_search = web_search(query="current pope age") | ||||||
|  | print("Pope age as per google search:", pope_age_search) | ||||||
| ```<end_code> | ```<end_code> | ||||||
| Observation: | Observation: | ||||||
| Pope age: "The pope Francis is currently 85 years old." | Pope age: "The pope Francis is currently 85 years old." | ||||||
|  |  | ||||||
|  | @ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor): | ||||||
|         self.generic_visit(node) |         self.generic_visit(node) | ||||||
| 
 | 
 | ||||||
|     def visit_Attribute(self, node): |     def visit_Attribute(self, node): | ||||||
|         # Skip self.something |  | ||||||
|         if not (isinstance(node.value, ast.Name) and node.value.id == "self"): |         if not (isinstance(node.value, ast.Name) and node.value.id == "self"): | ||||||
|             self.generic_visit(node) |             self.generic_visit(node) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -24,7 +24,7 @@ import torch | ||||||
| import textwrap | import textwrap | ||||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable, Dict, List, Optional, Union | from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||||
| from huggingface_hub import ( | from huggingface_hub import ( | ||||||
|     create_repo, |     create_repo, | ||||||
|     get_collection, |     get_collection, | ||||||
|  | @ -42,6 +42,8 @@ from transformers.utils import ( | ||||||
|     is_accelerate_available, |     is_accelerate_available, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
| ) | ) | ||||||
|  | from transformers.utils.chat_template_utils import _parse_type_hint | ||||||
|  | 
 | ||||||
| from transformers.dynamic_module_utils import get_imports | from transformers.dynamic_module_utils import get_imports | ||||||
| from transformers import AutoProcessor | from transformers import AutoProcessor | ||||||
| 
 | 
 | ||||||
|  | @ -95,17 +97,27 @@ def setup_default_tools(): | ||||||
|     return default_tools |     return default_tools | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def validate_after_init(cls, do_validate_forward: bool = True): | def validate_after_init(cls): | ||||||
|     original_init = cls.__init__ |     original_init = cls.__init__ | ||||||
| 
 | 
 | ||||||
|     @wraps(original_init) |     @wraps(original_init) | ||||||
|     def new_init(self, *args, **kwargs): |     def new_init(self, *args, **kwargs): | ||||||
|         original_init(self, *args, **kwargs) |         original_init(self, *args, **kwargs) | ||||||
|         self.validate_arguments(do_validate_forward=do_validate_forward) |         self.validate_arguments() | ||||||
| 
 | 
 | ||||||
|     cls.__init__ = new_init |     cls.__init__ = new_init | ||||||
|     return cls |     return cls | ||||||
| 
 | 
 | ||||||
|  | def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | ||||||
|  |     type_hints = get_type_hints(func) | ||||||
|  |     signature = inspect.signature(func) | ||||||
|  |     properties = {} | ||||||
|  |     for param_name, param_type in type_hints.items(): | ||||||
|  |         if param_name != "return": | ||||||
|  |             properties[param_name] = _parse_type_hint(param_type) | ||||||
|  |             if signature.parameters[param_name].default != inspect.Parameter.empty: | ||||||
|  |                 properties[param_name]["nullable"] = True | ||||||
|  |     return properties | ||||||
| 
 | 
 | ||||||
| AUTHORIZED_TYPES = [ | AUTHORIZED_TYPES = [ | ||||||
|     "string", |     "string", | ||||||
|  | @ -145,7 +157,7 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|     name: str |     name: str | ||||||
|     description: str |     description: str | ||||||
|     inputs: Dict[str, Dict[str, Union[str, type]]] |     inputs: Dict[str, Dict[str, Union[str, type, bool]]] | ||||||
|     output_type: str |     output_type: str | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|  | @ -153,9 +165,9 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|     def __init_subclass__(cls, **kwargs): |     def __init_subclass__(cls, **kwargs): | ||||||
|         super().__init_subclass__(**kwargs) |         super().__init_subclass__(**kwargs) | ||||||
|         validate_after_init(cls, do_validate_forward=False) |         validate_after_init(cls) | ||||||
| 
 | 
 | ||||||
|     def validate_arguments(self, do_validate_forward: bool = True): |     def validate_arguments(self): | ||||||
|         required_attributes = { |         required_attributes = { | ||||||
|             "description": str, |             "description": str, | ||||||
|             "name": str, |             "name": str, | ||||||
|  | @ -184,13 +196,22 @@ class Tool: | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES |         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||||
|         if do_validate_forward: | 
 | ||||||
|  |         # Validate forward function signature | ||||||
|         signature = inspect.signature(self.forward) |         signature = inspect.signature(self.forward) | ||||||
|  | 
 | ||||||
|         if not set(signature.parameters.keys()) == set(self.inputs.keys()): |         if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||||
|             raise Exception( |             raise Exception( | ||||||
|                 "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." |                 "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||||
|  |         for key, value in self.inputs.items(): | ||||||
|  |             if "nullable" in value: | ||||||
|  |                 assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||||
|  |             if key in json_schema and "nullable" in json_schema[key]: | ||||||
|  |                 assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||||
|  | 
 | ||||||
|     def forward(self, *args, **kwargs): |     def forward(self, *args, **kwargs): | ||||||
|         return NotImplementedError("Write this method in your subclass of `Tool`.") |         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||||
| 
 | 
 | ||||||
|  | @ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|             "Tool return type not found: make sure your function has a return type hint!" |             "Tool return type not found: make sure your function has a return type hint!" | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     if parameters["return"]["type"] == "object": |  | ||||||
|         parameters["return"]["type"] = "any" |  | ||||||
| 
 |  | ||||||
|     class SimpleTool(Tool): |     class SimpleTool(Tool): | ||||||
|         def __init__(self, name, description, inputs, output_type, function): |         def __init__(self, name, description, inputs, output_type, function): | ||||||
|             self.name = name |             self.name = name | ||||||
|  | @ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|     ) |     ) | ||||||
|     original_signature = inspect.signature(tool_function) |     original_signature = inspect.signature(tool_function) | ||||||
|     new_parameters = [ |     new_parameters = [ | ||||||
|         inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) |         inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY) | ||||||
|     ] + list(original_signature.parameters.values()) |     ] + list(original_signature.parameters.values()) | ||||||
|     new_signature = original_signature.replace(parameters=new_parameters) |     new_signature = original_signature.replace(parameters=new_parameters) | ||||||
|     simple_tool.forward.__signature__ = new_signature |     simple_tool.forward.__signature__ = new_signature | ||||||
|     # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")]) |  | ||||||
|     return simple_tool |     return simple_tool | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
| from smolagents.default_tools import FinalAnswerTool | from smolagents.default_tools import FinalAnswerTool | ||||||
| 
 | 
 | ||||||
| from .test_tools_common import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ from smolagents.local_python_executor import ( | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from .test_tools_common import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Fake function we will use as tool | # Fake function we will use as tool | ||||||
|  |  | ||||||
|  | @ -17,7 +17,7 @@ import unittest | ||||||
| 
 | 
 | ||||||
| from smolagents import load_tool | from smolagents import load_tool | ||||||
| 
 | 
 | ||||||
| from .test_tools_common import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): | class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Dict, Union | from typing import Dict, Union, Optional | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest | import pytest | ||||||
|  | @ -126,9 +126,9 @@ class ToolTests(unittest.TestCase): | ||||||
|                     "description": "the task category (such as text-classification, depth-estimation, etc)", |                     "description": "the task category (such as text-classification, depth-estimation, etc)", | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             output_type = "integer" |             output_type = "string" | ||||||
| 
 | 
 | ||||||
|             def forward(self, task): |             def forward(self, task: str) -> str: | ||||||
|                 return "best model" |                 return "best model" | ||||||
| 
 | 
 | ||||||
|         tool = HFModelDownloadsTool() |         tool = HFModelDownloadsTool() | ||||||
|  | @ -223,7 +223,7 @@ class ToolTests(unittest.TestCase): | ||||||
|             name = "specific" |             name = "specific" | ||||||
|             description = "test description" |             description = "test description" | ||||||
|             inputs = { |             inputs = { | ||||||
|                 "input_str": {"type": "string", "description": "input description"} |                 "string_input": {"type": "string", "description": "input description"} | ||||||
|             } |             } | ||||||
|             output_type = "string" |             output_type = "string" | ||||||
| 
 | 
 | ||||||
|  | @ -231,7 +231,7 @@ class ToolTests(unittest.TestCase): | ||||||
|                 super().__init__(self) |                 super().__init__(self) | ||||||
|                 self.url = "none" |                 self.url = "none" | ||||||
| 
 | 
 | ||||||
|             def forward(self, string_input): |             def forward(self, string_input: str) -> str: | ||||||
|                 return self.url + string_input |                 return self.url + string_input | ||||||
| 
 | 
 | ||||||
|         fail_tool = FailTool("dummy_url") |         fail_tool = FailTool("dummy_url") | ||||||
|  | @ -241,46 +241,127 @@ class ToolTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|     def test_saving_tool_allows_no_imports_from_outside_methods(self): |     def test_saving_tool_allows_no_imports_from_outside_methods(self): | ||||||
|         # Test that using imports from outside functions fails |         # Test that using imports from outside functions fails | ||||||
|         from numpy import random |         import numpy as np | ||||||
| 
 | 
 | ||||||
|         class FailTool2(Tool): |         class FailTool(Tool): | ||||||
|             name = "specific" |             name = "specific" | ||||||
|             description = "test description" |             description = "test description" | ||||||
|             inputs = { |             inputs = { | ||||||
|                 "input_str": {"type": "string", "description": "input description"} |                 "string_input": {"type": "string", "description": "input description"} | ||||||
|             } |             } | ||||||
|             output_type = "string" |             output_type = "string" | ||||||
| 
 | 
 | ||||||
|             def useless_method(self): |             def useless_method(self): | ||||||
|                 self.client = random.random() |                 self.client = np.random.random() | ||||||
|                 return "" |                 return "" | ||||||
| 
 | 
 | ||||||
|             def forward(self, string_input): |             def forward(self, string_input): | ||||||
|                 return self.useless_method() + string_input |                 return self.useless_method() + string_input | ||||||
| 
 | 
 | ||||||
|         fail_tool_2 = FailTool2() |         fail_tool = FailTool() | ||||||
|         with pytest.raises(Exception) as e: |         with pytest.raises(Exception) as e: | ||||||
|             fail_tool_2.save("output") |             fail_tool.save("output") | ||||||
|         assert "random" in str(e) |         assert "'np' is undefined" in str(e) | ||||||
| 
 | 
 | ||||||
|         # Test that putting these imports inside functions works |         # Test that putting these imports inside functions works | ||||||
| 
 |         class SuccessTool(Tool): | ||||||
|         class FailTool3(Tool): |  | ||||||
|             name = "specific" |             name = "specific" | ||||||
|             description = "test description" |             description = "test description" | ||||||
|             inputs = { |             inputs = { | ||||||
|                 "input_str": {"type": "string", "description": "input description"} |                 "string_input": {"type": "string", "description": "input description"} | ||||||
|             } |             } | ||||||
|             output_type = "string" |             output_type = "string" | ||||||
| 
 | 
 | ||||||
|             def useless_method(self): |             def useless_method(self): | ||||||
|                 from numpy import random |                 import numpy as np | ||||||
| 
 | 
 | ||||||
|                 self.client = random.random() |                 self.client = np.random.random() | ||||||
|                 return "" |                 return "" | ||||||
| 
 | 
 | ||||||
|             def forward(self, string_input): |             def forward(self, string_input): | ||||||
|                 return self.useless_method() + string_input |                 return self.useless_method() + string_input | ||||||
| 
 | 
 | ||||||
|         fail_tool_3 = FailTool3() |         success_tool = SuccessTool() | ||||||
|         fail_tool_3.save("output") |         success_tool.save("output") | ||||||
|  | 
 | ||||||
|  |     def test_tool_missing_class_attributes_raises_error(self): | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             class GetWeatherTool(Tool): | ||||||
|  |                 name = "get_weather" | ||||||
|  |                 description = "Get weather in the next days at given location." | ||||||
|  |                 inputs = { | ||||||
|  |                     "location": {"type": "string", "description": "the location"}, | ||||||
|  |                     "celsius": {"type": "string", "description": "the temperature type"} | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 def forward(self, location: str, celsius: Optional[bool] = False) -> str: | ||||||
|  |                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  |                  | ||||||
|  |             tool = GetWeatherTool() | ||||||
|  |         assert "You must set an attribute output_type" in str(e) | ||||||
|  | 
 | ||||||
|  |     def test_tool_from_decorator_optional_args(self): | ||||||
|  |         @tool | ||||||
|  |         def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||||
|  |             """ | ||||||
|  |             Get weather in the next days at given location. | ||||||
|  |             Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 location: the location | ||||||
|  |                 celsius: the temperature type | ||||||
|  |             """ | ||||||
|  |             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  |          | ||||||
|  |         assert "nullable" in get_weather.inputs["celsius"] | ||||||
|  |         assert get_weather.inputs["celsius"]["nullable"] == True | ||||||
|  |         assert "nullable" not in get_weather.inputs["location"] | ||||||
|  | 
 | ||||||
|  |     def test_tool_mismatching_nullable_args_raises_error(self): | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             class GetWeatherTool(Tool): | ||||||
|  |                 name = "get_weather" | ||||||
|  |                 description = "Get weather in the next days at given location." | ||||||
|  |                 inputs = { | ||||||
|  |                     "location": {"type": "string", "description": "the location"}, | ||||||
|  |                     "celsius": {"type": "string", "description": "the temperature type"} | ||||||
|  |                 } | ||||||
|  |                 output_type = "string" | ||||||
|  | 
 | ||||||
|  |                 def forward(self, location: str, celsius: Optional[bool] = False) -> str: | ||||||
|  |                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  |                  | ||||||
|  |             tool = GetWeatherTool() | ||||||
|  |         assert "Nullable" in str(e) | ||||||
|  | 
 | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             class GetWeatherTool2(Tool): | ||||||
|  |                 name = "get_weather" | ||||||
|  |                 description = "Get weather in the next days at given location." | ||||||
|  |                 inputs = { | ||||||
|  |                     "location": {"type": "string", "description": "the location"}, | ||||||
|  |                     "celsius": {"type": "string", "description": "the temperature type"} | ||||||
|  |                 } | ||||||
|  |                 output_type = "string" | ||||||
|  | 
 | ||||||
|  |                 def forward(self, location: str, celsius: bool = False) -> str: | ||||||
|  |                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  |                  | ||||||
|  |             tool = GetWeatherTool2() | ||||||
|  |         assert "Nullable" in str(e) | ||||||
|  | 
 | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             class GetWeatherTool3(Tool): | ||||||
|  |                 name = "get_weather" | ||||||
|  |                 description = "Get weather in the next days at given location." | ||||||
|  |                 inputs = { | ||||||
|  |                     "location": {"type": "string", "description": "the location"}, | ||||||
|  |                     "celsius": {"type": "string", "description": "the temperature type", "nullable": True} | ||||||
|  |                 } | ||||||
|  |                 output_type = "string" | ||||||
|  | 
 | ||||||
|  |                 def forward(self, location, celsius: str) -> str: | ||||||
|  |                     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  |                  | ||||||
|  |             tool = GetWeatherTool3() | ||||||
|  |         assert "Nullable" in str(e) | ||||||
		Loading…
	
		Reference in New Issue