Fix get safe module (#405)
* expose get_safe_module lazy loading issue * fix get_safe_module lazy imports * add six as explicit tests dependency * log dangerous attributes * refactor get_safe_module test
This commit is contained in:
		
							parent
							
								
									181a500c5d
								
							
						
					
					
						commit
						10407813e8
					
				|  | @ -18,6 +18,7 @@ import ast | ||||||
| import builtins | import builtins | ||||||
| import difflib | import difflib | ||||||
| import inspect | import inspect | ||||||
|  | import logging | ||||||
| import math | import math | ||||||
| import re | import re | ||||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||||
|  | @ -31,6 +32,9 @@ import pandas as pd | ||||||
| from .utils import BASE_BUILTIN_MODULES, truncate_content | from .utils import BASE_BUILTIN_MODULES, truncate_content | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class InterpreterError(ValueError): | class InterpreterError(ValueError): | ||||||
|     """ |     """ | ||||||
|     An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported |     An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported | ||||||
|  | @ -960,10 +964,17 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited= | ||||||
|             pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports |             pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports | ||||||
|             for pattern in dangerous_patterns |             for pattern in dangerous_patterns | ||||||
|         ): |         ): | ||||||
|  |             logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") | ||||||
|             continue |             continue | ||||||
| 
 | 
 | ||||||
|         attr_value = getattr(raw_module, attr_name) |         try: | ||||||
| 
 |             attr_value = getattr(raw_module, attr_name) | ||||||
|  |         except ImportError as e: | ||||||
|  |             # lazy / dynamic loading module -> INFO log and skip | ||||||
|  |             logger.info( | ||||||
|  |                 f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}" | ||||||
|  |             ) | ||||||
|  |             continue | ||||||
|         # Recursively process nested modules, passing visited set |         # Recursively process nested modules, passing visited set | ||||||
|         if isinstance(attr_value, ModuleType): |         if isinstance(attr_value, ModuleType): | ||||||
|             attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) |             attr_value = get_safe_module(attr_value, dangerous_patterns, authorized_imports, visited=visited) | ||||||
|  |  | ||||||
|  | @ -13,6 +13,7 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
|  | import types | ||||||
| import unittest | import unittest | ||||||
| from textwrap import dedent | from textwrap import dedent | ||||||
| 
 | 
 | ||||||
|  | @ -24,6 +25,7 @@ from smolagents.local_python_executor import ( | ||||||
|     InterpreterError, |     InterpreterError, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
|     fix_final_answer_code, |     fix_final_answer_code, | ||||||
|  |     get_safe_module, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -1069,3 +1071,23 @@ def test_evaluate_augassign_custom(operator, expected_result): | ||||||
|     state = {} |     state = {} | ||||||
|     result, _ = evaluate_python_code(code, {}, state=state) |     result, _ = evaluate_python_code(code, {}, state=state) | ||||||
|     assert result == expected_result |     assert result == expected_result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_safe_module_handle_lazy_imports(): | ||||||
|  |     class FakeModule(types.ModuleType): | ||||||
|  |         def __init__(self, name): | ||||||
|  |             super().__init__(name) | ||||||
|  |             self.non_lazy_attribute = "ok" | ||||||
|  | 
 | ||||||
|  |         def __getattr__(self, name): | ||||||
|  |             if name == "lazy_attribute": | ||||||
|  |                 raise ImportError("lazy import failure") | ||||||
|  |             return super().__getattr__(name) | ||||||
|  | 
 | ||||||
|  |         def __dir__(self): | ||||||
|  |             return super().__dir__() + ["lazy_attribute"] | ||||||
|  | 
 | ||||||
|  |     fake_module = FakeModule("fake_module") | ||||||
|  |     safe_module = get_safe_module(fake_module, dangerous_patterns=[], authorized_imports=set()) | ||||||
|  |     assert not hasattr(safe_module, "lazy_attribute") | ||||||
|  |     assert getattr(safe_module, "non_lazy_attribute") == "ok" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue