Fix blocking of os in authorized imports (#386)
This commit is contained in:
		
							parent
							
								
									a4f89b68b2
								
							
						
					
					
						commit
						ad5f84b101
					
				|  | @ -934,36 +934,39 @@ def evaluate_with( | ||||||
|             context.__exit__(None, None, None) |             context.__exit__(None, None, None) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_safe_module(unsafe_module, dangerous_patterns, visited=None): | def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=None): | ||||||
|     """Creates a safe copy of a module or returns the original if it's a function""" |     """Creates a safe copy of a module or returns the original if it's a function""" | ||||||
|     # If it's a function or non-module object, return it directly |     # If it's a function or non-module object, return it directly | ||||||
|     if not isinstance(unsafe_module, ModuleType): |     if not isinstance(raw_module, ModuleType): | ||||||
|         return unsafe_module |         return raw_module | ||||||
| 
 | 
 | ||||||
|     # Handle circular references: Initialize visited set for the first call |     # Handle circular references: Initialize visited set for the first call | ||||||
|     if visited is None: |     if visited is None: | ||||||
|         visited = set() |         visited = set() | ||||||
| 
 | 
 | ||||||
|     module_id = id(unsafe_module) |     module_id = id(raw_module) | ||||||
|     if module_id in visited: |     if module_id in visited: | ||||||
|         return unsafe_module  # Return original for circular refs |         return raw_module  # Return original for circular refs | ||||||
| 
 | 
 | ||||||
|     visited.add(module_id) |     visited.add(module_id) | ||||||
| 
 | 
 | ||||||
|     # Create new module for actual modules |     # Create new module for actual modules | ||||||
|     safe_module = ModuleType(unsafe_module.__name__) |     safe_module = ModuleType(raw_module.__name__) | ||||||
| 
 | 
 | ||||||
|     # Copy all attributes by reference, recursively checking modules |     # Copy all attributes by reference, recursively checking modules | ||||||
|     for attr_name in dir(unsafe_module): |     for attr_name in dir(raw_module): | ||||||
|         # Skip dangerous patterns at any level |         # Skip dangerous patterns at any level | ||||||
|         if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns): |         if any( | ||||||
|  |             pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports | ||||||
|  |             for pattern in dangerous_patterns | ||||||
|  |         ): | ||||||
|             continue |             continue | ||||||
| 
 | 
 | ||||||
|         attr_value = getattr(unsafe_module, attr_name) |         attr_value = getattr(raw_module, attr_name) | ||||||
| 
 | 
 | ||||||
|         # Recursively process nested modules, passing visited set |         # Recursively process nested modules, passing visited set | ||||||
|         if isinstance(attr_value, ModuleType): |         if isinstance(attr_value, ModuleType): | ||||||
|             attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited) |             attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) | ||||||
| 
 | 
 | ||||||
|         setattr(safe_module, attr_name, attr_value) |         setattr(safe_module, attr_name, attr_value) | ||||||
| 
 | 
 | ||||||
|  | @ -996,7 +999,7 @@ def import_modules(expression, state, authorized_imports): | ||||||
|             return True |             return True | ||||||
|         else: |         else: | ||||||
|             module_path = module_name.split(".") |             module_path = module_name.split(".") | ||||||
|             if any([module in dangerous_patterns for module in module_path]): |             if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): | ||||||
|                 return False |                 return False | ||||||
|             module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] |             module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] | ||||||
|             return any(subpath in authorized_imports for subpath in module_subpaths) |             return any(subpath in authorized_imports for subpath in module_subpaths) | ||||||
|  | @ -1005,7 +1008,7 @@ def import_modules(expression, state, authorized_imports): | ||||||
|         for alias in expression.names: |         for alias in expression.names: | ||||||
|             if check_module_authorized(alias.name): |             if check_module_authorized(alias.name): | ||||||
|                 raw_module = import_module(alias.name) |                 raw_module = import_module(alias.name) | ||||||
|                 state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns) |                 state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) | ||||||
|             else: |             else: | ||||||
|                 raise InterpreterError( |                 raise InterpreterError( | ||||||
|                     f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" |                     f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" | ||||||
|  | @ -1013,7 +1016,8 @@ def import_modules(expression, state, authorized_imports): | ||||||
|         return None |         return None | ||||||
|     elif isinstance(expression, ast.ImportFrom): |     elif isinstance(expression, ast.ImportFrom): | ||||||
|         if check_module_authorized(expression.module): |         if check_module_authorized(expression.module): | ||||||
|             module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) |             raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) | ||||||
|  |             module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) | ||||||
|             if expression.names[0].name == "*":  # Handle "from module import *" |             if expression.names[0].name == "*":  # Handle "from module import *" | ||||||
|                 if hasattr(module, "__all__"):  # If module has __all__, import only those names |                 if hasattr(module, "__all__"):  # If module has __all__, import only those names | ||||||
|                     for name in module.__all__: |                     for name in module.__all__: | ||||||
|  |  | ||||||
|  | @ -950,6 +950,10 @@ texec(tcompile("1 + 1", "no filename", "exec")) | ||||||
|             dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS |             dangerous_code, static_tools={"tcompile": compile, "teval": eval, "texec": exec} | BASE_PYTHON_TOOLS | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def test_can_import_os_if_explicitly_authorized(self): | ||||||
|  |         dangerous_code = "import os; os.listdir('./')" | ||||||
|  |         evaluate_python_code(dangerous_code, authorized_imports=["os"]) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "code, expected_result", |     "code, expected_result", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue