diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 5fe7401..86fa160 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= for attr_name in dir(raw_module): # Skip dangerous patterns at any level if any( - pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports + pattern in raw_module.__name__.split(".") + [attr_name] + and not check_module_authorized(pattern, authorized_imports, dangerous_patterns) for pattern in dangerous_patterns ): logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") @@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= return safe_module +def check_module_authorized(module_name, authorized_imports, dangerous_patterns): + if "*" in authorized_imports: + return True + else: + module_path = module_name.split(".") + if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): + return False + # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"] + module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] + return any(subpath in authorized_imports for subpath in module_subpaths) + + def import_modules(expression, state, authorized_imports): dangerous_patterns = ( "_os", @@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports): "multiprocessing", ) - def check_module_authorized(module_name): - if "*" in authorized_imports: - return True - else: - module_path = module_name.split(".") - if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): - return False - module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] - return any(subpath in authorized_imports for subpath in module_subpaths) - if isinstance(expression, ast.Import): for alias in expression.names: - if check_module_authorized(alias.name): + if check_module_authorized(alias.name, authorized_imports, dangerous_patterns): raw_module = import_module(alias.name) state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) else: @@ -1049,7 +1052,7 @@ def import_modules(expression, state, authorized_imports): ) return None elif isinstance(expression, ast.ImportFrom): - if check_module_authorized(expression.module): + if check_module_authorized(expression.module, authorized_imports, dangerous_patterns): raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) if expression.names[0].name == "*": # Handle "from module import *" diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index cab9e10..ca1f8b0 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -25,6 +25,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS from smolagents.local_python_executor import ( InterpreterError, PrintContainer, + check_module_authorized, evaluate_delete, evaluate_python_code, fix_final_answer_code, @@ -975,6 +976,10 @@ texec(tcompile("1 + 1", "no filename", "exec")) dangerous_code = "import os; os.listdir('./')" evaluate_python_code(dangerous_code, authorized_imports=["os"]) + def test_can_import_os_if_all_imports_authorized(self): + dangerous_code = "import os; os.listdir('./')" + evaluate_python_code(dangerous_code, authorized_imports=["*"]) + @pytest.mark.parametrize( "code, expected_result", @@ -1205,3 +1210,39 @@ class TestPrintContainer: pc = PrintContainer() pc.append("Hello") assert len(pc) == 5 + + +@pytest.mark.parametrize( + "module,authorized_imports,expected", + [ + ("os", ["*"], True), + ("AnyModule", ["*"], True), + ("os", ["os"], True), + ("AnyModule", ["AnyModule"], True), + ("Module.os", ["Module"], False), + ("Module.os", ["Module", "os"], True), + ("os.path", ["os"], True), + ("os", ["os.path"], False), + ], +) +def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool): + dangerous_patterns = ( + "_os", + "os", + "subprocess", + "_subprocess", + "pty", + "system", + "popen", + "spawn", + "shutil", + "sys", + "pathlib", + "io", + "socket", + "compile", + "eval", + "exec", + "multiprocessing", + ) + assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected