From 94371331bbf44ff6b3799beabca738f478625488 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 14 Feb 2025 11:21:04 +0100 Subject: [PATCH] Fix evaluate_condition for non-bool result (#638) --- src/smolagents/local_python_executor.py | 50 ++++----- tests/test_local_python_executor.py | 130 ++++++++++++++++++++---- 2 files changed, 128 insertions(+), 52 deletions(-) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index a5b12f3..d8fe84e 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -714,49 +714,39 @@ def evaluate_condition( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> bool | object: - left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) - comparators = [ - evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators - ] - ops = [type(op) for op in condition.ops] - result = True - current_left = left - - for op, comparator in zip(ops, comparators): + left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) + for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)): + op = type(op) + right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports) if op == ast.Eq: - current_result = current_left == comparator + current_result = left == right elif op == ast.NotEq: - current_result = current_left != comparator + current_result = left != right elif op == ast.Lt: - current_result = current_left < comparator + current_result = left < right elif op == ast.LtE: - current_result = current_left <= comparator + current_result = left <= right elif op == ast.Gt: - current_result = current_left > comparator + current_result = left > right elif op == ast.GtE: - current_result = current_left >= comparator + current_result = left >= right elif op == ast.Is: - current_result = current_left is comparator + current_result = left is right elif op == ast.IsNot: - current_result = current_left is not comparator + current_result = left is not right elif op == ast.In: - current_result = current_left in comparator + current_result = left in right elif op == ast.NotIn: - current_result = current_left not in comparator + current_result = left not in right else: - raise InterpreterError(f"Operator not supported: {op}") + raise InterpreterError(f"Unsupported comparison operator: {op}") - if not isinstance(current_result, bool): - return current_result - - result = result & current_result - current_left = comparator - - if isinstance(result, bool) and not result: - break - - return result if isinstance(result, (bool, pd.Series)) else result.all() + if current_result is False: + return False + result = current_result if i == 0 else (result and current_result) + left = right + return result def evaluate_if( diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index 2abef69..f7ecc91 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -19,6 +19,7 @@ import unittest from textwrap import dedent import numpy as np +import pandas as pd import pytest from smolagents.default_tools import BASE_PYTHON_TOOLS @@ -1188,11 +1189,11 @@ def test_evaluate_delete(code, state, expectation): ("a in b", {"a": 4, "b": [1, 2, 3]}, False), ("a not in b", {"a": 1, "b": [1, 2, 3]}, False), ("a not in b", {"a": 4, "b": [1, 2, 3]}, True), - # Composite conditions: + # Chained conditions: ("a == b == c", {"a": 1, "b": 1, "c": 1}, True), ("a == b == c", {"a": 1, "b": 2, "c": 1}, False), - ("a == b < c", {"a": 1, "b": 1, "c": 1}, False), - ("a == b < c", {"a": 1, "b": 1, "c": 2}, True), + ("a == b < c", {"a": 2, "b": 2, "c": 2}, False), + ("a == b < c", {"a": 0, "b": 0, "c": 1}, True), ], ) def test_evaluate_condition(condition, state, expected_result): @@ -1201,6 +1202,91 @@ def test_evaluate_condition(condition, state, expected_result): assert result == expected_result +@pytest.mark.parametrize( + "condition, state, expected_result", + [ + ("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])), + ("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])), + ("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])), + ("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])), + ("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])), + ("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])), + ( + "a == b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})}, + pd.DataFrame({"x": [True, True], "y": [True, False]}), + ), + ( + "a != b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})}, + pd.DataFrame({"x": [False, False], "y": [False, True]}), + ), + ( + "a < b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})}, + pd.DataFrame({"x": [True, False], "y": [False, False]}), + ), + ( + "a <= b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})}, + pd.DataFrame({"x": [True, True], "y": [False, False]}), + ), + ( + "a > b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})}, + pd.DataFrame({"x": [False, False], "y": [True, True]}), + ), + ( + "a >= b", + {"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})}, + pd.DataFrame({"x": [False, True], "y": [True, True]}), + ), + ], +) +def test_evaluate_condition_with_pandas(condition, state, expected_result): + condition_ast = ast.parse(condition, mode="eval").body + result = evaluate_condition(condition_ast, state, {}, {}, []) + if isinstance(result, pd.Series): + pd.testing.assert_series_equal(result, expected_result) + else: + pd.testing.assert_frame_equal(result, expected_result) + + +@pytest.mark.parametrize( + "condition, state, expected_exception", + [ + # Chained conditions: + ( + "a == b == c", + { + "a": pd.Series([1, 2, 3]), + "b": pd.Series([2, 2, 2]), + "c": pd.Series([3, 3, 3]), + }, + ValueError( + "The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()." + ), + ), + ( + "a == b == c", + { + "a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), + "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}), + "c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}), + }, + ValueError( + "The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()." + ), + ), + ], +) +def test_evaluate_condition_with_pandas_exceptions(condition, state, expected_exception): + condition_ast = ast.parse(condition, mode="eval").body + with pytest.raises(type(expected_exception)) as exception_info: + _ = evaluate_condition(condition_ast, state, {}, {}, []) + assert str(expected_exception) in str(exception_info.value) + + def test_get_safe_module_handle_lazy_imports(): class FakeModule(types.ModuleType): def __init__(self, name): @@ -1222,28 +1308,28 @@ def test_get_safe_module_handle_lazy_imports(): def test_non_standard_comparisons(): - code = """ -class NonStdEqualsResult: - def __init__(self, left:object, right:object): - self._left = left - self._right = right - def __str__(self) -> str: - return f'{self._left}=={self._right}' + code = dedent("""\ + class NonStdEqualsResult: + def __init__(self, left:object, right:object): + self._left = left + self._right = right + def __str__(self) -> str: + return f'{self._left} == {self._right}' -class NonStdComparisonClass: - def __init__(self, value: str ): - self._value = value - def __str__(self): - return self._value - def __eq__(self, other): - return NonStdEqualsResult(self, other) -a = NonStdComparisonClass("a") -b = NonStdComparisonClass("b") -result = a == b - """ + class NonStdComparisonClass: + def __init__(self, value: str ): + self._value = value + def __str__(self): + return self._value + def __eq__(self, other): + return NonStdEqualsResult(self, other) + a = NonStdComparisonClass("a") + b = NonStdComparisonClass("b") + result = a == b + """) result, _ = evaluate_python_code(code, state={}) assert not isinstance(result, bool) - assert str(result) == "a==b" + assert str(result) == "a == b" class TestPrintContainer: