From 41a388dac60013b9957768bd36a45cafb8aa5efe Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 13 Feb 2025 11:22:02 +0100 Subject: [PATCH] Refactor operations count state setting (#631) --- src/smolagents/local_python_executor.py | 3 +-- tests/test_local_python_executor.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index a3f5771..a5b12f3 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -1184,7 +1184,7 @@ def evaluate_ast( The list of modules that can be imported by the code. By default, only a few safe modules are allowed. If it contains "*", it will authorize any import. Use this at your own risk! """ - if state["_operations_count"] >= MAX_OPERATIONS: + if state.setdefault("_operations_count", 0) >= MAX_OPERATIONS: raise InterpreterError( f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." ) @@ -1363,7 +1363,6 @@ def evaluate_python_code( custom_tools = custom_tools if custom_tools is not None else {} result = None state["_print_outputs"] = PrintContainer() - state["_operations_count"] = 0 def final_answer(value): raise FinalAnswerException(value) diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index 670962e..9e96685 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -1149,7 +1149,6 @@ def test_evaluate_python_code_with_evaluate_delete(code, expected_error_message) ], ) def test_evaluate_delete(code, state, expectation): - state["_operations_count"] = 0 delete_node = ast.parse(code).body[0] if isinstance(expectation, Exception): with pytest.raises(type(expectation)) as exception_info: @@ -1157,7 +1156,7 @@ def test_evaluate_delete(code, state, expectation): assert str(expectation) in str(exception_info.value) else: evaluate_delete(delete_node, state, {}, {}, []) - del state["_operations_count"] + _ = state.pop("_operations_count", None) assert state == expectation