Fix subpackage import vulnerability (#238)
* Fix subpackage import vulnerability
This commit is contained in:
		
							parent
							
								
									d5c2ef48e7
								
							
						
					
					
						commit
						c255c1ff84
					
				|  | @ -120,21 +120,15 @@ class GradioUI: | |||
|         """ | ||||
| 
 | ||||
|         if file is None: | ||||
|             return gr.Textbox( | ||||
|                 "No file uploaded", visible=True | ||||
|             ), file_uploads_log | ||||
|             return gr.Textbox("No file uploaded", visible=True), file_uploads_log | ||||
| 
 | ||||
|         try: | ||||
|             mime_type, _ = mimetypes.guess_type(file.name) | ||||
|         except Exception as e: | ||||
|             return gr.Textbox( | ||||
|                 f"Error: {e}", visible=True | ||||
|             ), file_uploads_log | ||||
|             return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log | ||||
| 
 | ||||
|         if mime_type not in allowed_file_types: | ||||
|             return gr.Textbox( | ||||
|                 "File type disallowed", visible=True | ||||
|             ), file_uploads_log | ||||
|             return gr.Textbox("File type disallowed", visible=True), file_uploads_log | ||||
| 
 | ||||
|         # Sanitize file name | ||||
|         original_name = os.path.basename(file.name) | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ import math | |||
| import re | ||||
| from collections.abc import Mapping | ||||
| from importlib import import_module | ||||
| from types import ModuleType | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | ||||
| 
 | ||||
| import numpy as np | ||||
|  | @ -1046,12 +1047,75 @@ def evaluate_with( | |||
|             context.__exit__(None, None, None) | ||||
| 
 | ||||
| 
 | ||||
| def get_safe_module(unsafe_module, dangerous_patterns, visited=None): | ||||
|     """Creates a safe copy of a module or returns the original if it's a function""" | ||||
|     # If it's a function or non-module object, return it directly | ||||
|     if not isinstance(unsafe_module, ModuleType): | ||||
|         return unsafe_module | ||||
| 
 | ||||
|     # Handle circular references: Initialize visited set for the first call | ||||
|     if visited is None: | ||||
|         visited = set() | ||||
| 
 | ||||
|     module_id = id(unsafe_module) | ||||
|     if module_id in visited: | ||||
|         return unsafe_module  # Return original for circular refs | ||||
| 
 | ||||
|     visited.add(module_id) | ||||
| 
 | ||||
|     # Create new module for actual modules | ||||
|     safe_module = ModuleType(unsafe_module.__name__) | ||||
| 
 | ||||
|     # Copy all attributes by reference, recursively checking modules | ||||
|     for attr_name in dir(unsafe_module): | ||||
|         # Skip dangerous patterns at any level | ||||
|         if any( | ||||
|             pattern in f"{unsafe_module.__name__}.{attr_name}" | ||||
|             for pattern in dangerous_patterns | ||||
|         ): | ||||
|             continue | ||||
| 
 | ||||
|         attr_value = getattr(unsafe_module, attr_name) | ||||
| 
 | ||||
|         # Recursively process nested modules, passing visited set | ||||
|         if isinstance(attr_value, ModuleType): | ||||
|             attr_value = get_safe_module( | ||||
|                 attr_value, dangerous_patterns, visited=visited | ||||
|             ) | ||||
| 
 | ||||
|         setattr(safe_module, attr_name, attr_value) | ||||
| 
 | ||||
|     return safe_module | ||||
| 
 | ||||
| 
 | ||||
| def import_modules(expression, state, authorized_imports): | ||||
|     dangerous_patterns = ( | ||||
|         "_os", | ||||
|         "os", | ||||
|         "subprocess", | ||||
|         "_subprocess", | ||||
|         "pty", | ||||
|         "system", | ||||
|         "popen", | ||||
|         "spawn", | ||||
|         "shutil", | ||||
|         "glob", | ||||
|         "pathlib", | ||||
|         "io", | ||||
|         "socket", | ||||
|         "compile", | ||||
|         "eval", | ||||
|         "exec", | ||||
|         "multiprocessing", | ||||
|     ) | ||||
| 
 | ||||
|     def check_module_authorized(module_name): | ||||
|         if "*" in authorized_imports: | ||||
|             return True | ||||
|         else: | ||||
|             module_path = module_name.split(".") | ||||
|             if any([module in dangerous_patterns for module in module_path]): | ||||
|                 return False | ||||
|             module_subpaths = [ | ||||
|                 ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) | ||||
|             ] | ||||
|  | @ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports): | |||
|     if isinstance(expression, ast.Import): | ||||
|         for alias in expression.names: | ||||
|             if check_module_authorized(alias.name): | ||||
|                 module = import_module(alias.name) | ||||
|                 state[alias.asname or alias.name] = module | ||||
|                 raw_module = import_module(alias.name) | ||||
|                 state[alias.asname or alias.name] = get_safe_module( | ||||
|                     raw_module, dangerous_patterns | ||||
|                 ) | ||||
|             else: | ||||
|                 raise InterpreterError( | ||||
|                     f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" | ||||
|  | @ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports): | |||
|         return None | ||||
|     elif isinstance(expression, ast.ImportFrom): | ||||
|         if check_module_authorized(expression.module): | ||||
|             module = __import__( | ||||
|             raw_module = __import__( | ||||
|                 expression.module, fromlist=[alias.name for alias in expression.names] | ||||
|             ) | ||||
|             for alias in expression.names: | ||||
|                 state[alias.asname or alias.name] = getattr(module, alias.name) | ||||
|                 state[alias.asname or alias.name] = get_safe_module( | ||||
|                     getattr(raw_module, alias.name), dangerous_patterns | ||||
|                 ) | ||||
|         else: | ||||
|             raise InterpreterError(f"Import from {expression.module} is not allowed.") | ||||
|         return None | ||||
|  |  | |||
|  | @ -920,3 +920,19 @@ shift_intervals | |||
|     Expected: {expected} | ||||
|     Got:      {result} | ||||
|     """ | ||||
| 
 | ||||
|     def test_dangerous_subpackage_access_blocked(self): | ||||
|         # Direct imports with dangerous patterns should fail | ||||
|         code = "import random._os" | ||||
|         with pytest.raises(InterpreterError): | ||||
|             evaluate_python_code(code) | ||||
| 
 | ||||
|         # Import of whitelisted modules should succeed but dangerous submodules should not exist | ||||
|         code = "import random;random._os.system('echo bad command passed')" | ||||
|         with pytest.raises(AttributeError) as e: | ||||
|             evaluate_python_code(code) | ||||
|         assert "module 'random' has no attribute '_os'" in str(e) | ||||
| 
 | ||||
|         code = "import doctest;doctest.inspect.os.system('echo bad command passed')" | ||||
|         with pytest.raises(AttributeError): | ||||
|             evaluate_python_code(code, authorized_imports=["doctest"]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue