Test evaluate_condition (#634)
This commit is contained in:
parent
f3ee6052db
commit
cfe599c54a
|
@ -26,6 +26,7 @@ from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
PrintContainer,
|
PrintContainer,
|
||||||
check_module_authorized,
|
check_module_authorized,
|
||||||
|
evaluate_condition,
|
||||||
evaluate_delete,
|
evaluate_delete,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
fix_final_answer_code,
|
fix_final_answer_code,
|
||||||
|
@ -1160,6 +1161,46 @@ def test_evaluate_delete(code, state, expectation):
|
||||||
assert state == expectation
|
assert state == expectation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"condition, state, expected_result",
|
||||||
|
[
|
||||||
|
("a == b", {"a": 1, "b": 1}, True),
|
||||||
|
("a == b", {"a": 1, "b": 2}, False),
|
||||||
|
("a != b", {"a": 1, "b": 1}, False),
|
||||||
|
("a != b", {"a": 1, "b": 2}, True),
|
||||||
|
("a < b", {"a": 1, "b": 1}, False),
|
||||||
|
("a < b", {"a": 1, "b": 2}, True),
|
||||||
|
("a < b", {"a": 2, "b": 1}, False),
|
||||||
|
("a <= b", {"a": 1, "b": 1}, True),
|
||||||
|
("a <= b", {"a": 1, "b": 2}, True),
|
||||||
|
("a <= b", {"a": 2, "b": 1}, False),
|
||||||
|
("a > b", {"a": 1, "b": 1}, False),
|
||||||
|
("a > b", {"a": 1, "b": 2}, False),
|
||||||
|
("a > b", {"a": 2, "b": 1}, True),
|
||||||
|
("a >= b", {"a": 1, "b": 1}, True),
|
||||||
|
("a >= b", {"a": 1, "b": 2}, False),
|
||||||
|
("a >= b", {"a": 2, "b": 1}, True),
|
||||||
|
("a is b", {"a": 1, "b": 1}, True),
|
||||||
|
("a is b", {"a": 1, "b": 2}, False),
|
||||||
|
("a is not b", {"a": 1, "b": 1}, False),
|
||||||
|
("a is not b", {"a": 1, "b": 2}, True),
|
||||||
|
("a in b", {"a": 1, "b": [1, 2, 3]}, True),
|
||||||
|
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
||||||
|
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
||||||
|
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
||||||
|
# Composite conditions:
|
||||||
|
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
||||||
|
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
||||||
|
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
|
||||||
|
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_evaluate_condition(condition, state, expected_result):
|
||||||
|
condition_ast = ast.parse(condition, mode="eval").body
|
||||||
|
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
def test_get_safe_module_handle_lazy_imports():
|
def test_get_safe_module_handle_lazy_imports():
|
||||||
class FakeModule(types.ModuleType):
|
class FakeModule(types.ModuleType):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
|
|
Loading…
Reference in New Issue