Support any and none tool types (#280)

* Support any and none tool types
This commit is contained in:
Aymeric Roucher 2025-01-22 12:47:05 +01:00 committed by GitHub
parent 83ecd572fc
commit 43904f32c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 127 additions and 52 deletions

View File

@ -87,3 +87,8 @@ jobs:
run: |
uv run pytest ./tests/test_utils.py
if: ${{ success() || failure() }}
- name: Function type hints utils tests
run: |
uv run pytest ./tests/test_function_type_hints_utils.py
if: ${{ success() || failure() }}

View File

@ -276,20 +276,26 @@ def _parse_google_format_docstring(
return description, args_dict, returns
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
required = []
for param_name, param in signature.parameters.items():
if param.annotation == inspect.Parameter.empty:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param.default == inspect.Parameter.empty:
required.append(param_name)
properties = {}
for param_name, param_type in type_hints.items():
properties[param_name] = _parse_type_hint(param_type)
required = []
for param_name, param in signature.parameters.items():
if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param_name not in properties:
properties[param_name] = {}
if param.default == inspect.Parameter.empty:
required.append(param_name)
else:
properties[param_name]["nullable"] = True
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
@ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = {
float: {"type": "number"},
str: {"type": "string"},
bool: {"type": "boolean"},
Any: {},
Any: {"type": "any"},
types.NoneType: {"type": "null"},
}

View File

@ -16,7 +16,7 @@
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Dict, Optional
from .local_python_executor import (
BASE_BUILTIN_MODULES,
@ -85,7 +85,7 @@ class FinalAnswerTool(Tool):
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"
def forward(self, answer):
def forward(self, answer: Any) -> Any:
return answer

View File

@ -136,7 +136,7 @@ tool_role_conversions = {
}
def get_json_schema(tool: Tool) -> Dict:
def get_tool_json_schema(tool: Tool) -> Dict:
properties = deepcopy(tool.inputs)
required = []
for key, value in properties.items():
@ -240,7 +240,7 @@ class Model:
if tools_to_call_from:
completion_kwargs.update(
{
"tools": [get_json_schema(tool) for tool in tools_to_call_from],
"tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
"tool_choice": "required",
}
)
@ -490,7 +490,7 @@ class TransformersModel(Model):
if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
tools=completion_kwargs.pop("tools", []),
tools=[get_tool_json_schema(tool) for tool in tools_to_call_from],
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,

View File

@ -26,7 +26,7 @@ import textwrap
from contextlib import contextmanager
from functools import lru_cache, wraps
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union, get_type_hints
from typing import Callable, Dict, List, Optional, Union
from huggingface_hub import (
create_repo,
@ -38,9 +38,9 @@ from huggingface_hub import (
from huggingface_hub.utils import is_torch_available
from packaging import version
from ._transformers_utils import (
from ._function_type_hints_utils import (
TypeHintParsingException,
_parse_type_hint,
_convert_type_hints_to_json_schema,
get_imports,
get_json_schema,
)
@ -64,22 +64,6 @@ def validate_after_init(cls):
return cls
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
properties = {}
for param_name, param_type in type_hints.items():
if param_name != "return":
properties[param_name] = _parse_type_hint(param_type)
if signature.parameters[param_name].default != inspect.Parameter.empty:
properties[param_name]["nullable"] = True
for param_name in signature.parameters.keys():
if signature.parameters[param_name].default != inspect.Parameter.empty:
if param_name not in properties: # this can happen if the param has no type hint but a default value
properties[param_name] = {"nullable": True}
return properties
AUTHORIZED_TYPES = [
"string",
"boolean",
@ -87,8 +71,10 @@ AUTHORIZED_TYPES = [
"number",
"image",
"audio",
"any",
"array",
"object",
"any",
"null",
]
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
@ -168,12 +154,15 @@ class Tool:
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
json_schema = _convert_type_hints_to_json_schema(
self.forward
) # This function will raise an error on missing docstrings, contrary to get_json_schema
json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
"properties"
] # This function will not raise an error on missing docstrings, contrary to get_json_schema
for key, value in self.inputs.items():
assert key in json_schema, (
f"Input '{key}' should be present in function signature, found only {json_schema.keys()}"
)
if "nullable" in value:
assert key in json_schema and "nullable" in json_schema[key], (
assert "nullable" in json_schema[key], (
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
)
if key in json_schema and "nullable" in json_schema[key]:
@ -887,16 +876,6 @@ class ToolCollection:
yield cls(tools)
def get_tool_json_schema(tool_function):
tool_json_schema = get_json_schema(tool_function)["function"]
tool_parameters = tool_json_schema["parameters"]
inputs_schema = tool_parameters["properties"]
for input_name in inputs_schema:
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
inputs_schema[input_name]["nullable"] = True
return tool_json_schema
def tool(tool_function: Callable) -> Tool:
"""
Converts a function into an instance of a Tool subclass.
@ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
tool_json_schema = get_tool_json_schema(tool_function)
tool_json_schema = get_json_schema(tool_function)["function"]
if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")

View File

@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Optional, Tuple
from smolagents._function_type_hints_utils import get_json_schema
class AgentTextTests(unittest.TestCase):
def test_return_none(self):
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
"""
Test function
Args:
x: The first input
y: The second input
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The first input"},
"y": {
"type": "array",
"description": "The second input",
"nullable": True,
"prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}],
},
},
"required": ["x"],
},
"return": {"type": "null"},
}
self.assertEqual(
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
)
self.assertEqual(schema["function"], expected_schema)

View File

@ -34,7 +34,9 @@ class ModelTests(unittest.TestCase):
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
assert (
"nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
)
def test_chatmessage_has_model_dumps_json(self):
message = ChatMessage("user", "Hello!")

View File

@ -15,7 +15,7 @@
import unittest
from pathlib import Path
from textwrap import dedent
from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import MagicMock, patch
import mcp
@ -381,13 +381,41 @@ class ToolTests(unittest.TestCase):
Get weather in the next days at given location.
Args:
location: the location
celsius: is the temperature given in celsius
location: The location to get the weather for.
celsius: is the temperature given in celsius?
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert get_weather.inputs["celsius"]["nullable"]
def test_tool_supports_any_none(self):
@tool
def get_weather(location: Any) -> None:
"""
Get weather in the next days at given location.
Args:
location: The location to get the weather for.
"""
return
assert get_weather.inputs["location"]["type"] == "any"
def test_tool_supports_array(self):
@tool
def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]:
"""
Get weather in the next days at given locations.
Args:
locations: The locations to get the weather for.
months: The months to get the weather for
"""
return
assert get_weather.inputs["locations"]["type"] == "array"
assert get_weather.inputs["months"]["type"] == "array"
@pytest.fixture
def mock_server_parameters():