From a4f89b68b28035d270744f5bfd06804cbcc3a97e Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Tue, 28 Jan 2025 10:58:53 +0100 Subject: [PATCH] Add tool saving test (#389) * Add tool saving test * Format --- src/smolagents/e2b_executor.py | 4 ++-- src/smolagents/tool_validation.py | 4 +--- tests/test_tools.py | 4 ++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index 5e007b5..3d3eee1 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -14,9 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import base64 import pickle +import re import textwrap from io import BytesIO from typing import Any, List, Tuple @@ -47,7 +47,7 @@ class E2BExecutor: self.custom_tools = {} self.final_answer = False - self.final_answer_pattern = re.compile(r'^final_answer\((.*)\)$') + self.final_answer_pattern = re.compile(r"^final_answer\((.*)\)$") self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") # TODO: validate installing agents package or not # print("Installing agents package on remote executor...") diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index d8e6daa..1f4d457 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -25,9 +25,7 @@ class MethodChecker(ast.NodeVisitor): self.class_attributes = class_attributes self.errors = [] self.check_imports = check_imports - self.typing_names = { - 'Any' - } + self.typing_names = {"Any"} def visit_arguments(self, node): """Collect function arguments""" diff --git a/tests/test_tools.py b/tests/test_tools.py index e8d5a50..4948a1e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from pathlib import Path from textwrap import dedent @@ -399,7 +400,10 @@ class ToolTests(unittest.TestCase): """ return + with tempfile.TemporaryDirectory() as tmp_dir: + get_weather.save(tmp_dir) assert get_weather.inputs["location"]["type"] == "any" + assert get_weather.output_type == "null" def test_tool_supports_array(self): @tool