Agents deserve freedom. Freedom is the path to success! additional_authorized_imports=['*'] (#129)

* Add an option to authorize all imports
This commit is contained in:
joaopauloschuler 2025-01-13 12:27:42 -03:00 committed by GitHub
parent c0496dc6bc
commit a0b4350409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 11 deletions

View File

@ -884,6 +884,7 @@ class CodeAgent(MultiStepAgent):
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
allow_all_imports: bool = False,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
**kwargs,
@ -899,6 +900,14 @@ class CodeAgent(MultiStepAgent):
**kwargs,
)
if ( allow_all_imports and
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
raise Exception(
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
)
if allow_all_imports: additional_authorized_imports=['*']
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
@ -916,13 +925,16 @@ class CodeAgent(MultiStepAgent):
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports, all_tools
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
if allow_all_imports:
self.authorized_imports = 'all imports without restriction'
else:
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in self.system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", str(self.authorized_imports)
)

View File

@ -789,11 +789,15 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
module_path = module_name.split(".")
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
return any(subpath in authorized_imports for subpath in module_subpaths)
if '*' in authorized_imports:
return True
else:
module_path = module_name.split(".")
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names: