diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 93c64fb..e9c3d9d 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -884,6 +884,7 @@ class CodeAgent(MultiStepAgent): system_prompt: Optional[str] = None, grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, + allow_all_imports: bool = False, planning_interval: Optional[int] = None, use_e2b_executor: bool = False, **kwargs, @@ -899,6 +900,14 @@ class CodeAgent(MultiStepAgent): **kwargs, ) + if ( allow_all_imports and + ( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)): + raise Exception( + f"You passed both allow_all_imports and additional_authorized_imports. Please choose one." + ) + + if allow_all_imports: additional_authorized_imports=['*'] + self.additional_authorized_imports = ( additional_authorized_imports if additional_authorized_imports else [] ) @@ -916,13 +925,16 @@ class CodeAgent(MultiStepAgent): self.python_executor = LocalPythonInterpreter( self.additional_authorized_imports, all_tools ) - self.authorized_imports = list( - set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) - ) - if "{{authorized_imports}}" not in self.system_prompt: - raise AgentError( - "Tag '{{authorized_imports}}' should be provided in the prompt." + if allow_all_imports: + self.authorized_imports = 'all imports without restriction' + else: + self.authorized_imports = list( + set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) ) + if "{{authorized_imports}}" not in self.system_prompt: + raise AgentError( + "Tag '{{authorized_imports}}' should be provided in the prompt." + ) self.system_prompt = self.system_prompt.replace( "{{authorized_imports}}", str(self.authorized_imports) ) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index f476d30..a70b537 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -789,11 +789,15 @@ def evaluate_with(with_node, state, static_tools, custom_tools): def import_modules(expression, state, authorized_imports): def check_module_authorized(module_name): - module_path = module_name.split(".") - module_subpaths = [ - ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) - ] - return any(subpath in authorized_imports for subpath in module_subpaths) + if '*' in authorized_imports: + return True + else: + module_path = module_name.split(".") + module_subpaths = [ + ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) + ] + return any(subpath in authorized_imports for subpath in module_subpaths) + if isinstance(expression, ast.Import): for alias in expression.names: