diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 412c1e8..b379dee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,3 +87,8 @@ jobs: run: | uv run pytest ./tests/test_utils.py if: ${{ success() || failure() }} + + - name: Function type hints utils tests + run: | + uv run pytest ./tests/test_function_type_hints_utils.py + if: ${{ success() || failure() }} diff --git a/src/smolagents/_transformers_utils.py b/src/smolagents/_function_type_hints_utils.py similarity index 97% rename from src/smolagents/_transformers_utils.py rename to src/smolagents/_function_type_hints_utils.py index fcbcf26..b076941 100644 --- a/src/smolagents/_transformers_utils.py +++ b/src/smolagents/_function_type_hints_utils.py @@ -276,20 +276,26 @@ def _parse_google_format_docstring( return description, args_dict, returns -def _convert_type_hints_to_json_schema(func: Callable) -> Dict: +def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict: type_hints = get_type_hints(func) signature = inspect.signature(func) - required = [] - for param_name, param in signature.parameters.items(): - if param.annotation == inspect.Parameter.empty: - raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") - if param.default == inspect.Parameter.empty: - required.append(param_name) properties = {} for param_name, param_type in type_hints.items(): properties[param_name] = _parse_type_hint(param_type) + required = [] + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints: + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") + if param_name not in properties: + properties[param_name] = {} + + if param.default == inspect.Parameter.empty: + required.append(param_name) + else: + properties[param_name]["nullable"] = True + schema = {"type": "object", "properties": properties} if required: schema["required"] = required @@ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = { float: {"type": "number"}, str: {"type": "string"}, bool: {"type": "boolean"}, - Any: {}, + Any: {"type": "any"}, + types.NoneType: {"type": "null"}, } diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 876f36a..44f76df 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -16,7 +16,7 @@ # limitations under the License. import re from dataclasses import dataclass -from typing import Dict, Optional +from typing import Any, Dict, Optional from .local_python_executor import ( BASE_BUILTIN_MODULES, @@ -85,7 +85,7 @@ class FinalAnswerTool(Tool): inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} output_type = "any" - def forward(self, answer): + def forward(self, answer: Any) -> Any: return answer diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 9240aed..3e6e7c3 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -136,7 +136,7 @@ tool_role_conversions = { } -def get_json_schema(tool: Tool) -> Dict: +def get_tool_json_schema(tool: Tool) -> Dict: properties = deepcopy(tool.inputs) required = [] for key, value in properties.items(): @@ -240,7 +240,7 @@ class Model: if tools_to_call_from: completion_kwargs.update( { - "tools": [get_json_schema(tool) for tool in tools_to_call_from], + "tools": [get_tool_json_schema(tool) for tool in tools_to_call_from], "tool_choice": "required", } ) @@ -490,7 +490,7 @@ class TransformersModel(Model): if tools_to_call_from is not None: prompt_tensor = self.tokenizer.apply_chat_template( messages, - tools=completion_kwargs.pop("tools", []), + tools=[get_tool_json_schema(tool) for tool in tools_to_call_from], return_tensors="pt", return_dict=True, add_generation_prompt=True, diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 80301cc..e7725e8 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -26,7 +26,7 @@ import textwrap from contextlib import contextmanager from functools import lru_cache, wraps from pathlib import Path -from typing import Callable, Dict, List, Optional, Union, get_type_hints +from typing import Callable, Dict, List, Optional, Union from huggingface_hub import ( create_repo, @@ -38,9 +38,9 @@ from huggingface_hub import ( from huggingface_hub.utils import is_torch_available from packaging import version -from ._transformers_utils import ( +from ._function_type_hints_utils import ( TypeHintParsingException, - _parse_type_hint, + _convert_type_hints_to_json_schema, get_imports, get_json_schema, ) @@ -64,22 +64,6 @@ def validate_after_init(cls): return cls -def _convert_type_hints_to_json_schema(func: Callable) -> Dict: - type_hints = get_type_hints(func) - signature = inspect.signature(func) - properties = {} - for param_name, param_type in type_hints.items(): - if param_name != "return": - properties[param_name] = _parse_type_hint(param_type) - if signature.parameters[param_name].default != inspect.Parameter.empty: - properties[param_name]["nullable"] = True - for param_name in signature.parameters.keys(): - if signature.parameters[param_name].default != inspect.Parameter.empty: - if param_name not in properties: # this can happen if the param has no type hint but a default value - properties[param_name] = {"nullable": True} - return properties - - AUTHORIZED_TYPES = [ "string", "boolean", @@ -87,8 +71,10 @@ AUTHORIZED_TYPES = [ "number", "image", "audio", - "any", + "array", "object", + "any", + "null", ] CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} @@ -168,12 +154,15 @@ class Tool: "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." ) - json_schema = _convert_type_hints_to_json_schema( - self.forward - ) # This function will raise an error on missing docstrings, contrary to get_json_schema + json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[ + "properties" + ] # This function will not raise an error on missing docstrings, contrary to get_json_schema for key, value in self.inputs.items(): + assert key in json_schema, ( + f"Input '{key}' should be present in function signature, found only {json_schema.keys()}" + ) if "nullable" in value: - assert key in json_schema and "nullable" in json_schema[key], ( + assert "nullable" in json_schema[key], ( f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." ) if key in json_schema and "nullable" in json_schema[key]: @@ -887,16 +876,6 @@ class ToolCollection: yield cls(tools) -def get_tool_json_schema(tool_function): - tool_json_schema = get_json_schema(tool_function)["function"] - tool_parameters = tool_json_schema["parameters"] - inputs_schema = tool_parameters["properties"] - for input_name in inputs_schema: - if "required" not in tool_parameters or input_name not in tool_parameters["required"]: - inputs_schema[input_name]["nullable"] = True - return tool_json_schema - - def tool(tool_function: Callable) -> Tool: """ Converts a function into an instance of a Tool subclass. @@ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool: tool_function: Your function. Should have type hints for each input and a type hint for the output. Should also have a docstring description including an 'Args:' part where each argument is described. """ - tool_json_schema = get_tool_json_schema(tool_function) + tool_json_schema = get_json_schema(tool_function)["function"] if "return" not in tool_json_schema: raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") diff --git a/tests/test_function_type_hints_utils.py b/tests/test_function_type_hints_utils.py new file mode 100644 index 0000000..9e58985 --- /dev/null +++ b/tests/test_function_type_hints_utils.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from typing import Optional, Tuple + +from smolagents._function_type_hints_utils import get_json_schema + + +class AgentTextTests(unittest.TestCase): + def test_return_none(self): + def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None: + """ + Test function + Args: + x: The first input + y: The second input + """ + pass + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The first input"}, + "y": { + "type": "array", + "description": "The second input", + "nullable": True, + "prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}], + }, + }, + "required": ["x"], + }, + "return": {"type": "null"}, + } + self.assertEqual( + schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"] + ) + self.assertEqual(schema["function"], expected_schema) diff --git a/tests/test_models.py b/tests/test_models.py index 4f897f3..bad87fd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -34,7 +34,9 @@ class ModelTests(unittest.TestCase): """ return "The weather is UNGODLY with torrential rains and temperatures below -10°C" - assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] + assert ( + "nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] + ) def test_chatmessage_has_model_dumps_json(self): message = ChatMessage("user", "Hello!") diff --git a/tests/test_tools.py b/tests/test_tools.py index 0caefd2..e8d5a50 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -15,7 +15,7 @@ import unittest from pathlib import Path from textwrap import dedent -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import MagicMock, patch import mcp @@ -381,13 +381,41 @@ class ToolTests(unittest.TestCase): Get weather in the next days at given location. Args: - location: the location - celsius: is the temperature given in celsius + location: The location to get the weather for. + celsius: is the temperature given in celsius? """ return "The weather is UNGODLY with torrential rains and temperatures below -10°C" assert get_weather.inputs["celsius"]["nullable"] + def test_tool_supports_any_none(self): + @tool + def get_weather(location: Any) -> None: + """ + Get weather in the next days at given location. + + Args: + location: The location to get the weather for. + """ + return + + assert get_weather.inputs["location"]["type"] == "any" + + def test_tool_supports_array(self): + @tool + def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]: + """ + Get weather in the next days at given locations. + + Args: + locations: The locations to get the weather for. + months: The months to get the weather for + """ + return + + assert get_weather.inputs["locations"]["type"] == "array" + assert get_weather.inputs["months"]["type"] == "array" + @pytest.fixture def mock_server_parameters():