From 7a9112372911cfc98d875de890f791d8c3bec593 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:57:16 +0100 Subject: [PATCH] Improve python executor's error logging (#275) * Improve python executor's error logging --- src/smolagents/agents.py | 12 ++------- src/smolagents/local_python_executor.py | 24 ++++++++++++----- src/smolagents/tools.py | 35 +++++++++++++++++++------ tests/test_agents.py | 4 +-- tests/test_python_interpreter.py | 34 ++++++++++++++++-------- tests/test_tools.py | 14 ++++++++++ 6 files changed, 86 insertions(+), 37 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 22af248..7b10279 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -972,16 +972,8 @@ class CodeAgent(MultiStepAgent): ] observation += "Execution logs:\n" + execution_logs except Exception as e: - if isinstance(e, SyntaxError): - error_msg = ( - f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n" - f"{e.text}" - f"{' ' * (e.offset or 0)}^\n" - f"Error: {str(e)}" - ) - else: - error_msg = str(e) - if "Import of " in str(e) and " is not allowed" in str(e): + error_msg = str(e) + if "Import of " in error_msg and " is not allowed" in error_msg: self.logger.log( "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", level=LogLevel.INFO, diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index e46d87a..1b47d28 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -554,7 +554,7 @@ def evaluate_call( func = ERRORS[func_name] else: raise InterpreterError( - f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." + f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})." ) elif isinstance(call.func, ast.Subscript): @@ -1245,7 +1245,16 @@ def evaluate_python_code( updated by this function to contain all variables as they are evaluated. The print outputs will be stored in the state under the key 'print_outputs'. """ - expression = ast.parse(code) + try: + expression = ast.parse(code) + except SyntaxError as e: + raise InterpreterError( + f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n" + f"{e.text}" + f"{' ' * (e.offset or 0)}^\n" + f"Error: {str(e)}" + ) + if state is None: state = {} if static_tools is None: @@ -1273,10 +1282,13 @@ def evaluate_python_code( state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) is_final_answer = True return e.value, is_final_answer - except InterpreterError as e: - msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) - msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" - raise InterpreterError(msg) + except Exception as e: + exception_type = type(e).__name__ + error_msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) + error_msg = ( + f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {exception_type}:{str(e)}" + ) + raise InterpreterError(error_msg) class LocalPythonInterpreter: diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index b9f4141..80301cc 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -168,7 +168,9 @@ class Tool: "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." ) - json_schema = _convert_type_hints_to_json_schema(self.forward) + json_schema = _convert_type_hints_to_json_schema( + self.forward + ) # This function will raise an error on missing docstrings, contrary to get_json_schema for key, value in self.inputs.items(): if "nullable" in value: assert key in json_schema and "nullable" in json_schema[key], ( @@ -885,6 +887,16 @@ class ToolCollection: yield cls(tools) +def get_tool_json_schema(tool_function): + tool_json_schema = get_json_schema(tool_function)["function"] + tool_parameters = tool_json_schema["parameters"] + inputs_schema = tool_parameters["properties"] + for input_name in inputs_schema: + if "required" not in tool_parameters or input_name not in tool_parameters["required"]: + inputs_schema[input_name]["nullable"] = True + return tool_json_schema + + def tool(tool_function: Callable) -> Tool: """ Converts a function into an instance of a Tool subclass. @@ -893,12 +905,19 @@ def tool(tool_function: Callable) -> Tool: tool_function: Your function. Should have type hints for each input and a type hint for the output. Should also have a docstring description including an 'Args:' part where each argument is described. """ - parameters = get_json_schema(tool_function)["function"] - if "return" not in parameters: + tool_json_schema = get_tool_json_schema(tool_function) + if "return" not in tool_json_schema: raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") class SimpleTool(Tool): - def __init__(self, name, description, inputs, output_type, function): + def __init__( + self, + name: str, + description: str, + inputs: Dict[str, Dict[str, str]], + output_type: str, + function: Callable, + ): self.name = name self.description = description self.inputs = inputs @@ -907,10 +926,10 @@ def tool(tool_function: Callable) -> Tool: self.is_initialized = True simple_tool = SimpleTool( - parameters["name"], - parameters["description"], - parameters["parameters"]["properties"], - parameters["return"]["type"], + name=tool_json_schema["name"], + description=tool_json_schema["description"], + inputs=tool_json_schema["parameters"]["properties"], + output_type=tool_json_schema["return"]["type"], function=tool_function, ) original_signature = inspect.signature(tool_function) diff --git a/tests/test_agents.py b/tests/test_agents.py index 2a56a0f..7bc2704 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -332,7 +332,7 @@ class AgentTests(unittest.TestCase): output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, AgentText) assert output == "got an error" - assert "Code execution failed at line 'print = 2' because of" in str(agent.logs) + assert "Code execution failed at line 'print = 2' due to: InterpreterError" in str(agent.logs) def test_code_agent_syntax_error_show_offending_lines(self): agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) @@ -426,7 +426,7 @@ class AgentTests(unittest.TestCase): with console.capture() as capture: agent.run("Count to 3") str_output = capture.get() - assert "import under `additional_authorized_imports`" in str_output + assert "Consider passing said import under" in str_output.replace("\n", "") def test_multiagents(self): class FakeModelMultiagentsManagerAgent: diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 540720f..9fcd218 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -630,12 +630,9 @@ counts += 1""" assert "Cannot add non-list value 1 to a list." in str(e) def test_error_highlights_correct_line_of_code(self): - code = """# Ok this is a very long code -# It has many commented lines -a = 1 + code = """a = 1 b = 2 -# Here is another piece counts = [1, 2, 3] counts += 1 b += 1""" @@ -643,12 +640,22 @@ b += 1""" evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert "Code execution failed at line 'counts += 1" in str(e) + def test_error_type_returned_in_function_call(self): + code = """def error_function(): + raise ValueError("error") + +error_function()""" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code) + assert "error" in str(e) + assert "ValueError" in str(e) + def test_assert(self): code = """ assert 1 == 1 assert 1 == 2 """ - with pytest.raises(AssertionError) as e: + with pytest.raises(InterpreterError) as e: evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert "1 == 2" in str(e) and "1 == 1" not in str(e) @@ -845,6 +852,13 @@ shift_intervals result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}) assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"} + def test_syntax_error_points_error(self): + code = "a = ;" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code) + assert "SyntaxError" in str(e) + assert " ^" in str(e) + def test_fix_final_answer_code(self): test_cases = [ ( @@ -890,18 +904,16 @@ shift_intervals # Import of whitelisted modules should succeed but dangerous submodules should not exist code = "import random;random._os.system('echo bad command passed')" - with pytest.raises(AttributeError) as e: + with pytest.raises(InterpreterError) as e: evaluate_python_code(code) - assert "module 'random' has no attribute '_os'" in str(e) + assert "AttributeError:module 'random' has no attribute '_os'" in str(e) code = "import doctest;doctest.inspect.os.system('echo bad command passed')" - with pytest.raises(AttributeError): + with pytest.raises(InterpreterError): evaluate_python_code(code, authorized_imports=["doctest"]) def test_close_matches_subscript(self): code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]' with pytest.raises(Exception) as e: evaluate_python_code(code) - assert "Maybe you meant one of these indexes instead" in str( - e - ) and "['Bhutan']" in str(e).replace("\\", "") + assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "") diff --git a/tests/test_tools.py b/tests/test_tools.py index 67bd2f6..0caefd2 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -374,6 +374,20 @@ class ToolTests(unittest.TestCase): GetWeatherTool3() assert "Nullable" in str(e) + def test_tool_default_parameters_is_nullable(self): + @tool + def get_weather(location: str, celsius: bool = False) -> str: + """ + Get weather in the next days at given location. + + Args: + location: the location + celsius: is the temperature given in celsius + """ + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + assert get_weather.inputs["celsius"]["nullable"] + @pytest.fixture def mock_server_parameters():