Support any and none tool types (#280)
* Support any and none tool types
This commit is contained in:
		
							parent
							
								
									83ecd572fc
								
							
						
					
					
						commit
						43904f32c7
					
				|  | @ -87,3 +87,8 @@ jobs: | |||
|         run: | | ||||
|           uv run pytest ./tests/test_utils.py | ||||
|         if: ${{ success() || failure() }} | ||||
| 
 | ||||
|       - name: Function type hints utils tests | ||||
|         run: | | ||||
|           uv run pytest ./tests/test_function_type_hints_utils.py | ||||
|         if: ${{ success() || failure() }} | ||||
|  |  | |||
|  | @ -276,20 +276,26 @@ def _parse_google_format_docstring( | |||
|     return description, args_dict, returns | ||||
| 
 | ||||
| 
 | ||||
| def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | ||||
| def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict: | ||||
|     type_hints = get_type_hints(func) | ||||
|     signature = inspect.signature(func) | ||||
|     required = [] | ||||
|     for param_name, param in signature.parameters.items(): | ||||
|         if param.annotation == inspect.Parameter.empty: | ||||
|             raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") | ||||
|         if param.default == inspect.Parameter.empty: | ||||
|             required.append(param_name) | ||||
| 
 | ||||
|     properties = {} | ||||
|     for param_name, param_type in type_hints.items(): | ||||
|         properties[param_name] = _parse_type_hint(param_type) | ||||
| 
 | ||||
|     required = [] | ||||
|     for param_name, param in signature.parameters.items(): | ||||
|         if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints: | ||||
|             raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") | ||||
|         if param_name not in properties: | ||||
|             properties[param_name] = {} | ||||
| 
 | ||||
|         if param.default == inspect.Parameter.empty: | ||||
|             required.append(param_name) | ||||
|         else: | ||||
|             properties[param_name]["nullable"] = True | ||||
| 
 | ||||
|     schema = {"type": "object", "properties": properties} | ||||
|     if required: | ||||
|         schema["required"] = required | ||||
|  | @ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = { | |||
|     float: {"type": "number"}, | ||||
|     str: {"type": "string"}, | ||||
|     bool: {"type": "boolean"}, | ||||
|     Any: {}, | ||||
|     Any: {"type": "any"}, | ||||
|     types.NoneType: {"type": "null"}, | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -16,7 +16,7 @@ | |||
| # limitations under the License. | ||||
| import re | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, Optional | ||||
| from typing import Any, Dict, Optional | ||||
| 
 | ||||
| from .local_python_executor import ( | ||||
|     BASE_BUILTIN_MODULES, | ||||
|  | @ -85,7 +85,7 @@ class FinalAnswerTool(Tool): | |||
|     inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} | ||||
|     output_type = "any" | ||||
| 
 | ||||
|     def forward(self, answer): | ||||
|     def forward(self, answer: Any) -> Any: | ||||
|         return answer | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -136,7 +136,7 @@ tool_role_conversions = { | |||
| } | ||||
| 
 | ||||
| 
 | ||||
| def get_json_schema(tool: Tool) -> Dict: | ||||
| def get_tool_json_schema(tool: Tool) -> Dict: | ||||
|     properties = deepcopy(tool.inputs) | ||||
|     required = [] | ||||
|     for key, value in properties.items(): | ||||
|  | @ -240,7 +240,7 @@ class Model: | |||
|         if tools_to_call_from: | ||||
|             completion_kwargs.update( | ||||
|                 { | ||||
|                     "tools": [get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                     "tools": [get_tool_json_schema(tool) for tool in tools_to_call_from], | ||||
|                     "tool_choice": "required", | ||||
|                 } | ||||
|             ) | ||||
|  | @ -490,7 +490,7 @@ class TransformersModel(Model): | |||
|         if tools_to_call_from is not None: | ||||
|             prompt_tensor = self.tokenizer.apply_chat_template( | ||||
|                 messages, | ||||
|                 tools=completion_kwargs.pop("tools", []), | ||||
|                 tools=[get_tool_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 return_tensors="pt", | ||||
|                 return_dict=True, | ||||
|                 add_generation_prompt=True, | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ import textwrap | |||
| from contextlib import contextmanager | ||||
| from functools import lru_cache, wraps | ||||
| from pathlib import Path | ||||
| from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||
| from typing import Callable, Dict, List, Optional, Union | ||||
| 
 | ||||
| from huggingface_hub import ( | ||||
|     create_repo, | ||||
|  | @ -38,9 +38,9 @@ from huggingface_hub import ( | |||
| from huggingface_hub.utils import is_torch_available | ||||
| from packaging import version | ||||
| 
 | ||||
| from ._transformers_utils import ( | ||||
| from ._function_type_hints_utils import ( | ||||
|     TypeHintParsingException, | ||||
|     _parse_type_hint, | ||||
|     _convert_type_hints_to_json_schema, | ||||
|     get_imports, | ||||
|     get_json_schema, | ||||
| ) | ||||
|  | @ -64,22 +64,6 @@ def validate_after_init(cls): | |||
|     return cls | ||||
| 
 | ||||
| 
 | ||||
| def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | ||||
|     type_hints = get_type_hints(func) | ||||
|     signature = inspect.signature(func) | ||||
|     properties = {} | ||||
|     for param_name, param_type in type_hints.items(): | ||||
|         if param_name != "return": | ||||
|             properties[param_name] = _parse_type_hint(param_type) | ||||
|             if signature.parameters[param_name].default != inspect.Parameter.empty: | ||||
|                 properties[param_name]["nullable"] = True | ||||
|     for param_name in signature.parameters.keys(): | ||||
|         if signature.parameters[param_name].default != inspect.Parameter.empty: | ||||
|             if param_name not in properties:  # this can happen if the param has no type hint but a default value | ||||
|                 properties[param_name] = {"nullable": True} | ||||
|     return properties | ||||
| 
 | ||||
| 
 | ||||
| AUTHORIZED_TYPES = [ | ||||
|     "string", | ||||
|     "boolean", | ||||
|  | @ -87,8 +71,10 @@ AUTHORIZED_TYPES = [ | |||
|     "number", | ||||
|     "image", | ||||
|     "audio", | ||||
|     "any", | ||||
|     "array", | ||||
|     "object", | ||||
|     "any", | ||||
|     "null", | ||||
| ] | ||||
| 
 | ||||
| CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} | ||||
|  | @ -168,12 +154,15 @@ class Tool: | |||
|                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||
|                 ) | ||||
| 
 | ||||
|             json_schema = _convert_type_hints_to_json_schema( | ||||
|                 self.forward | ||||
|             )  # This function will raise an error on missing docstrings, contrary to get_json_schema | ||||
|             json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[ | ||||
|                 "properties" | ||||
|             ]  # This function will not raise an error on missing docstrings, contrary to get_json_schema | ||||
|             for key, value in self.inputs.items(): | ||||
|                 assert key in json_schema, ( | ||||
|                     f"Input '{key}' should be present in function signature, found only {json_schema.keys()}" | ||||
|                 ) | ||||
|                 if "nullable" in value: | ||||
|                     assert key in json_schema and "nullable" in json_schema[key], ( | ||||
|                     assert "nullable" in json_schema[key], ( | ||||
|                         f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||
|                     ) | ||||
|                 if key in json_schema and "nullable" in json_schema[key]: | ||||
|  | @ -887,16 +876,6 @@ class ToolCollection: | |||
|             yield cls(tools) | ||||
| 
 | ||||
| 
 | ||||
| def get_tool_json_schema(tool_function): | ||||
|     tool_json_schema = get_json_schema(tool_function)["function"] | ||||
|     tool_parameters = tool_json_schema["parameters"] | ||||
|     inputs_schema = tool_parameters["properties"] | ||||
|     for input_name in inputs_schema: | ||||
|         if "required" not in tool_parameters or input_name not in tool_parameters["required"]: | ||||
|             inputs_schema[input_name]["nullable"] = True | ||||
|     return tool_json_schema | ||||
| 
 | ||||
| 
 | ||||
| def tool(tool_function: Callable) -> Tool: | ||||
|     """ | ||||
|     Converts a function into an instance of a Tool subclass. | ||||
|  | @ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool: | |||
|         tool_function: Your function. Should have type hints for each input and a type hint for the output. | ||||
|         Should also have a docstring description including an 'Args:' part where each argument is described. | ||||
|     """ | ||||
|     tool_json_schema = get_tool_json_schema(tool_function) | ||||
|     tool_json_schema = get_json_schema(tool_function)["function"] | ||||
|     if "return" not in tool_json_schema: | ||||
|         raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,54 @@ | |||
| # coding=utf-8 | ||||
| # Copyright 2024 HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import unittest | ||||
| from typing import Optional, Tuple | ||||
| 
 | ||||
| from smolagents._function_type_hints_utils import get_json_schema | ||||
| 
 | ||||
| 
 | ||||
| class AgentTextTests(unittest.TestCase): | ||||
|     def test_return_none(self): | ||||
|         def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None: | ||||
|             """ | ||||
|             Test function | ||||
|             Args: | ||||
|                 x: The first input | ||||
|                 y: The second input | ||||
|             """ | ||||
|             pass | ||||
| 
 | ||||
|         schema = get_json_schema(fn) | ||||
|         expected_schema = { | ||||
|             "name": "fn", | ||||
|             "description": "Test function", | ||||
|             "parameters": { | ||||
|                 "type": "object", | ||||
|                 "properties": { | ||||
|                     "x": {"type": "integer", "description": "The first input"}, | ||||
|                     "y": { | ||||
|                         "type": "array", | ||||
|                         "description": "The second input", | ||||
|                         "nullable": True, | ||||
|                         "prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}], | ||||
|                     }, | ||||
|                 }, | ||||
|                 "required": ["x"], | ||||
|             }, | ||||
|             "return": {"type": "null"}, | ||||
|         } | ||||
|         self.assertEqual( | ||||
|             schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"] | ||||
|         ) | ||||
|         self.assertEqual(schema["function"], expected_schema) | ||||
|  | @ -34,7 +34,9 @@ class ModelTests(unittest.TestCase): | |||
|             """ | ||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
| 
 | ||||
|         assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] | ||||
|         assert ( | ||||
|             "nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] | ||||
|         ) | ||||
| 
 | ||||
|     def test_chatmessage_has_model_dumps_json(self): | ||||
|         message = ChatMessage("user", "Hello!") | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ | |||
| import unittest | ||||
| from pathlib import Path | ||||
| from textwrap import dedent | ||||
| from typing import Dict, Optional, Union | ||||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||||
| from unittest.mock import MagicMock, patch | ||||
| 
 | ||||
| import mcp | ||||
|  | @ -381,13 +381,41 @@ class ToolTests(unittest.TestCase): | |||
|             Get weather in the next days at given location. | ||||
| 
 | ||||
|             Args: | ||||
|                 location: the location | ||||
|                 celsius: is the temperature given in celsius | ||||
|                 location: The location to get the weather for. | ||||
|                 celsius: is the temperature given in celsius? | ||||
|             """ | ||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||
| 
 | ||||
|         assert get_weather.inputs["celsius"]["nullable"] | ||||
| 
 | ||||
|     def test_tool_supports_any_none(self): | ||||
|         @tool | ||||
|         def get_weather(location: Any) -> None: | ||||
|             """ | ||||
|             Get weather in the next days at given location. | ||||
| 
 | ||||
|             Args: | ||||
|                 location: The location to get the weather for. | ||||
|             """ | ||||
|             return | ||||
| 
 | ||||
|         assert get_weather.inputs["location"]["type"] == "any" | ||||
| 
 | ||||
|     def test_tool_supports_array(self): | ||||
|         @tool | ||||
|         def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]: | ||||
|             """ | ||||
|             Get weather in the next days at given locations. | ||||
| 
 | ||||
|             Args: | ||||
|                 locations: The locations to get the weather for. | ||||
|                 months: The months to get the weather for | ||||
|             """ | ||||
|             return | ||||
| 
 | ||||
|         assert get_weather.inputs["locations"]["type"] == "array" | ||||
|         assert get_weather.inputs["months"]["type"] == "array" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def mock_server_parameters(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue