call.func parameter (#194)
This commit is contained in:
		
							parent
							
								
									450934ce79
								
							
						
					
					
						commit
						a22c221fa7
					
				|  | @ -591,7 +591,11 @@ def evaluate_call( | |||
|     custom_tools: Dict[str, Callable], | ||||
|     authorized_imports: List[str], | ||||
| ) -> Any: | ||||
|     if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): | ||||
|     if not ( | ||||
|         isinstance(call.func, ast.Attribute) | ||||
|         or isinstance(call.func, ast.Name) | ||||
|         or isinstance(call.func, ast.Subscript) | ||||
|     ): | ||||
|         raise InterpreterError(f"This is not a correct function: {call.func}).") | ||||
|     if isinstance(call.func, ast.Attribute): | ||||
|         obj = evaluate_ast( | ||||
|  | @ -617,6 +621,23 @@ def evaluate_call( | |||
|                 f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." | ||||
|             ) | ||||
| 
 | ||||
|     elif isinstance(call.func, ast.Subscript): | ||||
|         value = evaluate_ast( | ||||
|             call.func.value, state, static_tools, custom_tools, authorized_imports | ||||
|         ) | ||||
|         index = evaluate_ast( | ||||
|             call.func.slice, state, static_tools, custom_tools, authorized_imports | ||||
|         ) | ||||
|         if isinstance(value, (list, tuple)): | ||||
|             func = value[index] | ||||
|         else: | ||||
|             raise InterpreterError( | ||||
|                 f"Cannot subscript object of type {type(value).__name__}" | ||||
|             ) | ||||
| 
 | ||||
|         if not callable(func): | ||||
|             raise InterpreterError(f"This is not a correct function: {call.func}).") | ||||
|         func_name = None | ||||
|     args = [] | ||||
|     for arg in call.args: | ||||
|         if isinstance(arg, ast.Starred): | ||||
|  | @ -726,6 +747,8 @@ def evaluate_name( | |||
|         return state[name.id] | ||||
|     elif name.id in static_tools: | ||||
|         return static_tools[name.id] | ||||
|     elif name.id in custom_tools: | ||||
|         return custom_tools[name.id] | ||||
|     elif name.id in ERRORS: | ||||
|         return ERRORS[name.id] | ||||
|     close_matches = difflib.get_close_matches(name.id, list(state.keys())) | ||||
|  |  | |||
|  | @ -60,6 +60,14 @@ class PythonInterpreterTester(unittest.TestCase): | |||
|             in str(e) | ||||
|         ) | ||||
| 
 | ||||
|     def test_subscript_call(self): | ||||
|         code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)""" | ||||
|         state = {} | ||||
|         result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) | ||||
|         assert result == 64 | ||||
|         assert state["result_foo"] == 8 | ||||
|         assert state["result_boo"] == 64 | ||||
| 
 | ||||
|     def test_evaluate_call(self): | ||||
|         code = "y = add_two(x)" | ||||
|         state = {"x": 3} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue