Agents deserve freedom. Freedom is the path to success! additional_authorized_imports=['*'] (#129)
* Add an option to authorize all imports
This commit is contained in:
parent
c0496dc6bc
commit
a0b4350409
|
@ -884,6 +884,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
|
allow_all_imports: bool = False,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
use_e2b_executor: bool = False,
|
use_e2b_executor: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -899,6 +900,14 @@ class CodeAgent(MultiStepAgent):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ( allow_all_imports and
|
||||||
|
( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
|
||||||
|
raise Exception(
|
||||||
|
f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_all_imports: additional_authorized_imports=['*']
|
||||||
|
|
||||||
self.additional_authorized_imports = (
|
self.additional_authorized_imports = (
|
||||||
additional_authorized_imports if additional_authorized_imports else []
|
additional_authorized_imports if additional_authorized_imports else []
|
||||||
)
|
)
|
||||||
|
@ -916,6 +925,9 @@ class CodeAgent(MultiStepAgent):
|
||||||
self.python_executor = LocalPythonInterpreter(
|
self.python_executor = LocalPythonInterpreter(
|
||||||
self.additional_authorized_imports, all_tools
|
self.additional_authorized_imports, all_tools
|
||||||
)
|
)
|
||||||
|
if allow_all_imports:
|
||||||
|
self.authorized_imports = 'all imports without restriction'
|
||||||
|
else:
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||||
)
|
)
|
||||||
|
|
|
@ -789,12 +789,16 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||||
|
|
||||||
def import_modules(expression, state, authorized_imports):
|
def import_modules(expression, state, authorized_imports):
|
||||||
def check_module_authorized(module_name):
|
def check_module_authorized(module_name):
|
||||||
|
if '*' in authorized_imports:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
module_subpaths = [
|
module_subpaths = [
|
||||||
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
||||||
]
|
]
|
||||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
|
||||||
|
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if check_module_authorized(alias.name):
|
if check_module_authorized(alias.name):
|
||||||
|
|
Loading…
Reference in New Issue