minor fix for console in AgentLogger (#303)
* minor fix for console in AgentLogger
This commit is contained in:
		
							parent
							
								
									2c43546d3c
								
							
						
					
					
						commit
						ec45d6766a
					
				|  | @ -17,11 +17,10 @@ | ||||||
| import time | import time | ||||||
| from collections import deque | from collections import deque | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from enum import IntEnum |  | ||||||
| from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union | from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union | ||||||
| 
 | 
 | ||||||
| from rich import box | from rich import box | ||||||
| from rich.console import Console, Group | from rich.console import Group | ||||||
| from rich.panel import Panel | from rich.panel import Panel | ||||||
| from rich.rule import Rule | from rich.rule import Rule | ||||||
| from rich.syntax import Syntax | from rich.syntax import Syntax | ||||||
|  | @ -62,9 +61,10 @@ from .utils import ( | ||||||
|     AgentError, |     AgentError, | ||||||
|     AgentExecutionError, |     AgentExecutionError, | ||||||
|     AgentGenerationError, |     AgentGenerationError, | ||||||
|  |     AgentLogger, | ||||||
|     AgentMaxStepsError, |     AgentMaxStepsError, | ||||||
|     AgentParsingError, |     AgentParsingError, | ||||||
|     console, |     LogLevel, | ||||||
|     parse_code_blobs, |     parse_code_blobs, | ||||||
|     parse_json_tool_call, |     parse_json_tool_call, | ||||||
|     truncate_content, |     truncate_content, | ||||||
|  | @ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions( | ||||||
| YELLOW_HEX = "#d4b702" | YELLOW_HEX = "#d4b702" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LogLevel(IntEnum): |  | ||||||
|     ERROR = 0  # Only errors |  | ||||||
|     INFO = 1  # Normal output (default) |  | ||||||
|     DEBUG = 2  # Detailed output |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class AgentLogger: |  | ||||||
|     def __init__(self, level: LogLevel = LogLevel.INFO): |  | ||||||
|         self.level = level |  | ||||||
|         self.console = Console() |  | ||||||
| 
 |  | ||||||
|     def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): |  | ||||||
|         if level <= self.level: |  | ||||||
|             console.print(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MultiStepAgent: | class MultiStepAgent: | ||||||
|     """ |     """ | ||||||
|     Agent class that solves the given task step by step, using the ReAct framework: |     Agent class that solves the given task step by step, using the ReAct framework: | ||||||
|  | @ -353,7 +337,8 @@ class MultiStepAgent: | ||||||
|             )  # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output |             )  # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output | ||||||
|         except Exception: |         except Exception: | ||||||
|             raise AgentParsingError( |             raise AgentParsingError( | ||||||
|                 f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" |                 f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!", | ||||||
|  |                 self.logger, | ||||||
|             ) |             ) | ||||||
|         return rationale.strip(), action.strip() |         return rationale.strip(), action.strip() | ||||||
| 
 | 
 | ||||||
|  | @ -391,7 +376,7 @@ class MultiStepAgent: | ||||||
|         available_tools = {**self.tools, **self.managed_agents} |         available_tools = {**self.tools, **self.managed_agents} | ||||||
|         if tool_name not in available_tools: |         if tool_name not in available_tools: | ||||||
|             error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." |             error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." | ||||||
|             raise AgentExecutionError(error_msg) |             raise AgentExecutionError(error_msg, self.logger) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             if isinstance(arguments, str): |             if isinstance(arguments, str): | ||||||
|  | @ -409,7 +394,7 @@ class MultiStepAgent: | ||||||
|                     observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) |                     observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) | ||||||
|             else: |             else: | ||||||
|                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." |                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg, self.logger, self.logger) | ||||||
|             return observation |             return observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             if tool_name in self.tools: |             if tool_name in self.tools: | ||||||
|  | @ -418,13 +403,13 @@ class MultiStepAgent: | ||||||
|                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" |                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" | ||||||
|                     f"As a reminder, this tool's description is the following:\n{tool_description}" |                     f"As a reminder, this tool's description is the following:\n{tool_description}" | ||||||
|                 ) |                 ) | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg, self.logger) | ||||||
|             elif tool_name in self.managed_agents: |             elif tool_name in self.managed_agents: | ||||||
|                 error_msg = ( |                 error_msg = ( | ||||||
|                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" |                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" | ||||||
|                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" |                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" | ||||||
|                 ) |                 ) | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg, self.logger) | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: ActionStep) -> Union[None, Any]: |     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||||
|         """To be implemented in children classes. Should return either None if the step is not final.""" |         """To be implemented in children classes. Should return either None if the step is not final.""" | ||||||
|  | @ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin | ||||||
| 
 | 
 | ||||||
|         if final_answer is None and self.step_number == self.max_steps: |         if final_answer is None and self.step_number == self.max_steps: | ||||||
|             error_message = "Reached max steps." |             error_message = "Reached max steps." | ||||||
|             final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) |             self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) | ||||||
|  | @ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|             tool_arguments = tool_call.function.arguments |             tool_arguments = tool_call.function.arguments | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}") |             raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] |         log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] | ||||||
| 
 | 
 | ||||||
