Improve python executor's error logging (#275)

* Improve python executor's error logging
This commit is contained in:
Aymeric Roucher 2025-01-20 15:57:16 +01:00 committed by GitHub
parent 3c18d4d588
commit 7a91123729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 37 deletions

View File

@ -972,16 +972,8 @@ class CodeAgent(MultiStepAgent):
] ]
observation += "Execution logs:\n" + execution_logs observation += "Execution logs:\n" + execution_logs
except Exception as e: except Exception as e:
if isinstance(e, SyntaxError): error_msg = str(e)
error_msg = ( if "Import of " in error_msg and " is not allowed" in error_msg:
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
else:
error_msg = str(e)
if "Import of " in str(e) and " is not allowed" in str(e):
self.logger.log( self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO, level=LogLevel.INFO,

View File

@ -554,7 +554,7 @@ def evaluate_call(
func = ERRORS[func_name] func = ERRORS[func_name]
else: else:
raise InterpreterError( raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})."
) )
elif isinstance(call.func, ast.Subscript): elif isinstance(call.func, ast.Subscript):
@ -1245,7 +1245,16 @@ def evaluate_python_code(
updated by this function to contain all variables as they are evaluated. updated by this function to contain all variables as they are evaluated.
The print outputs will be stored in the state under the key 'print_outputs'. The print outputs will be stored in the state under the key 'print_outputs'.
""" """
expression = ast.parse(code) try:
expression = ast.parse(code)
except SyntaxError as e:
raise InterpreterError(
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
if state is None: if state is None:
state = {} state = {}
if static_tools is None: if static_tools is None:
@ -1273,10 +1282,13 @@ def evaluate_python_code(
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
is_final_answer = True is_final_answer = True
return e.value, is_final_answer return e.value, is_final_answer
except InterpreterError as e: except Exception as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) exception_type = type(e).__name__
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" error_msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
raise InterpreterError(msg) error_msg = (
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {exception_type}:{str(e)}"
)
raise InterpreterError(error_msg)
class LocalPythonInterpreter: class LocalPythonInterpreter:

View File

@ -168,7 +168,9 @@ class Tool:
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
) )
json_schema = _convert_type_hints_to_json_schema(self.forward) json_schema = _convert_type_hints_to_json_schema(
self.forward
) # This function will raise an error on missing docstrings, contrary to get_json_schema
for key, value in self.inputs.items(): for key, value in self.inputs.items():
if "nullable" in value: if "nullable" in value:
assert key in json_schema and "nullable" in json_schema[key], ( assert key in json_schema and "nullable" in json_schema[key], (
@ -885,6 +887,16 @@ class ToolCollection:
yield cls(tools) yield cls(tools)
def get_tool_json_schema(tool_function):
tool_json_schema = get_json_schema(tool_function)["function"]
tool_parameters = tool_json_schema["parameters"]
inputs_schema = tool_parameters["properties"]
for input_name in inputs_schema:
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
inputs_schema[input_name]["nullable"] = True
return tool_json_schema
def tool(tool_function: Callable) -> Tool: def tool(tool_function: Callable) -> Tool:
""" """
Converts a function into an instance of a Tool subclass. Converts a function into an instance of a Tool subclass.
@ -893,12 +905,19 @@ def tool(tool_function: Callable) -> Tool:
tool_function: Your function. Should have type hints for each input and a type hint for the output. tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described. Should also have a docstring description including an 'Args:' part where each argument is described.
""" """
parameters = get_json_schema(tool_function)["function"] tool_json_schema = get_tool_json_schema(tool_function)
if "return" not in parameters: if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
class SimpleTool(Tool): class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function): def __init__(
self,
name: str,
description: str,
inputs: Dict[str, Dict[str, str]],
output_type: str,
function: Callable,
):
self.name = name self.name = name
self.description = description self.description = description
self.inputs = inputs self.inputs = inputs
@ -907,10 +926,10 @@ def tool(tool_function: Callable) -> Tool:
self.is_initialized = True self.is_initialized = True
simple_tool = SimpleTool( simple_tool = SimpleTool(
parameters["name"], name=tool_json_schema["name"],
parameters["description"], description=tool_json_schema["description"],
parameters["parameters"]["properties"], inputs=tool_json_schema["parameters"]["properties"],
parameters["return"]["type"], output_type=tool_json_schema["return"]["type"],
function=tool_function, function=tool_function,
) )
original_signature = inspect.signature(tool_function) original_signature = inspect.signature(tool_function)

View File

@ -332,7 +332,7 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs) assert "Code execution failed at line 'print = 2' due to: InterpreterError" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self): def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
@ -426,7 +426,7 @@ class AgentTests(unittest.TestCase):
with console.capture() as capture: with console.capture() as capture:
agent.run("Count to 3") agent.run("Count to 3")
str_output = capture.get() str_output = capture.get()
assert "import under `additional_authorized_imports`" in str_output assert "Consider passing said import under" in str_output.replace("\n", "")
def test_multiagents(self): def test_multiagents(self):
class FakeModelMultiagentsManagerAgent: class FakeModelMultiagentsManagerAgent:

