Test evaluate_condition (#634)
This commit is contained in:
parent
f3ee6052db
commit
cfe599c54a
|
@ -26,6 +26,7 @@ from smolagents.local_python_executor import (
|
|||
InterpreterError,
|
||||
PrintContainer,
|
||||
check_module_authorized,
|
||||
evaluate_condition,
|
||||
evaluate_delete,
|
||||
evaluate_python_code,
|
||||
fix_final_answer_code,
|
||||
|
@ -1160,6 +1161,46 @@ def test_evaluate_delete(code, state, expectation):
|
|||
assert state == expectation
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition, state, expected_result",
|
||||
[
|
||||
("a == b", {"a": 1, "b": 1}, True),
|
||||
("a == b", {"a": 1, "b": 2}, False),
|
||||
("a != b", {"a": 1, "b": 1}, False),
|
||||
("a != b", {"a": 1, "b": 2}, True),
|
||||
("a < b", {"a": 1, "b": 1}, False),
|
||||
("a < b", {"a": 1, "b": 2}, True),
|
||||
("a < b", {"a": 2, "b": 1}, False),
|
||||
("a <= b", {"a": 1, "b": 1}, True),
|
||||
("a <= b", {"a": 1, "b": 2}, True),
|
||||
("a <= b", {"a": 2, "b": 1}, False),
|
||||
("a > b", {"a": 1, "b": 1}, False),
|
||||
("a > b", {"a": 1, "b": 2}, False),
|
||||
("a > b", {"a": 2, "b": 1}, True),
|
||||
("a >= b", {"a": 1, "b": 1}, True),
|
||||
("a >= b", {"a": 1, "b": 2}, False),
|
||||
("a >= b", {"a": 2, "b": 1}, True),
|
||||
("a is b", {"a": 1, "b": 1}, True),
|
||||
("a is b", {"a": 1, "b": 2}, False),
|
||||
("a is not b", {"a": 1, "b": 1}, False),
|
||||
("a is not b", {"a": 1, "b": 2}, True),
|
||||
("a in b", {"a": 1, "b": [1, 2, 3]}, True),
|
||||
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
|
||||
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
|
||||
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
|
||||
# Composite conditions:
|
||||
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
|
||||
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
|
||||
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
|
||||
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
|
||||
],
|
||||
)
|
||||
def test_evaluate_condition(condition, state, expected_result):
|
||||
condition_ast = ast.parse(condition, mode="eval").body
|
||||
result = evaluate_condition(condition_ast, state, {}, {}, [])
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_get_safe_module_handle_lazy_imports():
|
||||
class FakeModule(types.ModuleType):
|
||||
def __init__(self, name):
|
||||
|
|
Loading…
Reference in New Issue