|  | @ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] |         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] | ||||||
|         self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) |         self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) | ||||||
|         if "{{authorized_imports}}" not in system_prompt: |         if "{{authorized_imports}}" not in system_prompt: | ||||||
|             raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.") |             raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.") | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             tools=tools, |             tools=tools, | ||||||
|             model=model, |             model=model, | ||||||
|  | @ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|             ).content |             ).content | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating model output:\n{e}") |             raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) | ||||||
| 
 | 
 | ||||||
|         self.logger.log( |         self.logger.log( | ||||||
|             Group( |             Group( | ||||||
|  | @ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|             code_action = fix_final_answer_code(parse_code_blobs(llm_output)) |             code_action = fix_final_answer_code(parse_code_blobs(llm_output)) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." |             error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg, self.logger) | ||||||
| 
 | 
 | ||||||
|         log_entry.tool_calls = [ |         log_entry.tool_calls = [ | ||||||
|             ToolCall( |             ToolCall( | ||||||
|  | @ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|                     "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", |                     "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", | ||||||
|                     level=LogLevel.INFO, |                     level=LogLevel.INFO, | ||||||
|                 ) |                 ) | ||||||
|             raise AgentExecutionError(error_msg) |             raise AgentExecutionError(error_msg, self.logger) | ||||||
| 
 | 
 | ||||||
|         truncated_output = truncate_content(str(output)) |         truncated_output = truncate_content(str(output)) | ||||||
|         observation += "Last output from code snippet:\n" + truncated_output |         observation += "Last output from code snippet:\n" + truncated_output | ||||||
|  |  | ||||||
|  | @ -21,6 +21,7 @@ import inspect | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| import types | import types | ||||||
|  | from enum import IntEnum | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from typing import Dict, Tuple, Union | from typing import Dict, Tuple, Union | ||||||
| 
 | 
 | ||||||
|  | @ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class LogLevel(IntEnum): | ||||||
|  |     ERROR = 0  # Only errors | ||||||
|  |     INFO = 1  # Normal output (default) | ||||||
|  |     DEBUG = 2  # Detailed output | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AgentLogger: | ||||||
|  |     def __init__(self, level: LogLevel = LogLevel.INFO): | ||||||
|  |         self.level = level | ||||||
|  |         self.console = Console() | ||||||
|  | 
 | ||||||
|  |     def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): | ||||||
|  |         if level <= self.level: | ||||||
|  |             self.console.print(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class AgentError(Exception): | class AgentError(Exception): | ||||||
|     """Base class for other agent-related exceptions""" |     """Base class for other agent-related exceptions""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, message): |     def __init__(self, message, logger: AgentLogger): | ||||||
|         super().__init__(message) |         super().__init__(message) | ||||||
|         self.message = message |         self.message = message | ||||||
|         console.print(f"[bold red]{message}[/bold red]") |         logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentParsingError(AgentError): | class AgentParsingError(AgentError): | ||||||
|  |  | ||||||
|  | @ -421,9 +421,7 @@ class AgentTests(unittest.TestCase): | ||||||
|     def test_code_agent_missing_import_triggers_advice_in_error_log(self): |     def test_code_agent_missing_import_triggers_advice_in_error_log(self): | ||||||
|         agent = CodeAgent(tools=[], model=fake_code_model_import) |         agent = CodeAgent(tools=[], model=fake_code_model_import) | ||||||
| 
 | 
 | ||||||
|         from smolagents.agents import console |         with agent.logger.console.capture() as capture: | ||||||
| 
 |  | ||||||
|         with console.capture() as capture: |  | ||||||
|             agent.run("Count to 3") |             agent.run("Count to 3") | ||||||
|         str_output = capture.get() |         str_output = capture.get() | ||||||
|         assert "Consider passing said import under" in str_output.replace("\n", "") |         assert "Consider passing said import under" in str_output.replace("\n", "") | ||||||
|  |  | ||||||
|  | @ -27,6 +27,7 @@ from smolagents.models import ( | ||||||
|     ChatMessageToolCall, |     ChatMessageToolCall, | ||||||
|     ChatMessageToolCallDefinition, |     ChatMessageToolCallDefinition, | ||||||
| ) | ) | ||||||
|  | from smolagents.utils import AgentLogger, LogLevel | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class FakeLLMModel: | class FakeLLMModel: | ||||||
|  | @ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase): | ||||||
|         self.assertEqual(final_message.content["mime_type"], "image/png") |         self.assertEqual(final_message.content["mime_type"], "image/png") | ||||||
| 
 | 
 | ||||||
|     def test_streaming_with_agent_error(self): |     def test_streaming_with_agent_error(self): | ||||||
|  |         logger = AgentLogger(level=LogLevel.INFO) | ||||||
|  | 
 | ||||||
|         def dummy_model(prompt, **kwargs): |         def dummy_model(prompt, **kwargs): | ||||||
|             raise AgentError("Simulated agent error") |             raise AgentError("Simulated agent error", logger) | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue