Support any and none tool types (#280)
* Support any and none tool types
This commit is contained in:
		
							parent
							
								
									83ecd572fc
								
							
						
					
					
						commit
						43904f32c7
					
				|  | @ -87,3 +87,8 @@ jobs: | ||||||
|         run: | |         run: | | ||||||
|           uv run pytest ./tests/test_utils.py |           uv run pytest ./tests/test_utils.py | ||||||
|         if: ${{ success() || failure() }} |         if: ${{ success() || failure() }} | ||||||
|  | 
 | ||||||
|  |       - name: Function type hints utils tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest ./tests/test_function_type_hints_utils.py | ||||||
|  |         if: ${{ success() || failure() }} | ||||||
|  |  | ||||||
|  | @ -276,20 +276,26 @@ def _parse_google_format_docstring( | ||||||
|     return description, args_dict, returns |     return description, args_dict, returns | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _convert_type_hints_to_json_schema(func: Callable) -> Dict: | def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict: | ||||||
|     type_hints = get_type_hints(func) |     type_hints = get_type_hints(func) | ||||||
|     signature = inspect.signature(func) |     signature = inspect.signature(func) | ||||||
|     required = [] |  | ||||||
|     for param_name, param in signature.parameters.items(): |  | ||||||
|         if param.annotation == inspect.Parameter.empty: |  | ||||||
|             raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") |  | ||||||
|         if param.default == inspect.Parameter.empty: |  | ||||||
|             required.append(param_name) |  | ||||||
| 
 | 
 | ||||||
|     properties = {} |     properties = {} | ||||||
|     for param_name, param_type in type_hints.items(): |     for param_name, param_type in type_hints.items(): | ||||||
|         properties[param_name] = _parse_type_hint(param_type) |         properties[param_name] = _parse_type_hint(param_type) | ||||||
| 
 | 
 | ||||||
|  |     required = [] | ||||||
|  |     for param_name, param in signature.parameters.items(): | ||||||
|  |         if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints: | ||||||
|  |             raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") | ||||||
|  |         if param_name not in properties: | ||||||
|  |             properties[param_name] = {} | ||||||
|  | 
 | ||||||
|  |         if param.default == inspect.Parameter.empty: | ||||||
|  |             required.append(param_name) | ||||||
|  |         else: | ||||||
|  |             properties[param_name]["nullable"] = True | ||||||
|  | 
 | ||||||
|     schema = {"type": "object", "properties": properties} |     schema = {"type": "object", "properties": properties} | ||||||
|     if required: |     if required: | ||||||
|         schema["required"] = required |         schema["required"] = required | ||||||
|  | @ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = { | ||||||
|     float: {"type": "number"}, |     float: {"type": "number"}, | ||||||
|     str: {"type": "string"}, |     str: {"type": "string"}, | ||||||
|     bool: {"type": "boolean"}, |     bool: {"type": "boolean"}, | ||||||
|     Any: {}, |     Any: {"type": "any"}, | ||||||
|  |     types.NoneType: {"type": "null"}, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -16,7 +16,7 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import re | import re | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Dict, Optional | from typing import Any, Dict, Optional | ||||||
| 
 | 
 | ||||||
| from .local_python_executor import ( | from .local_python_executor import ( | ||||||
|     BASE_BUILTIN_MODULES, |     BASE_BUILTIN_MODULES, | ||||||
|  | @ -85,7 +85,7 @@ class FinalAnswerTool(Tool): | ||||||
|     inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} |     inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} | ||||||
|     output_type = "any" |     output_type = "any" | ||||||
| 
 | 
 | ||||||
|     def forward(self, answer): |     def forward(self, answer: Any) -> Any: | ||||||
|         return answer |         return answer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -136,7 +136,7 @@ tool_role_conversions = { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_json_schema(tool: Tool) -> Dict: | def get_tool_json_schema(tool: Tool) -> Dict: | ||||||
|     properties = deepcopy(tool.inputs) |     properties = deepcopy(tool.inputs) | ||||||
|     required = [] |     required = [] | ||||||
|     for key, value in properties.items(): |     for key, value in properties.items(): | ||||||
|  | @ -240,7 +240,7 @@ class Model: | ||||||
|         if tools_to_call_from: |         if tools_to_call_from: | ||||||
|             completion_kwargs.update( |             completion_kwargs.update( | ||||||
|                 { |                 { | ||||||
|                     "tools": [get_json_schema(tool) for tool in tools_to_call_from], |                     "tools": [get_tool_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                     "tool_choice": "required", |                     "tool_choice": "required", | ||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
|  | @ -490,7 +490,7 @@ class TransformersModel(Model): | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             prompt_tensor = self.tokenizer.apply_chat_template( |             prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|                 messages, |                 messages, | ||||||
|                 tools=completion_kwargs.pop("tools", []), |                 tools=[get_tool_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 return_tensors="pt", |                 return_tensors="pt", | ||||||
|                 return_dict=True, |                 return_dict=True, | ||||||
|                 add_generation_prompt=True, |                 add_generation_prompt=True, | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ import textwrap | ||||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable, Dict, List, Optional, Union, get_type_hints | from typing import Callable, Dict, List, Optional, Union | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import ( | from huggingface_hub import ( | ||||||
|     create_repo, |     create_repo, | ||||||
|  | @ -38,9 +38,9 @@ from huggingface_hub import ( | ||||||
| from huggingface_hub.utils import is_torch_available | from huggingface_hub.utils import is_torch_available | ||||||
| from packaging import version | from packaging import version | ||||||
| 
 | 
 | ||||||
| from ._transformers_utils import ( | from ._function_type_hints_utils import ( | ||||||
|     TypeHintParsingException, |     TypeHintParsingException, | ||||||
|     _parse_type_hint, |     _convert_type_hints_to_json_schema, | ||||||
|     get_imports, |     get_imports, | ||||||
|     get_json_schema, |     get_json_schema, | ||||||
| ) | ) | ||||||
|  | @ -64,22 +64,6 @@ def validate_after_init(cls): | ||||||
|     return cls |     return cls | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _convert_type_hints_to_json_schema(func: Callable) -> Dict: |  | ||||||
|     type_hints = get_type_hints(func) |  | ||||||
|     signature = inspect.signature(func) |  | ||||||
|     properties = {} |  | ||||||
|     for param_name, param_type in type_hints.items(): |  | ||||||
|         if param_name != "return": |  | ||||||
|             properties[param_name] = _parse_type_hint(param_type) |  | ||||||
|             if signature.parameters[param_name].default != inspect.Parameter.empty: |  | ||||||
|                 properties[param_name]["nullable"] = True |  | ||||||
|     for param_name in signature.parameters.keys(): |  | ||||||
|         if signature.parameters[param_name].default != inspect.Parameter.empty: |  | ||||||
|             if param_name not in properties:  # this can happen if the param has no type hint but a default value |  | ||||||
|                 properties[param_name] = {"nullable": True} |  | ||||||
|     return properties |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| AUTHORIZED_TYPES = [ | AUTHORIZED_TYPES = [ | ||||||
|     "string", |     "string", | ||||||
|     "boolean", |     "boolean", | ||||||
|  | @ -87,8 +71,10 @@ AUTHORIZED_TYPES = [ | ||||||
|     "number", |     "number", | ||||||
|     "image", |     "image", | ||||||
|     "audio", |     "audio", | ||||||
|     "any", |     "array", | ||||||
|     "object", |     "object", | ||||||
|  |     "any", | ||||||
|  |     "null", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} | CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} | ||||||
|  | @ -168,12 +154,15 @@ class Tool: | ||||||
|                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." |                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             json_schema = _convert_type_hints_to_json_schema( |             json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[ | ||||||
|                 self.forward |                 "properties" | ||||||
|             )  # This function will raise an error on missing docstrings, contrary to get_json_schema |             ]  # This function will not raise an error on missing docstrings, contrary to get_json_schema | ||||||
|             for key, value in self.inputs.items(): |             for key, value in self.inputs.items(): | ||||||
|  |                 assert key in json_schema, ( | ||||||
|  |                     f"Input '{key}' should be present in function signature, found only {json_schema.keys()}" | ||||||
|  |                 ) | ||||||
|                 if "nullable" in value: |                 if "nullable" in value: | ||||||
|                     assert key in json_schema and "nullable" in json_schema[key], ( |                     assert "nullable" in json_schema[key], ( | ||||||
|                         f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." |                         f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||||
|                     ) |                     ) | ||||||
|                 if key in json_schema and "nullable" in json_schema[key]: |                 if key in json_schema and "nullable" in json_schema[key]: | ||||||
|  | @ -887,16 +876,6 @@ class ToolCollection: | ||||||
|             yield cls(tools) |             yield cls(tools) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_tool_json_schema(tool_function): |  | ||||||
|     tool_json_schema = get_json_schema(tool_function)["function"] |  | ||||||
|     tool_parameters = tool_json_schema["parameters"] |  | ||||||
|     inputs_schema = tool_parameters["properties"] |  | ||||||
|     for input_name in inputs_schema: |  | ||||||
|         if "required" not in tool_parameters or input_name not in tool_parameters["required"]: |  | ||||||
|             inputs_schema[input_name]["nullable"] = True |  | ||||||
|     return tool_json_schema |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def tool(tool_function: Callable) -> Tool: | def tool(tool_function: Callable) -> Tool: | ||||||
|     """ |     """ | ||||||
|     Converts a function into an instance of a Tool subclass. |     Converts a function into an instance of a Tool subclass. | ||||||
|  | @ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|         tool_function: Your function. Should have type hints for each input and a type hint for the output. |         tool_function: Your function. Should have type hints for each input and a type hint for the output. | ||||||
|         Should also have a docstring description including an 'Args:' part where each argument is described. |         Should also have a docstring description including an 'Args:' part where each argument is described. | ||||||
|     """ |     """ | ||||||
|     tool_json_schema = get_tool_json_schema(tool_function) |     tool_json_schema = get_json_schema(tool_function)["function"] | ||||||
|     if "return" not in tool_json_schema: |     if "return" not in tool_json_schema: | ||||||
|         raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") |         raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,54 @@ | ||||||
|  | # coding=utf-8 | ||||||
|  | # Copyright 2024 HuggingFace Inc. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | import unittest | ||||||
|  | from typing import Optional, Tuple | ||||||
|  | 
 | ||||||
|  | from smolagents._function_type_hints_utils import get_json_schema | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AgentTextTests(unittest.TestCase): | ||||||
|  |     def test_return_none(self): | ||||||
|  |         def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None: | ||||||
|  |             """ | ||||||
|  |             Test function | ||||||
|  |             Args: | ||||||
|  |                 x: The first input | ||||||
|  |                 y: The second input | ||||||
|  |             """ | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |         schema = get_json_schema(fn) | ||||||
|  |         expected_schema = { | ||||||
|  |             "name": "fn", | ||||||
|  |             "description": "Test function", | ||||||
|  |             "parameters": { | ||||||
|  |                 "type": "object", | ||||||
|  |                 "properties": { | ||||||
|  |                     "x": {"type": "integer", "description": "The first input"}, | ||||||
|  |                     "y": { | ||||||
|  |                         "type": "array", | ||||||
|  |                         "description": "The second input", | ||||||
|  |                         "nullable": True, | ||||||
|  |                         "prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}], | ||||||
|  |                     }, | ||||||
|  |                 }, | ||||||
|  |                 "required": ["x"], | ||||||
|  |             }, | ||||||
|  |             "return": {"type": "null"}, | ||||||
|  |         } | ||||||
|  |         self.assertEqual( | ||||||
|  |             schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"] | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(schema["function"], expected_schema) | ||||||
|  | @ -34,7 +34,9 @@ class ModelTests(unittest.TestCase): | ||||||
|             """ |             """ | ||||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" |             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
| 
 | 
 | ||||||
|         assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] |         assert ( | ||||||
|  |             "nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def test_chatmessage_has_model_dumps_json(self): |     def test_chatmessage_has_model_dumps_json(self): | ||||||
|         message = ChatMessage("user", "Hello!") |         message = ChatMessage("user", "Hello!") | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from textwrap import dedent | from textwrap import dedent | ||||||
| from typing import Dict, Optional, Union | from typing import Any, Dict, List, Optional, Tuple, Union | ||||||
| from unittest.mock import MagicMock, patch | from unittest.mock import MagicMock, patch | ||||||
| 
 | 
 | ||||||
| import mcp | import mcp | ||||||
|  | @ -381,13 +381,41 @@ class ToolTests(unittest.TestCase): | ||||||
|             Get weather in the next days at given location. |             Get weather in the next days at given location. | ||||||
| 
 | 
 | ||||||
|             Args: |             Args: | ||||||
|                 location: the location |                 location: The location to get the weather for. | ||||||
|                 celsius: is the temperature given in celsius |                 celsius: is the temperature given in celsius? | ||||||
|             """ |             """ | ||||||
|             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" |             return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
| 
 | 
 | ||||||
|         assert get_weather.inputs["celsius"]["nullable"] |         assert get_weather.inputs["celsius"]["nullable"] | ||||||
| 
 | 
 | ||||||
|  |     def test_tool_supports_any_none(self): | ||||||
|  |         @tool | ||||||
|  |         def get_weather(location: Any) -> None: | ||||||
|  |             """ | ||||||
|  |             Get weather in the next days at given location. | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 location: The location to get the weather for. | ||||||
|  |             """ | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         assert get_weather.inputs["location"]["type"] == "any" | ||||||
|  | 
 | ||||||
|  |     def test_tool_supports_array(self): | ||||||
|  |         @tool | ||||||
|  |         def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]: | ||||||
|  |             """ | ||||||
|  |             Get weather in the next days at given locations. | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 locations: The locations to get the weather for. | ||||||
|  |                 months: The months to get the weather for | ||||||
|  |             """ | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         assert get_weather.inputs["locations"]["type"] == "array" | ||||||
|  |         assert get_weather.inputs["months"]["type"] == "array" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def mock_server_parameters(): | def mock_server_parameters(): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue