Make dangerous_patterns a module variable (#505)

This commit is contained in:
CalOmnie 2025-02-18 11:03:22 +01:00 committed by GitHub
parent 2982e88409
commit 1df9dca33a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 51 deletions

View File

@ -114,6 +114,26 @@ BASE_PYTHON_TOOLS = {
"complex": complex,
}
DANGEROUS_PATTERNS = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
class PrintContainer:
def __init__(self):
@ -954,7 +974,7 @@ def evaluate_with(
context.__exit__(None, None, None)
def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None):
def get_safe_module(raw_module, authorized_imports, visited=None):
"""Creates a safe copy of a module or returns the original if it's a function"""
# If it's a function or non-module object, return it directly
if not isinstance(raw_module, ModuleType):
@ -978,8 +998,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
# Skip dangerous patterns at any level
if any(
pattern in raw_module.__name__.split(".") + [attr_name]
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
for pattern in dangerous_patterns
and not check_module_authorized(pattern, authorized_imports)
for pattern in DANGEROUS_PATTERNS
):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
continue
@ -994,19 +1014,19 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
continue
# Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited)
attr_value = get_safe_module(attr_value, authorized_imports, visited=visited)
setattr(safe_module, attr_name, attr_value)
return safe_module
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
def check_module_authorized(module_name, authorized_imports):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]):
return False
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
@ -1014,40 +1034,20 @@ def check_module_authorized(module_name, authorized_imports, dangerous_patterns)
def import_modules(expression, state, authorized_imports):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
if check_module_authorized(alias.name, authorized_imports):
raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports)
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
if check_module_authorized(expression.module, authorized_imports):
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
module = get_safe_module(raw_module, authorized_imports)
if expression.names[0].name == "*": # Handle "from module import *"
if hasattr(module, "__all__"): # If module has __all__, import only those names
for name in module.__all__:

View File

@ -1302,7 +1302,7 @@ def test_get_safe_module_handle_lazy_imports():
return super().__dir__() + ["lazy_attribute"]
fake_module = FakeModule("fake_module")
safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set())
safe_module = get_safe_module(fake_module, authorized_imports=set())
assert not hasattr(safe_module, "lazy_attribute")
assert getattr(safe_module, "non_lazy_attribute") == "ok"
@ -1377,23 +1377,4 @@ class TestPrintContainer:
],
)
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected
assert check_module_authorized(module, authorized_imports) == expected