Improve code execution logs in case of error by showing print outputs (#446)

* Improve code execution logs in case of error by still showing print outputs
* Improve action step testing
* Number steps starting at 1
This commit is contained in:
Aymeric Roucher 2025-01-31 12:34:32 +01:00 committed by GitHub
parent 5f2147a17d
commit f1a9b83443
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 82 additions and 61 deletions

View File

@ -350,7 +350,7 @@ class MultiStepAgent:
)
raise AgentExecutionError(error_msg, self.logger)
def step(self, log_entry: ActionStep) -> Union[None, Any]:
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes. Should return either None if the step is not final."""
pass
@ -427,8 +427,8 @@ You have been provided with these additional arguments, that you can access usin
images (`list[str]`): Paths to image(s).
"""
final_answer = None
self.step_number = 0
while final_answer is None and self.step_number < self.max_steps:
self.step_number = 1
while final_answer is None and self.step_number <= self.max_steps:
step_start_time = time.time()
memory_step = ActionStep(
step_number=self.step_number,
@ -461,7 +461,7 @@ You have been provided with these additional arguments, that you can access usin
self.step_number += 1
yield memory_step
if final_answer is None and self.step_number == self.max_steps:
if final_answer is None and self.step_number == self.max_steps + 1:
error_message = "Reached max steps."
final_answer = self.provide_final_answer(task, images)
final_memory_step = ActionStep(
@ -666,7 +666,7 @@ class ToolCallingAgent(MultiStepAgent):
**kwargs,
)
def step(self, log_entry: ActionStep) -> Union[None, Any]:
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final.
@ -676,7 +676,7 @@ class ToolCallingAgent(MultiStepAgent):
self.input_messages = memory_messages
# Add new step in logs
log_entry.model_input_messages = memory_messages.copy()
memory_step.model_input_messages = memory_messages.copy()
try:
model_message: ChatMessage = self.model(
@ -684,7 +684,7 @@ class ToolCallingAgent(MultiStepAgent):
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
log_entry.model_output_message = model_message
memory_step.model_output_message = model_message
if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
tool_call = model_message.tool_calls[0]
@ -694,7 +694,7 @@ class ToolCallingAgent(MultiStepAgent):
except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
# Execute
self.logger.log(
@ -724,7 +724,7 @@ class ToolCallingAgent(MultiStepAgent):
level=LogLevel.INFO,
)
log_entry.action_output = final_answer
memory_step.action_output = final_answer
return final_answer
else:
if tool_arguments is None:
@ -746,7 +746,7 @@ class ToolCallingAgent(MultiStepAgent):
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
level=LogLevel.INFO,
)
log_entry.observations = updated_information
memory_step.observations = updated_information
return None
@ -831,7 +831,7 @@ class CodeAgent(MultiStepAgent):
)
return self.system_prompt
def step(self, log_entry: ActionStep) -> Union[None, Any]:
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final.
@ -841,7 +841,7 @@ class CodeAgent(MultiStepAgent):
self.input_messages = memory_messages.copy()
# Add new step in logs
log_entry.model_input_messages = memory_messages.copy()
memory_step.model_input_messages = memory_messages.copy()
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
chat_message: ChatMessage = self.model(
@ -849,9 +849,9 @@ class CodeAgent(MultiStepAgent):
stop_sequences=["<end_code>", "Observation:"],
**additional_args,
)
log_entry.model_output_message = chat_message
memory_step.model_output_message = chat_message
model_output = chat_message.content
log_entry.model_output = model_output
memory_step.model_output = model_output
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
@ -868,7 +868,7 @@ class CodeAgent(MultiStepAgent):
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg, self.logger)
log_entry.tool_calls = [
memory_step.tool_calls = [
ToolCall(
name="python_interpreter",
arguments=code_action,
@ -878,7 +878,6 @@ class CodeAgent(MultiStepAgent):
# Execute
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
observation = ""
is_final_answer = False
try:
output, execution_logs, is_final_answer = self.python_executor(
@ -891,8 +890,17 @@ class CodeAgent(MultiStepAgent):
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
observation += "Execution logs:\n" + execution_logs
observation = "Execution logs:\n" + execution_logs
except Exception as e:
if "print_outputs" in self.python_executor.state:
execution_logs = self.python_executor.state["print_outputs"]
if len(execution_logs) > 0:
execution_outputs_console = [
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
memory_step.observations = "Execution logs:\n" + execution_logs
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
error_msg = str(e)
if "Import of " in error_msg and " is not allowed" in error_msg:
self.logger.log(
@ -903,7 +911,7 @@ class CodeAgent(MultiStepAgent):
truncated_output = truncate_content(str(output))
observation += "Last output from code snippet:\n" + truncated_output
log_entry.observations = observation
memory_step.observations = observation
execution_outputs_console += [
Text(
@ -912,7 +920,7 @@ class CodeAgent(MultiStepAgent):
),
]
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
log_entry.action_output = output
memory_step.action_output = output
return output if is_final_answer else None

View File

@ -1283,7 +1283,7 @@ def evaluate_python_code(
expression = ast.parse(code)
except SyntaxError as e:
raise InterpreterError(
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"Code parsing failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
@ -1316,11 +1316,10 @@ def evaluate_python_code(
return e.value, is_final_answer
except Exception as e:
exception_type = type(e).__name__
error_msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
error_msg = (
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
raise InterpreterError(
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {exception_type}:{str(e)}"
)
raise InterpreterError(error_msg)
class LocalPythonInterpreter:

View File

@ -90,29 +90,16 @@ class ActionStep(MemoryStep):
messages.append(
Message(
role=MessageRole.ASSISTANT,
content=[{"type": "text", "text": str([tc.dict() for tc in self.tool_calls])}],
content=[
{
"type": "text",
"text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]),
}
],
)
)
if self.error is not None:
message_content = (
"Error:\n"
+ str(self.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
if self.tool_calls is None:
tool_response_message = Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": message_content}]
)
else:
tool_response_message = Message(
role=MessageRole.TOOL_RESPONSE,
content=[{"type": "text", "text": f"Call id: {self.tool_calls[0].id}\n{message_content}"}],
)
messages.append(tool_response_message)
else:
if self.observations is not None and self.tool_calls is not None:
if self.observations is not None:
messages.append(
Message(
role=MessageRole.TOOL_RESPONSE,
@ -124,6 +111,18 @@ class ActionStep(MemoryStep):
],
)
)
if self.error is not None:
error_message = (
"Error:\n"
+ str(self.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else ""
message_content += error_message
messages.append(
Message(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}])
)
if self.observations_images:
messages.append(
Message(

View File

@ -768,7 +768,6 @@ class OpenAIServerModel(Model):
convert_images_to_image_urls=True,
**kwargs,
)
response = self.client.chat.completions.create(**completion_kwargs)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens

View File

@ -178,6 +178,7 @@ def fake_code_model_error(messages, stop_sequences=None) -> str:
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
print("Flag!")
def error_function():
raise ValueError("error")
@ -393,6 +394,11 @@ class AgentTests(unittest.TestCase):
assert "Code execution failed at line 'error_function()'" in str(agent.memory.steps[1].error)
assert "ValueError" in str(agent.memory.steps)
def test_code_agent_code_error_saves_previous_print_outputs(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
agent.run("What is 2 multiplied by 3.6452?")
assert "Flag!" in str(agent.memory.steps[1].observations)
def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
output = agent.run("What is 2 multiplied by 3.6452?")
@ -410,7 +416,7 @@ class AgentTests(unittest.TestCase):
max_steps=5,
)
answer = agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.memory.steps) == 7
assert len(agent.memory.steps) == 7 # Task step + 5 action steps + Final answer
assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
assert isinstance(answer, str)

View File

@ -1,5 +1,6 @@
import pytest
from smolagents.agents import ToolCall
from smolagents.memory import (
ActionStep,
AgentMemory,
@ -39,7 +40,9 @@ class TestMemoryStep:
def test_action_step_to_messages():
action_step = ActionStep(
model_input_messages=[Message(role=MessageRole.USER, content="Hello")],
tool_calls=None,
tool_calls=[
ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}),
],
start_time=0.0,
end_time=1.0,
step_number=1,
@ -47,12 +50,12 @@ def test_action_step_to_messages():
duration=1.0,
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
model_output="Hi",
observations="Observation",
observations="This is a nice observation",
observations_images=["image1.png"],
action_output="Output",
)
messages = action_step.to_messages()
assert len(messages) == 2
assert len(messages) == 4
for message in messages:
assert isinstance(message, dict)
assert "role" in message
@ -66,14 +69,21 @@ def test_action_step_to_messages():
assert isinstance(content, dict)
assert "type" in content
assert "text" in content
user_message = messages[1]
assert user_message["role"] == MessageRole.USER
assert len(user_message["content"]) == 2
text_content = user_message["content"][0]
message = messages[1]
assert message["role"] == MessageRole.ASSISTANT
assert len(message["content"]) == 1
text_content = message["content"][0]
assert isinstance(text_content, dict)
assert "type" in text_content
assert "text" in text_content
for image_content in user_message["content"][1:]:
observation_message = messages[2]
assert observation_message["role"] == MessageRole.TOOL_RESPONSE
assert "Observation:\nThis is a nice observation" in observation_message["content"][0]["text"]
image_message = messages[3]
image_content = image_message["content"][1]
assert isinstance(image_content, dict)
assert "type" in image_content
assert "image" in image_content