diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 6f3e1eb..7288640 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -18,6 +18,7 @@ import ast import builtins import difflib import inspect +import logging import math import re from collections.abc import Mapping @@ -31,6 +32,9 @@ import pandas as pd from .utils import BASE_BUILTIN_MODULES, truncate_content +logger = logging.getLogger(__name__) + + class InterpreterError(ValueError): """ An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported @@ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports for pattern in dangerous_patterns ): + logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") continue - attr_value = getattr(raw_module, attr_name) - + try: + attr_value = getattr(raw_module, attr_name) + except ImportError as e: + # lazy / dynamic loading module -> INFO log and skip + logger.info( + f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}" + ) + continue # Recursively process nested modules, passing visited set if isinstance(attr_value, ModuleType): attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index a44fbea..ad8b99d 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types import unittest from textwrap import dedent @@ -24,6 +25,7 @@ from smolagents.local_python_executor import ( InterpreterError, evaluate_python_code, fix_final_answer_code, + get_safe_module, ) @@ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result): state = {} result, _ = evaluate_python_code(code, {}, state=state) assert result == expected_result + + +def test_get_safe_module_handle_lazy_imports(): + class FakeModule(types.ModuleType): + def __init__(self, name): + super().__init__(name) + self.non_lazy_attribute = "ok" + + def __getattr__(self, name): + if name == "lazy_attribute": + raise ImportError("lazy import failure") + return super().__getattr__(name) + + def __dir__(self): + return super().__dir__() + ["lazy_attribute"] + + fake_module = FakeModule("fake_module") + safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set()) + assert not hasattr(safe_module, "lazy_attribute") + assert getattr(safe_module, "non_lazy_attribute") == "ok"