Fix get safe module (#405)
* expose get_safe_module lazy loading issue * fix get_safe_module lazy imports * add six as explicit tests dependency * log dangerous attributes * refactor get_safe_module test
This commit is contained in:
		
							parent
							
								
									181a500c5d
								
							
						
					
					
						commit
						10407813e8
					
				|  | @ -18,6 +18,7 @@ import ast | |||
| import builtins | ||||
| import difflib | ||||
| import inspect | ||||
| import logging | ||||
| import math | ||||
| import re | ||||
| from collections.abc import Mapping | ||||
|  | @ -31,6 +32,9 @@ import pandas as pd | |||
| from .utils import BASE_BUILTIN_MODULES, truncate_content | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class InterpreterError(ValueError): | ||||
|     """ | ||||
|     An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported | ||||
|  | @ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= | |||
|             pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports | ||||
|             for pattern in dangerous_patterns | ||||
|         ): | ||||
|             logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") | ||||
|             continue | ||||
| 
 | ||||
|         attr_value = getattr(raw_module, attr_name) | ||||
| 
 | ||||
|         try: | ||||
|             attr_value = getattr(raw_module, attr_name) | ||||
|         except ImportError as e: | ||||
|             # lazy / dynamic loading module -> INFO log and skip | ||||
|             logger.info( | ||||
|                 f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}" | ||||
|             ) | ||||
|             continue | ||||
|         # Recursively process nested modules, passing visited set | ||||
|         if isinstance(attr_value, ModuleType): | ||||
|             attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import types | ||||
| import unittest | ||||
| from textwrap import dedent | ||||
| 
 | ||||
|  | @ -24,6 +25,7 @@ from smolagents.local_python_executor import ( | |||
|     InterpreterError, | ||||
|     evaluate_python_code, | ||||
|     fix_final_answer_code, | ||||
|     get_safe_module, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result): | |||
|     state = {} | ||||
|     result, _ = evaluate_python_code(code, {}, state=state) | ||||
|     assert result == expected_result | ||||
| 
 | ||||
| 
 | ||||
| def test_get_safe_module_handle_lazy_imports(): | ||||
|     class FakeModule(types.ModuleType): | ||||
|         def __init__(self, name): | ||||
|             super().__init__(name) | ||||
|             self.non_lazy_attribute = "ok" | ||||
| 
 | ||||
|         def __getattr__(self, name): | ||||
|             if name == "lazy_attribute": | ||||
|                 raise ImportError("lazy import failure") | ||||
|             return super().__getattr__(name) | ||||
| 
 | ||||
|         def __dir__(self): | ||||
|             return super().__dir__() + ["lazy_attribute"] | ||||
| 
 | ||||
|     fake_module = FakeModule("fake_module") | ||||
|     safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set()) | ||||
|     assert not hasattr(safe_module, "lazy_attribute") | ||||
|     assert getattr(safe_module, "non_lazy_attribute") == "ok" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue