Fix evaluate_condition for non-bool result (#638)
This commit is contained in:
parent
d02093dc24
commit
94371331bb
|
@ -714,49 +714,39 @@ def evaluate_condition(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> bool | object:
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
||||
comparators = [
|
||||
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
|
||||
]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
result = True
|
||||
current_left = left
|
||||
|
||||
for op, comparator in zip(ops, comparators):
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
||||
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
|
||||
op = type(op)
|
||||
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
|
||||
if op == ast.Eq:
|
||||
current_result = current_left == comparator
|
||||
current_result = left == right
|
||||
elif op == ast.NotEq:
|
||||
current_result = current_left != comparator
|
||||
current_result = left != right
|
||||
elif op == ast.Lt:
|
||||
current_result = current_left < comparator
|
||||
current_result = left < right
|
||||
elif op == ast.LtE:
|
||||
current_result = current_left <= comparator
|
||||
current_result = left <= right
|
||||
elif op == ast.Gt:
|
||||
current_result = current_left > comparator
|
||||
current_result = left > right
|
||||
elif op == ast.GtE:
|
||||
current_result = current_left >= comparator
|
||||
current_result = left >= right
|
||||
elif op == ast.Is:
|
||||
current_result = current_left is comparator
|
||||
current_result = left is right
|
||||
elif op == ast.IsNot:
|
||||
current_result = current_left is not comparator
|
||||
current_result = left is not right
|
||||
elif op == ast.In:
|
||||
current_result = current_left in comparator
|
||||
current_result = left in right
|
||||
elif op == ast.NotIn:
|
||||
current_result = current_left not in comparator
|
||||
current_result = left not in right
|
||||
else:
|
||||
raise InterpreterError(f"Operator not supported: {op}")
|
||||
raise InterpreterError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
if not isinstance(current_result, bool):
|
||||
return current_result
|
||||
|
||||
result = result & current_result
|
||||
current_left = comparator
|
||||
|
||||
if isinstance(result, bool) and not result:
|
||||
break
|
||||
|
||||
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
||||
if current_result is False:
|
||||
return False
|
||||
result = current_result if i == 0 else (result and current_result)
|
||||
left = right
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_if(
|
||||
|
|
|
@ -19,6 +19,7 @@ import unittest
|
|||
from textwrap import dedent
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||
|
@ -1188,11 +1189,11 @@ def test_evaluate_delete(code, state, expectation):
|
|||
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
||||
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
||||
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
||||
# Composite conditions:
|
||||
# Chained conditions:
|
||||
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
||||
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
||||
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
|
||||
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
|
||||
("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
|
||||
("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
|
||||
],
|
||||
)
|
||||
def test_evaluate_condition(condition, state, expected_result):
|
||||
|
@ -1201,6 +1202,91 @@ def test_evaluate_condition(condition, state, expected_result):
|
|||
assert result == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition, state, expected_result",
|
||||
[
|
||||
("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
|
||||
("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
|
||||
("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
|
||||
("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
|
||||
("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
|
||||
("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
|
||||
(
|
||||
"a == b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
||||
pd.DataFrame({"x": [True, True], "y": [True, False]}),
|
||||
),
|
||||
(
|
||||
"a != b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
|
||||
pd.DataFrame({"x": [False, False], "y": [False, True]}),
|
||||
),
|
||||
(
|
||||
"a < b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||
pd.DataFrame({"x": [True, False], "y": [False, False]}),
|
||||
),
|
||||
(
|
||||
"a <= b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||
pd.DataFrame({"x": [True, True], "y": [False, False]}),
|
||||
),
|
||||
(
|
||||
"a > b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||
pd.DataFrame({"x": [False, False], "y": [True, True]}),
|
||||
),
|
||||
(
|
||||
"a >= b",
|
||||
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
|
||||
pd.DataFrame({"x": [False, True], "y": [True, True]}),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate_condition_with_pandas(condition, state, expected_result):
|
||||
condition_ast = ast.parse(condition, mode="eval").body
|
||||
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||
if isinstance(result, pd.Series):
|
||||
pd.testing.assert_series_equal(result, expected_result)
|
||||
else:
|
||||
pd.testing.assert_frame_equal(result, expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition, state, expected_exception",
|
||||
[
|
||||
# Chained conditions:
|
||||
(
|
||||
"a == b == c",
|
||||
{
|
||||
"a": pd.Series([1, 2, 3]),
|
||||
"b": pd.Series([2, 2, 2]),
|
||||
"c": pd.Series([3, 3, 3]),
|
||||
},
|
||||
ValueError(
|
||||
"The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
||||
),
|
||||
),
|
||||
(
|
||||
"a == b == c",
|
||||
{
|
||||
"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
|
||||
"b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
|
||||
"c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
|
||||
},
|
||||
ValueError(
|
||||
"The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate_condition_with_pandas_exceptions(condition, state, expected_exception):
|
||||
condition_ast = ast.parse(condition, mode="eval").body
|
||||
with pytest.raises(type(expected_exception)) as exception_info:
|
||||
_ = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||
assert str(expected_exception) in str(exception_info.value)
|
||||
|
||||
|
||||
def test_get_safe_module_handle_lazy_imports():
|
||||
class FakeModule(types.ModuleType):
|
||||
def __init__(self, name):
|
||||
|
@ -1222,28 +1308,28 @@ def test_get_safe_module_handle_lazy_imports():
|
|||
|
||||
|
||||
def test_non_standard_comparisons():
|
||||
code = """
|
||||
class NonStdEqualsResult:
|
||||
code = dedent("""\
|
||||
class NonStdEqualsResult:
|
||||
def __init__(self, left:object, right:object):
|
||||
self._left = left
|
||||
self._right = right
|
||||
def __str__(self) -> str:
|
||||
return f'{self._left}=={self._right}'
|
||||
return f'{self._left} == {self._right}'
|
||||
|
||||
class NonStdComparisonClass:
|
||||
class NonStdComparisonClass:
|
||||
def __init__(self, value: str ):
|
||||
self._value = value
|
||||
def __str__(self):
|
||||
return self._value
|
||||
def __eq__(self, other):
|
||||
return NonStdEqualsResult(self, other)
|
||||
a = NonStdComparisonClass("a")
|
||||
b = NonStdComparisonClass("b")
|
||||
result = a == b
|
||||
"""
|
||||
a = NonStdComparisonClass("a")
|
||||
b = NonStdComparisonClass("b")
|
||||
result = a == b
|
||||
""")
|
||||
result, _ = evaluate_python_code(code, state={})
|
||||
assert not isinstance(result, bool)
|
||||
assert str(result) == "a==b"
|
||||
assert str(result) == "a == b"
|
||||
|
||||
|
||||
class TestPrintContainer:
|
||||
|
|
Loading…
Reference in New Issue