Agents deserve freedom. Freedom is the path to success! additional_authorized_imports=['*'] (#129)

* Add an option to authorize all imports
This commit is contained in:
joaopauloschuler 2025-01-13 12:27:42 -03:00 committed by GitHub
parent c0496dc6bc
commit a0b4350409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 11 deletions

View File

@ -884,6 +884,7 @@ class CodeAgent(MultiStepAgent):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
allow_all_imports: bool = False,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
use_e2b_executor: bool = False, use_e2b_executor: bool = False,
**kwargs, **kwargs,
@ -899,6 +900,14 @@ class CodeAgent(MultiStepAgent):
**kwargs, **kwargs,
) )
if ( allow_all_imports and
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
raise Exception(
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
)
if allow_all_imports: additional_authorized_imports=['*']
self.additional_authorized_imports = ( self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else [] additional_authorized_imports if additional_authorized_imports else []
) )
@ -916,6 +925,9 @@ class CodeAgent(MultiStepAgent):
self.python_executor = LocalPythonInterpreter( self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports, all_tools self.additional_authorized_imports, all_tools
) )
if allow_all_imports:
self.authorized_imports = 'all imports without restriction'
else:
self.authorized_imports = list( self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
) )

View File

@ -789,12 +789,16 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
def import_modules(expression, state, authorized_imports): def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name): def check_module_authorized(module_name):
if '*' in authorized_imports:
return True
else:
module_path = module_name.split(".") module_path = module_name.split(".")
module_subpaths = [ module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1) ".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
] ]
return any(subpath in authorized_imports for subpath in module_subpaths) return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import): if isinstance(expression, ast.Import):
for alias in expression.names: for alias in expression.names:
if check_module_authorized(alias.name): if check_module_authorized(alias.name):