Move check_module_authorized out of import_module for use in get_safe_module (#507)
Co-authored-by: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									8c6f90cc11
								
							
						
					
					
						commit
						9318c8cae3
					
				|  | @ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= | ||||||
|     for attr_name in dir(raw_module): |     for attr_name in dir(raw_module): | ||||||
|         # Skip dangerous patterns at any level |         # Skip dangerous patterns at any level | ||||||
|         if any( |         if any( | ||||||
|             pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports |             pattern in raw_module.__name__.split(".") + [attr_name] | ||||||
|  |             and not check_module_authorized(pattern, authorized_imports, dangerous_patterns) | ||||||
|             for pattern in dangerous_patterns |             for pattern in dangerous_patterns | ||||||
|         ): |         ): | ||||||
|             logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") |             logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") | ||||||
|  | @ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= | ||||||
|     return safe_module |     return safe_module | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def check_module_authorized(module_name, authorized_imports, dangerous_patterns): | ||||||
|  |     if "*" in authorized_imports: | ||||||
|  |         return True | ||||||
|  |     else: | ||||||
|  |         module_path = module_name.split(".") | ||||||
|  |         if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): | ||||||
|  |             return False | ||||||
|  |         # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"] | ||||||
|  |         module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] | ||||||
|  |         return any(subpath in authorized_imports for subpath in module_subpaths) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def import_modules(expression, state, authorized_imports): | def import_modules(expression, state, authorized_imports): | ||||||
|     dangerous_patterns = ( |     dangerous_patterns = ( | ||||||
|         "_os", |         "_os", | ||||||
|  | @ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports): | ||||||
|         "multiprocessing", |         "multiprocessing", | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     def check_module_authorized(module_name): |  | ||||||
|         if "*" in authorized_imports: |  | ||||||
|             return True |  | ||||||
|         else: |  | ||||||
|             module_path = module_name.split(".") |  | ||||||
|             if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]): |  | ||||||
|                 return False |  | ||||||
|             module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] |  | ||||||
|             return any(subpath in authorized_imports for subpath in module_subpaths) |  | ||||||
| 
 |  | ||||||
|     if isinstance(expression, ast.Import): |     if isinstance(expression, ast.Import): | ||||||
|         for alias in expression.names: |         for alias in expression.names: | ||||||
|             if check_module_authorized(alias.name): |             if check_module_authorized(alias.name, authorized_imports, dangerous_patterns): | ||||||
|                 raw_module = import_module(alias.name) |                 raw_module = import_module(alias.name) | ||||||
|                 state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) |                 state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports) | ||||||
|             else: |             else: | ||||||
|  | @ -1049,7 +1052,7 @@ def import_modules(expression, state, authorized_imports): | ||||||
|                 ) |                 ) | ||||||
|         return None |         return None | ||||||
|     elif isinstance(expression, ast.ImportFrom): |     elif isinstance(expression, ast.ImportFrom): | ||||||
|         if check_module_authorized(expression.module): |         if check_module_authorized(expression.module, authorized_imports, dangerous_patterns): | ||||||
|             raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) |             raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) | ||||||
|             module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) |             module = get_safe_module(raw_module, dangerous_patterns, authorized_imports) | ||||||
|             if expression.names[0].name == "*":  # Handle "from module import *" |             if expression.names[0].name == "*":  # Handle "from module import *" | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||||
| from smolagents.local_python_executor import ( | from smolagents.local_python_executor import ( | ||||||
|     InterpreterError, |     InterpreterError, | ||||||
|     PrintContainer, |     PrintContainer, | ||||||
|  |     check_module_authorized, | ||||||
|     evaluate_delete, |     evaluate_delete, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
|     fix_final_answer_code, |     fix_final_answer_code, | ||||||
|  | @ -975,6 +976,10 @@ texec(tcompile("1 + 1", "no filename", "exec")) | ||||||
|         dangerous_code = "import os; os.listdir('./')" |         dangerous_code = "import os; os.listdir('./')" | ||||||
|         evaluate_python_code(dangerous_code, authorized_imports=["os"]) |         evaluate_python_code(dangerous_code, authorized_imports=["os"]) | ||||||
| 
 | 
 | ||||||
|  |     def test_can_import_os_if_all_imports_authorized(self): | ||||||
|  |         dangerous_code = "import os; os.listdir('./')" | ||||||
|  |         evaluate_python_code(dangerous_code, authorized_imports=["*"]) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "code, expected_result", |     "code, expected_result", | ||||||
|  | @ -1205,3 +1210,39 @@ class TestPrintContainer: | ||||||
|         pc = PrintContainer() |         pc = PrintContainer() | ||||||
|         pc.append("Hello") |         pc.append("Hello") | ||||||
|         assert len(pc) == 5 |         assert len(pc) == 5 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "module,authorized_imports,expected", | ||||||
|  |     [ | ||||||
|  |         ("os", ["*"], True), | ||||||
|  |         ("AnyModule", ["*"], True), | ||||||
|  |         ("os", ["os"], True), | ||||||
|  |         ("AnyModule", ["AnyModule"], True), | ||||||
|  |         ("Module.os", ["Module"], False), | ||||||
|  |         ("Module.os", ["Module", "os"], True), | ||||||
|  |         ("os.path", ["os"], True), | ||||||
|  |         ("os", ["os.path"], False), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool): | ||||||
|  |     dangerous_patterns = ( | ||||||
|  |         "_os", | ||||||
|  |         "os", | ||||||
|  |         "subprocess", | ||||||
|  |         "_subprocess", | ||||||
|  |         "pty", | ||||||
|  |         "system", | ||||||
|  |         "popen", | ||||||
|  |         "spawn", | ||||||
|  |         "shutil", | ||||||
|  |         "sys", | ||||||
|  |         "pathlib", | ||||||
|  |         "io", | ||||||
|  |         "socket", | ||||||
|  |         "compile", | ||||||
|  |         "eval", | ||||||
|  |         "exec", | ||||||
|  |         "multiprocessing", | ||||||
|  |     ) | ||||||
|  |     assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue