Fix subpackage import vulnerability (#238)
* Fix subpackage import vulnerability
This commit is contained in:
		
							parent
							
								
									d5c2ef48e7
								
							
						
					
					
						commit
						c255c1ff84
					
				|  | @ -120,21 +120,15 @@ class GradioUI: | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         if file is None: |         if file is None: | ||||||
|             return gr.Textbox( |             return gr.Textbox("No file uploaded", visible=True), file_uploads_log | ||||||
|                 "No file uploaded", visible=True |  | ||||||
|             ), file_uploads_log |  | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             mime_type, _ = mimetypes.guess_type(file.name) |             mime_type, _ = mimetypes.guess_type(file.name) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             return gr.Textbox( |             return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log | ||||||
|                 f"Error: {e}", visible=True |  | ||||||
|             ), file_uploads_log |  | ||||||
| 
 | 
 | ||||||
|         if mime_type not in allowed_file_types: |         if mime_type not in allowed_file_types: | ||||||
|             return gr.Textbox( |             return gr.Textbox("File type disallowed", visible=True), file_uploads_log | ||||||
|                 "File type disallowed", visible=True |  | ||||||
|             ), file_uploads_log |  | ||||||
| 
 | 
 | ||||||
|         # Sanitize file name |         # Sanitize file name | ||||||
|         original_name = os.path.basename(file.name) |         original_name = os.path.basename(file.name) | ||||||
|  |  | ||||||
|  | @ -21,6 +21,7 @@ import math | ||||||
| import re | import re | ||||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
|  | from types import ModuleType | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | from typing import Any, Callable, Dict, List, Optional, Tuple | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
|  | @ -1046,12 +1047,75 @@ def evaluate_with( | ||||||
|             context.__exit__(None, None, None) |             context.__exit__(None, None, None) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_safe_module(unsafe_module, dangerous_patterns, visited=None): | ||||||
|  |     """Creates a safe copy of a module or returns the original if it's a function""" | ||||||
|  |     # If it's a function or non-module object, return it directly | ||||||
|  |     if not isinstance(unsafe_module, ModuleType): | ||||||
|  |         return unsafe_module | ||||||
|  | 
 | ||||||
|  |     # Handle circular references: Initialize visited set for the first call | ||||||
|  |     if visited is None: | ||||||
|  |         visited = set() | ||||||
|  | 
 | ||||||
|  |     module_id = id(unsafe_module) | ||||||
|  |     if module_id in visited: | ||||||
|  |         return unsafe_module  # Return original for circular refs | ||||||
|  | 
 | ||||||
|  |     visited.add(module_id) | ||||||
|  | 
 | ||||||
|  |     # Create new module for actual modules | ||||||
|  |     safe_module = ModuleType(unsafe_module.__name__) | ||||||
|  | 
 | ||||||
|  |     # Copy all attributes by reference, recursively checking modules | ||||||
|  |     for attr_name in dir(unsafe_module): | ||||||
|  |         # Skip dangerous patterns at any level | ||||||
|  |         if any( | ||||||
|  |             pattern in f"{unsafe_module.__name__}.{attr_name}" | ||||||
|  |             for pattern in dangerous_patterns | ||||||
|  |         ): | ||||||
|  |             continue | ||||||
|  | 
 | ||||||
|  |         attr_value = getattr(unsafe_module, attr_name) | ||||||
|  | 
 | ||||||
|  |         # Recursively process nested modules, passing visited set | ||||||
|  |         if isinstance(attr_value, ModuleType): | ||||||
|  |             attr_value = get_safe_module( | ||||||
|  |                 attr_value, dangerous_patterns, visited=visited | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         setattr(safe_module, attr_name, attr_value) | ||||||
|  | 
 | ||||||
|  |     return safe_module | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def import_modules(expression, state, authorized_imports): | def import_modules(expression, state, authorized_imports): | ||||||
|  |     dangerous_patterns = ( | ||||||
|  |         "_os", | ||||||
|  |         "os", | ||||||
|  |         "subprocess", | ||||||
|  |         "_subprocess", | ||||||
|  |         "pty", | ||||||
|  |         "system", | ||||||
|  |         "popen", | ||||||
|  |         "spawn", | ||||||
|  |         "shutil", | ||||||
|  |         "glob", | ||||||
|  |         "pathlib", | ||||||
|  |         "io", | ||||||
|  |         "socket", | ||||||
|  |         "compile", | ||||||
|  |         "eval", | ||||||
|  |         "exec", | ||||||
|  |         "multiprocessing", | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|     def check_module_authorized(module_name): |     def check_module_authorized(module_name): | ||||||
|         if "*" in authorized_imports: |         if "*" in authorized_imports: | ||||||
|             return True |             return True | ||||||
|         else: |         else: | ||||||
|             module_path = module_name.split(".") |             module_path = module_name.split(".") | ||||||
|  |             if any([module in dangerous_patterns for module in module_path]): | ||||||
|  |                 return False | ||||||
|             module_subpaths = [ |             module_subpaths = [ | ||||||
|                 ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) |                 ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) | ||||||
|             ] |             ] | ||||||
|  | @ -1060,8 +1124,10 @@ def import_modules(expression, state, authorized_imports): | ||||||
|     if isinstance(expression, ast.Import): |     if isinstance(expression, ast.Import): | ||||||
|         for alias in expression.names: |         for alias in expression.names: | ||||||
|             if check_module_authorized(alias.name): |             if check_module_authorized(alias.name): | ||||||
|                 module = import_module(alias.name) |                 raw_module = import_module(alias.name) | ||||||
|                 state[alias.asname or alias.name] = module |                 state[alias.asname or alias.name] = get_safe_module( | ||||||
|  |                     raw_module, dangerous_patterns | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 raise InterpreterError( |                 raise InterpreterError( | ||||||
|                     f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" |                     f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" | ||||||
|  | @ -1069,11 +1135,13 @@ def import_modules(expression, state, authorized_imports): | ||||||
|         return None |         return None | ||||||
|     elif isinstance(expression, ast.ImportFrom): |     elif isinstance(expression, ast.ImportFrom): | ||||||
|         if check_module_authorized(expression.module): |         if check_module_authorized(expression.module): | ||||||
|             module = __import__( |             raw_module = __import__( | ||||||
|                 expression.module, fromlist=[alias.name for alias in expression.names] |                 expression.module, fromlist=[alias.name for alias in expression.names] | ||||||
|             ) |             ) | ||||||
|             for alias in expression.names: |             for alias in expression.names: | ||||||
|                 state[alias.asname or alias.name] = getattr(module, alias.name) |                 state[alias.asname or alias.name] = get_safe_module( | ||||||
|  |                     getattr(raw_module, alias.name), dangerous_patterns | ||||||
|  |                 ) | ||||||
|         else: |         else: | ||||||
|             raise InterpreterError(f"Import from {expression.module} is not allowed.") |             raise InterpreterError(f"Import from {expression.module} is not allowed.") | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  | @ -920,3 +920,19 @@ shift_intervals | ||||||
|     Expected: {expected} |     Expected: {expected} | ||||||
|     Got:      {result} |     Got:      {result} | ||||||
|     """ |     """ | ||||||
|  | 
 | ||||||
|  |     def test_dangerous_subpackage_access_blocked(self): | ||||||
|  |         # Direct imports with dangerous patterns should fail | ||||||
|  |         code = "import random._os" | ||||||
|  |         with pytest.raises(InterpreterError): | ||||||
|  |             evaluate_python_code(code) | ||||||
|  | 
 | ||||||
|  |         # Import of whitelisted modules should succeed but dangerous submodules should not exist | ||||||
|  |         code = "import random;random._os.system('echo bad command passed')" | ||||||
|  |         with pytest.raises(AttributeError) as e: | ||||||
|  |             evaluate_python_code(code) | ||||||
|  |         assert "module 'random' has no attribute '_os'" in str(e) | ||||||
|  | 
 | ||||||
|  |         code = "import doctest;doctest.inspect.os.system('echo bad command passed')" | ||||||
|  |         with pytest.raises(AttributeError): | ||||||
|  |             evaluate_python_code(code, authorized_imports=["doctest"]) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue