From cfe599c54a81412ea334e1f9d1f17189772428ef Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 13 Feb 2025 12:03:52 +0100 Subject: [PATCH] Test evaluate_condition (#634) --- tests/test_local_python_executor.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index 9e96685..2abef69 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -26,6 +26,7 @@ from smolagents.local_python_executor import ( InterpreterError, PrintContainer, check_module_authorized, + evaluate_condition, evaluate_delete, evaluate_python_code, fix_final_answer_code, @@ -1160,6 +1161,46 @@ def test_evaluate_delete(code, state, expectation): assert state == expectation +@pytest.mark.parametrize( + "condition, state, expected_result", + [ + ("a == b", {"a": 1, "b": 1}, True), + ("a == b", {"a": 1, "b": 2}, False), + ("a != b", {"a": 1, "b": 1}, False), + ("a != b", {"a": 1, "b": 2}, True), + ("a < b", {"a": 1, "b": 1}, False), + ("a < b", {"a": 1, "b": 2}, True), + ("a < b", {"a": 2, "b": 1}, False), + ("a <= b", {"a": 1, "b": 1}, True), + ("a <= b", {"a": 1, "b": 2}, True), + ("a <= b", {"a": 2, "b": 1}, False), + ("a > b", {"a": 1, "b": 1}, False), + ("a > b", {"a": 1, "b": 2}, False), + ("a > b", {"a": 2, "b": 1}, True), + ("a >= b", {"a": 1, "b": 1}, True), + ("a >= b", {"a": 1, "b": 2}, False), + ("a >= b", {"a": 2, "b": 1}, True), + ("a is b", {"a": 1, "b": 1}, True), + ("a is b", {"a": 1, "b": 2}, False), + ("a is not b", {"a": 1, "b": 1}, False), + ("a is not b", {"a": 1, "b": 2}, True), + ("a in b", {"a": 1, "b": [1, 2, 3]}, True), + ("a in b", {"a": 4, "b": [1, 2, 3]}, False), + ("a not in b", {"a": 1, "b": [1, 2, 3]}, False), + ("a not in b", {"a": 4, "b": [1, 2, 3]}, True), + # Composite conditions: + ("a == b == c", {"a": 1, "b": 1, "c": 1}, True), + ("a == b == c", {"a": 1, "b": 2, "c": 1}, False), + ("a == b < c", {"a": 1, "b": 1, "c": 1}, False), + ("a == b < c", {"a": 1, "b": 1, "c": 2}, True), + ], +) +def test_evaluate_condition(condition, state, expected_result): + condition_ast = ast.parse(condition, mode="eval").body + result = evaluate_condition(condition_ast, state, {}, {}, []) + assert result == expected_result + + def test_get_safe_module_handle_lazy_imports(): class FakeModule(types.ModuleType): def __init__(self, name):