View File

@ -630,12 +630,9 @@ counts += 1"""
assert "Cannot add non-list value 1 to a list." in str(e) assert "Cannot add non-list value 1 to a list." in str(e)
def test_error_highlights_correct_line_of_code(self): def test_error_highlights_correct_line_of_code(self):
code = """# Ok this is a very long code code = """a = 1
# It has many commented lines
a = 1
b = 2 b = 2
# Here is another piece
counts = [1, 2, 3] counts = [1, 2, 3]
counts += 1 counts += 1
b += 1""" b += 1"""
@ -643,12 +640,22 @@ b += 1"""
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Code execution failed at line 'counts += 1" in str(e) assert "Code execution failed at line 'counts += 1" in str(e)
def test_error_type_returned_in_function_call(self):
code = """def error_function():
raise ValueError("error")
error_function()"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "error" in str(e)
assert "ValueError" in str(e)
def test_assert(self): def test_assert(self):
code = """ code = """
assert 1 == 1 assert 1 == 1
assert 1 == 2 assert 1 == 2
""" """
with pytest.raises(AssertionError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "1 == 2" in str(e) and "1 == 1" not in str(e) assert "1 == 2" in str(e) and "1 == 1" not in str(e)
@ -845,6 +852,13 @@ shift_intervals
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}) result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"} assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
def test_syntax_error_points_error(self):
code = "a = ;"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "SyntaxError" in str(e)
assert " ^" in str(e)
def test_fix_final_answer_code(self): def test_fix_final_answer_code(self):
test_cases = [ test_cases = [
( (
@ -890,18 +904,16 @@ shift_intervals
# Import of whitelisted modules should succeed but dangerous submodules should not exist # Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')" code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(AttributeError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code) evaluate_python_code(code)
assert "module 'random' has no attribute '_os'" in str(e) assert "AttributeError:module 'random' has no attribute '_os'" in str(e)
code = "import doctest;doctest.inspect.os.system('echo bad command passed')" code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(AttributeError): with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["doctest"]) evaluate_python_code(code, authorized_imports=["doctest"])
def test_close_matches_subscript(self): def test_close_matches_subscript(self):
code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]' code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
evaluate_python_code(code) evaluate_python_code(code)
assert "Maybe you meant one of these indexes instead" in str( assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "")
e
) and "['Bhutan']" in str(e).replace("\\", "")

View File

@ -374,6 +374,20 @@ class ToolTests(unittest.TestCase):
GetWeatherTool3() GetWeatherTool3()
assert "Nullable" in str(e) assert "Nullable" in str(e)
def test_tool_default_parameters_is_nullable(self):
@tool
def get_weather(location: str, celsius: bool = False) -> str:
"""
Get weather in the next days at given location.
Args:
location: the location
celsius: is the temperature given in celsius
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert get_weather.inputs["celsius"]["nullable"]
@pytest.fixture @pytest.fixture
def mock_server_parameters(): def mock_server_parameters():