diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 86fa160..4f7bb95 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -713,7 +713,7 @@ def evaluate_condition( static_tools: Dict[str, Callable], custom_tools: Dict[str, Callable], authorized_imports: List[str], -) -> bool: +) -> bool | object: left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) comparators = [ evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators @@ -747,6 +747,9 @@ def evaluate_condition( else: raise InterpreterError(f"Operator not supported: {op}") + if not isinstance(current_result, bool): + return current_result + result = result & current_result current_left = comparator diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index ca1f8b0..670962e 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -1181,6 +1181,31 @@ def test_get_safe_module_handle_lazy_imports(): assert getattr(safe_module, "non_lazy_attribute") == "ok" +def test_non_standard_comparisons(): + code = """ +class NonStdEqualsResult: + def __init__(self, left:object, right:object): + self._left = left + self._right = right + def __str__(self) -> str: + return f'{self._left}=={self._right}' + +class NonStdComparisonClass: + def __init__(self, value: str ): + self._value = value + def __str__(self): + return self._value + def __eq__(self, other): + return NonStdEqualsResult(self, other) +a = NonStdComparisonClass("a") +b = NonStdComparisonClass("b") +result = a == b + """ + result, _ = evaluate_python_code(code, state={}) + assert not isinstance(result, bool) + assert str(result) == "a==b" + + class TestPrintContainer: def test_initial_value(self): pc = PrintContainer()