Fix get safe module (#405)
* expose get_safe_module lazy loading issue * fix get_safe_module lazy imports * add six as explicit tests dependency * log dangerous attributes * refactor get_safe_module test
This commit is contained in:
parent
181a500c5d
commit
10407813e8
|
@ -18,6 +18,7 @@ import ast
|
||||||
import builtins
|
import builtins
|
||||||
import difflib
|
import difflib
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
@ -31,6 +32,9 @@ import pandas as pd
|
||||||
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
"""
|
"""
|
||||||
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
|
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported
|
||||||
|
@ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
|
||||||
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
|
||||||
for pattern in dangerous_patterns
|
for pattern in dangerous_patterns
|
||||||
):
|
):
|
||||||
|
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
attr_value = getattr(raw_module, attr_name)
|
attr_value = getattr(raw_module, attr_name)
|
||||||
|
except ImportError as e:
|
||||||
|
# lazy / dynamic loading module -> INFO log and skip
|
||||||
|
logger.info(
|
||||||
|
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
# Recursively process nested modules, passing visited set
|
# Recursively process nested modules, passing visited set
|
||||||
if isinstance(attr_value, ModuleType):
|
if isinstance(attr_value, ModuleType):
|
||||||
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
|
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
@ -24,6 +25,7 @@ from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
fix_final_answer_code,
|
fix_final_answer_code,
|
||||||
|
get_safe_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result):
|
||||||
state = {}
|
state = {}
|
||||||
result, _ = evaluate_python_code(code, {}, state=state)
|
result, _ = evaluate_python_code(code, {}, state=state)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_safe_module_handle_lazy_imports():
|
||||||
|
class FakeModule(types.ModuleType):
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__(name)
|
||||||
|
self.non_lazy_attribute = "ok"
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name == "lazy_attribute":
|
||||||
|
raise ImportError("lazy import failure")
|
||||||
|
return super().__getattr__(name)
|
||||||
|
|
||||||
|
def __dir__(self):
|
||||||
|
return super().__dir__() + ["lazy_attribute"]
|
||||||
|
|
||||||
|
fake_module = FakeModule("fake_module")
|
||||||
|
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
|
||||||
|
assert not hasattr(safe_module, "lazy_attribute")
|
||||||
|
assert getattr(safe_module, "non_lazy_attribute") == "ok"
|
||||||
|
|
Loading…
Reference in New Issue