From fec65e154a41d97d0d73613443514f86e220b1c3 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 10 Jan 2025 23:46:22 +0100 Subject: [PATCH] More flexible verbosity level (#150) --- docs/source/en/examples/rag.md | 2 +- examples/benchmark.ipynb | 27 ++------ examples/rag.py | 2 +- src/smolagents/agents.py | 119 +++++++++++++++++++++------------ src/smolagents/monitoring.py | 7 +- 5 files changed, 86 insertions(+), 71 deletions(-) diff --git a/docs/source/en/examples/rag.md b/docs/source/en/examples/rag.md index acbdf14..46ae7b7 100644 --- a/docs/source/en/examples/rag.md +++ b/docs/source/en/examples/rag.md @@ -137,7 +137,7 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m from smolagents import HfApiModel, CodeAgent agent = CodeAgent( - tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True + tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2 ) ``` diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index 02d7b7b..7a7b776 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -172,7 +172,7 @@ "[132 rows x 4 columns]" ] }, - "execution_count": 21, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -195,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -398,23 +398,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 27836.90it/s]\n", - " 16%|█▌ | 21/132 [02:18<07:35, 4.11s/it]" - ] - } - ], + "outputs": [], "source": [ "open_model_ids = [\n", " \"meta-llama/Llama-3.3-70B-Instruct\",\n", @@ -423,6 +407,7 @@ " \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n", " \"meta-llama/Llama-3.2-3B-Instruct\",\n", " \"meta-llama/Llama-3.1-8B-Instruct\",\n", + " \"mistralai/Mistral-Nemo-Instruct-2407\",\n", " # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n", " # \"meta-llama/Llama-3.1-70B-Instruct\",\n", "]\n", @@ -1010,7 +995,7 @@ ], "metadata": { "kernelspec": { - "display_name": "test", + "display_name": "compare-agents", "language": "python", "name": "python3" }, diff --git a/examples/rag.py b/examples/rag.py index bd40854..4096d57 100644 --- a/examples/rag.py +++ b/examples/rag.py @@ -60,7 +60,7 @@ from smolagents import HfApiModel, CodeAgent retriever_tool = RetrieverTool(docs_processed) agent = CodeAgent( - tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True + tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2 ) agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?") diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index a81b80d..66bbdef 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -18,12 +18,14 @@ import time from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from enum import IntEnum from rich import box from rich.console import Group from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax from rich.text import Text +from rich.console import Console from .default_tools import FinalAnswerTool, TOOL_MAPPING from .e2b_executor import E2BExecutor @@ -164,6 +166,22 @@ def format_prompt_with_managed_agents_descriptions( YELLOW_HEX = "#d4b702" +class LogLevel(IntEnum): + ERROR = 0 # Only errors + INFO = 1 # Normal output (default) + DEBUG = 2 # Detailed output + + +class AgentLogger: + def __init__(self, level: LogLevel = LogLevel.INFO): + self.level = level + self.console = Console() + + def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): + if level <= self.level: + console.print(*args, **kwargs) + + class MultiStepAgent: """ Agent class that solves the given task step by step, using the ReAct framework: @@ -179,7 +197,7 @@ class MultiStepAgent: max_steps: int = 6, tool_parser: Optional[Callable] = None, add_base_tools: bool = False, - verbose: bool = False, + verbosity_level: int = 1, grammar: Optional[Dict[str, str]] = None, managed_agents: Optional[List] = None, step_callbacks: Optional[List[Callable]] = None, @@ -205,7 +223,6 @@ class MultiStepAgent: self.managed_agents = {} if managed_agents is not None: - print("NOTNONE") self.managed_agents = {agent.name: agent for agent in managed_agents} self.tools = {tool.name: tool for tool in tools} @@ -222,8 +239,8 @@ class MultiStepAgent: self.input_messages = None self.logs = [] self.task = None - self.verbose = verbose - self.monitor = Monitor(self.model) + self.logger = AgentLogger(level=verbosity_level) + self.monitor = Monitor(self.model, self.logger) self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks.append(self.monitor.update_metrics) @@ -485,14 +502,15 @@ You have been provided with these additional arguments, that you can access usin else: self.logs.append(system_prompt_step) - console.print( + self.logger.log( Panel( f"\n[bold]{self.task.strip()}\n", title="[bold]New run", subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}", border_style=YELLOW_HEX, subtitle_align="left", - ) + ), + level=LogLevel.INFO, ) self.logs.append(TaskStep(task=self.task)) @@ -531,12 +549,13 @@ You have been provided with these additional arguments, that you can access usin is_first_step=(self.step_number == 0), step=self.step_number, ) - console.print( + self.logger.log( Rule( f"[bold]Step {self.step_number}", characters="━", style=YELLOW_HEX, - ) + ), + level=LogLevel.INFO, ) # Run one step! @@ -557,7 +576,7 @@ You have been provided with these additional arguments, that you can access usin final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) self.logs.append(final_step_log) final_answer = self.provide_final_answer(task) - console.print(Text(f"Final answer: {final_answer}")) + self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) final_step_log.action_output = final_answer final_step_log.end_time = time.time() final_step_log.duration = step_log.end_time - step_start_time @@ -586,12 +605,13 @@ You have been provided with these additional arguments, that you can access usin is_first_step=(self.step_number == 0), step=self.step_number, ) - console.print( + self.logger.log( Rule( f"[bold]Step {self.step_number}", characters="━", style=YELLOW_HEX, - ) + ), + level=LogLevel.INFO, ) # Run one step! @@ -613,7 +633,7 @@ You have been provided with these additional arguments, that you can access usin final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) self.logs.append(final_step_log) final_answer = self.provide_final_answer(task) - console.print(Text(f"Final answer: {final_answer}")) + self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) final_step_log.action_output = final_answer final_step_log.duration = 0 for callback in self.step_callbacks: @@ -679,8 +699,10 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.print( - Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction) + self.logger.log( + Rule("[bold]Initial plan", style="orange"), + Text(final_plan_redaction), + level=LogLevel.INFO, ) else: # update plan agent_memory = self.write_inner_memory_from_logs( @@ -735,8 +757,10 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.print( - Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction) + self.logger.log( + Rule("[bold]Updated plan", style="orange"), + Text(final_plan_redaction), + level=LogLevel.INFO, ) @@ -795,8 +819,11 @@ class ToolCallingAgent(MultiStepAgent): ) # Execute - console.print( - Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")) + self.logger.log( + Panel( + Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}") + ), + level=LogLevel.INFO, ) if tool_name == "final_answer": if isinstance(tool_arguments, dict): @@ -810,13 +837,15 @@ class ToolCallingAgent(MultiStepAgent): isinstance(answer, str) and answer in self.state.keys() ): # if the answer is a state variable, return the value final_answer = self.state[answer] - console.print( - f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'." + self.logger.log( + f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.", + level=LogLevel.INFO, ) else: final_answer = answer - console.print( - Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}") + self.logger.log( + Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"), + level=LogLevel.INFO, ) log_entry.action_output = final_answer @@ -837,7 +866,7 @@ class ToolCallingAgent(MultiStepAgent): updated_information = f"Stored '{observation_name}' in memory." else: updated_information = str(observation).strip() - console.print(f"Observations: {updated_information}") + self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO) log_entry.observations = updated_information return None @@ -922,22 +951,22 @@ class CodeAgent(MultiStepAgent): except Exception as e: raise AgentGenerationError(f"Error in generating model output:\n{e}") - if self.verbose: - console.print( - Group( - Rule( - "[italic]Output message of the LLM:", - align="left", - style="orange", - ), - Syntax( - llm_output, - lexer="markdown", - theme="github-dark", - word_wrap=True, - ), - ) - ) + self.logger.log( + Group( + Rule( + "[italic]Output message of the LLM:", + align="left", + style="orange", + ), + Syntax( + llm_output, + lexer="markdown", + theme="github-dark", + word_wrap=True, + ), + ), + level=LogLevel.DEBUG, + ) # Parse try: @@ -955,7 +984,7 @@ class CodeAgent(MultiStepAgent): ) # Execute - console.print( + self.logger.log( Panel( Syntax( code_action, @@ -966,7 +995,8 @@ class CodeAgent(MultiStepAgent): title="[bold]Executing this code:", title_align="left", box=box.HORIZONTALS, - ) + ), + level=LogLevel.INFO, ) observation = "" is_final_answer = False @@ -993,8 +1023,9 @@ class CodeAgent(MultiStepAgent): else: error_msg = str(e) if "Import of " in str(e) and " is not allowed" in str(e): - console.print( - "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent." + self.logger.log( + "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", + level=LogLevel.INFO, ) raise AgentExecutionError(error_msg) @@ -1008,7 +1039,7 @@ class CodeAgent(MultiStepAgent): style=(f"bold {YELLOW_HEX}" if is_final_answer else ""), ), ] - console.print(Group(*execution_outputs_console)) + self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) log_entry.action_output = output return output if is_final_answer else None diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index daa53cd..b6ba78f 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -16,13 +16,12 @@ # limitations under the License. from rich.text import Text -from .utils import console - class Monitor: - def __init__(self, tracked_model): + def __init__(self, tracked_model, logger): self.step_durations = [] self.tracked_model = tracked_model + self.logger = logger if ( getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found" @@ -53,7 +52,7 @@ class Monitor: self.total_output_token_count += self.tracked_model.last_output_token_count console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}" console_outputs += "]" - console.print(Text(console_outputs, style="dim")) + self.logger.log(Text(console_outputs, style="dim"), level=1) __all__ = ["Monitor"]