Fix get safe module (#405)

* expose get_safe_module lazy loading issue

* fix get_safe_module lazy imports

* add six as explicit tests dependency

* log dangerous attributes

* refactor get_safe_module test
This commit is contained in:
Antoine Jeannot 2025-01-31 08:25:22 +01:00 committed by GitHub
parent 181a500c5d
commit 10407813e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 2 deletions

View File

@ -18,6 +18,7 @@ import ast
import builtins import builtins
import difflib import difflib
import inspect import inspect
import logging
import math import math
import re import re
from collections.abc import Mapping from collections.abc import Mapping
@ -31,6 +32,9 @@ import pandas as pd
from .utils import BASE_BUILTIN_MODULES, truncate_content from .utils import BASE_BUILTIN_MODULES, truncate_content
logger = logging.getLogger(__name__)
class InterpreterError(ValueError): class InterpreterError(ValueError):
""" """
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
@ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
for pattern in dangerous_patterns for pattern in dangerous_patterns
): ):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
continue continue
attr_value = getattr(raw_module, attr_name) try:
attr_value = getattr(raw_module, attr_name)
except ImportError as e:
# lazy / dynamic loading module -> INFO log and skip
logger.info(
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
)
continue
# Recursively process nested modules, passing visited set # Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType): if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import types
import unittest import unittest
from textwrap import dedent from textwrap import dedent
@ -24,6 +25,7 @@ from smolagents.local_python_executor import (
InterpreterError, InterpreterError,
evaluate_python_code, evaluate_python_code,
fix_final_answer_code, fix_final_answer_code,
get_safe_module,
) )
@ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result):
state = {} state = {}
result, _ = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
assert result == expected_result assert result == expected_result
def test_get_safe_module_handle_lazy_imports():
class FakeModule(types.ModuleType):
def __init__(self, name):
super().__init__(name)
self.non_lazy_attribute = "ok"
def __getattr__(self, name):
if name == "lazy_attribute":
raise ImportError("lazy import failure")
return super().__getattr__(name)
def __dir__(self):
return super().__dir__() + ["lazy_attribute"]
fake_module = FakeModule("fake_module")
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
assert not hasattr(safe_module, "lazy_attribute")
assert getattr(safe_module, "non_lazy_attribute") == "ok"