diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index e0b59d5..065adce 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -181,6 +181,7 @@ "import datasets\n", "import pandas as pd\n", "\n", + "\n", "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n", "pd.DataFrame(eval_ds)" ] @@ -199,26 +200,28 @@ "metadata": {}, "outputs": [], "source": [ - "import time\n", "import json\n", "import os\n", "import re\n", "import string\n", + "import time\n", "import warnings\n", - "from tqdm import tqdm\n", "from typing import List\n", "\n", + "from dotenv import load_dotenv\n", + "from tqdm import tqdm\n", + "\n", "from smolagents import (\n", - " GoogleSearchTool,\n", - " CodeAgent,\n", - " ToolCallingAgent,\n", - " HfApiModel,\n", " AgentError,\n", - " VisitWebpageTool,\n", + " CodeAgent,\n", + " GoogleSearchTool,\n", + " HfApiModel,\n", " PythonInterpreterTool,\n", + " ToolCallingAgent,\n", + " VisitWebpageTool,\n", ")\n", "from smolagents.agents import ActionStep\n", - "from dotenv import load_dotenv\n", + "\n", "\n", "load_dotenv()\n", "os.makedirs(\"output\", exist_ok=True)\n", @@ -231,9 +234,7 @@ " return str(obj)\n", "\n", "\n", - "def answer_questions(\n", - " eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n", - "):\n", + "def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n", " answered_questions = []\n", " if os.path.exists(file_name):\n", " with open(file_name, \"r\") as f:\n", @@ -365,23 +366,18 @@ " ma_elems = split_string(model_answer)\n", "\n", " if len(gt_elems) != len(ma_elems): # check length is the same\n", - " warnings.warn(\n", - " \"Answer lists have different lengths, returning False.\", UserWarning\n", - " )\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", " return False\n", "\n", " comparisons = []\n", - " for ma_elem, gt_elem in zip(\n", - " ma_elems, gt_elems\n", - " ): # compare each element as float or str\n", + " for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n", " if is_float(gt_elem):\n", " normalized_ma_elem = normalize_number_str(ma_elem)\n", " comparisons.append(normalized_ma_elem == float(gt_elem))\n", " else:\n", " # we do not remove punct since comparisons can include punct\n", " comparisons.append(\n", - " normalize_str(ma_elem, remove_punct=False)\n", - " == normalize_str(gt_elem, remove_punct=False)\n", + " normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n", " )\n", " return all(comparisons)\n", "\n", @@ -441,9 +437,7 @@ " action_type = \"vanilla\"\n", " llm = HfApiModel(model_id)\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(\n", - " eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n", - " )" + " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)" ] }, { @@ -461,6 +455,7 @@ "source": [ "from smolagents import LiteLLMModel\n", "\n", + "\n", "litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n", "\n", "for model_id in litellm_model_ids:\n", @@ -492,9 +487,7 @@ " action_type = \"vanilla\"\n", " llm = LiteLLMModel(model_id)\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(\n", - " eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n", - " )" + " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)" ] }, { @@ -556,9 +549,11 @@ } ], "source": [ - "import pandas as pd\n", "import glob\n", "\n", + "import pandas as pd\n", + "\n", + "\n", "res = []\n", "for file_path in glob.glob(\"output/*.jsonl\"):\n", " data = []\n", @@ -595,11 +590,7 @@ "\n", "result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n", "\n", - "result_df = (\n", - " (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n", - " .round(1)\n", - " .reset_index()\n", - ")" + "result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()" ] }, { @@ -895,6 +886,7 @@ "import pandas as pd\n", "from matplotlib.legend_handler import HandlerTuple # Added import\n", "\n", + "\n", "# Assuming pivot_df is your original dataframe\n", "models = pivot_df[\"model_id\"].unique()\n", "sources = pivot_df[\"source\"].unique()\n", @@ -961,14 +953,10 @@ "handles, labels = ax.get_legend_handles_labels()\n", "unique_sources = sources\n", "legend_elements = [\n", - " (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n", - " for i in range(len(unique_sources))\n", + " (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n", "]\n", "custom_legend = ax.legend(\n", - " [\n", - " (agent_handle, vanilla_handle)\n", - " for agent_handle, vanilla_handle, _ in legend_elements\n", - " ],\n", + " [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n", " [label for _, _, label in legend_elements],\n", " handler_map={tuple: HandlerTuple(ndivide=None)},\n", " bbox_to_anchor=(1.05, 1),\n", @@ -1006,9 +994,7 @@ " # Start the matrix environment with 4 columns\n", " # l for left-aligned model and task, c for centered numbers\n", " mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n", - " mathjax_table += (\n", - " \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n", - " )\n", + " mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n", " mathjax_table += \"\\\\hline\\n\"\n", "\n", " # Sort the DataFrame by model_id and source\n", @@ -1033,9 +1019,7 @@ " model_display = \"\\\\;\"\n", "\n", " # Add the data row\n", - " mathjax_table += (\n", - " f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n", - " )\n", + " mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n", "\n", " current_model = model\n", "\n", diff --git a/examples/e2b_example.py b/examples/e2b_example.py index 843e144..a58c7b1 100644 --- a/examples/e2b_example.py +++ b/examples/e2b_example.py @@ -1,7 +1,9 @@ -from smolagents import Tool, CodeAgent, HfApiModel -from smolagents.default_tools import VisitWebpageTool from dotenv import load_dotenv +from smolagents import CodeAgent, HfApiModel, Tool +from smolagents.default_tools import VisitWebpageTool + + load_dotenv() @@ -16,10 +18,11 @@ class GetCatImageTool(Tool): self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png" def forward(self): - from PIL import Image - import requests from io import BytesIO + import requests + from PIL import Image + response = requests.get(self.url) return Image.open(BytesIO(response.content)) @@ -46,4 +49,5 @@ agent.run( # Try the agent in a Gradio UI from smolagents import GradioUI + GradioUI(agent).launch() diff --git a/examples/gradio_upload.py b/examples/gradio_upload.py index 061d226..7460136 100644 --- a/examples/gradio_upload.py +++ b/examples/gradio_upload.py @@ -1,4 +1,5 @@ -from smolagents import CodeAgent, HfApiModel, GradioUI +from smolagents import CodeAgent, GradioUI, HfApiModel + agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1) diff --git a/examples/inspect_runs.py b/examples/inspect_runs.py index 3e24efa..9322f0b 100644 --- a/examples/inspect_runs.py +++ b/examples/inspect_runs.py @@ -1,24 +1,22 @@ +from openinference.instrumentation.smolagents import SmolagentsInstrumentor from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from openinference.instrumentation.smolagents import SmolagentsInstrumentor - from smolagents import ( CodeAgent, DuckDuckGoSearchTool, - VisitWebpageTool, + HfApiModel, ManagedAgent, ToolCallingAgent, - HfApiModel, + VisitWebpageTool, ) + # Let's setup the instrumentation first trace_provider = TracerProvider() -trace_provider.add_span_processor( - SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")) -) +trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))) SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True) @@ -39,6 +37,4 @@ manager_agent = CodeAgent( model=model, managed_agents=[managed_agent], ) -manager_agent.run( - "If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?" -) +manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?") diff --git a/examples/rag.py b/examples/rag.py index 83a201d..f5a2e2c 100644 --- a/examples/rag.py +++ b/examples/rag.py @@ -8,13 +8,10 @@ from langchain_community.retrievers import BM25Retriever knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") -knowledge_base = knowledge_base.filter( - lambda row: row["source"].startswith("huggingface/transformers") -) +knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers")) source_docs = [ - Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) - for doc in knowledge_base + Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base ] text_splitter = RecursiveCharacterTextSplitter( @@ -51,14 +48,12 @@ class RetrieverTool(Tool): query, ) return "\nRetrieved documents:\n" + "".join( - [ - f"\n\n===== Document {str(i)} =====\n" + doc.page_content - for i, doc in enumerate(docs) - ] + [f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)] ) -from smolagents import HfApiModel, CodeAgent +from smolagents import CodeAgent, HfApiModel + retriever_tool = RetrieverTool(docs_processed) agent = CodeAgent( @@ -68,9 +63,7 @@ agent = CodeAgent( verbosity_level=2, ) -agent_output = agent.run( - "For a transformers model training, which is slower, the forward or the backward pass?" -) +agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?") print("Final output:") print(agent_output) diff --git a/examples/text_to_sql.py b/examples/text_to_sql.py index 60b84f6..c25f0ca 100644 --- a/examples/text_to_sql.py +++ b/examples/text_to_sql.py @@ -1,16 +1,17 @@ from sqlalchemy import ( - create_engine, - MetaData, - Table, Column, - String, - Integer, Float, + Integer, + MetaData, + String, + Table, + create_engine, insert, inspect, text, ) + engine = create_engine("sqlite:///:memory:") metadata_obj = MetaData() @@ -40,9 +41,7 @@ for row in rows: inspector = inspect(engine) columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")] -table_description = "Columns:\n" + "\n".join( - [f" - {name}: {col_type}" for name, col_type in columns_info] -) +table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info]) print(table_description) from smolagents import tool @@ -72,6 +71,7 @@ def sql_engine(query: str) -> str: from smolagents import CodeAgent, HfApiModel + agent = CodeAgent( tools=[sql_engine], model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"), diff --git a/examples/tool_calling_agent_from_any_llm.py b/examples/tool_calling_agent_from_any_llm.py index 05daaa5..c9004a4 100644 --- a/examples/tool_calling_agent_from_any_llm.py +++ b/examples/tool_calling_agent_from_any_llm.py @@ -1,7 +1,9 @@ -from smolagents.agents import ToolCallingAgent -from smolagents import tool, LiteLLMModel from typing import Optional +from smolagents import LiteLLMModel, tool +from smolagents.agents import ToolCallingAgent + + # Choose which LLM engine to use! # model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct") diff --git a/examples/tool_calling_agent_mcp.py b/examples/tool_calling_agent_mcp.py index c0e613a..dfc847b 100644 --- a/examples/tool_calling_agent_mcp.py +++ b/examples/tool_calling_agent_mcp.py @@ -13,8 +13,10 @@ Usage: import os from mcp import StdioServerParameters + from smolagents import CodeAgent, HfApiModel, ToolCollection + mcp_server_params = StdioServerParameters( command="uvx", args=["--quiet", "pubmedmcp@0.1.3"], diff --git a/examples/tool_calling_agent_ollama.py b/examples/tool_calling_agent_ollama.py index c7198d6..ceafb57 100644 --- a/examples/tool_calling_agent_ollama.py +++ b/examples/tool_calling_agent_ollama.py @@ -1,7 +1,9 @@ -from smolagents.agents import ToolCallingAgent -from smolagents import tool, LiteLLMModel from typing import Optional +from smolagents import LiteLLMModel, tool +from smolagents.agents import ToolCallingAgent + + model = LiteLLMModel( model_id="ollama_chat/llama3.2", api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary diff --git a/pyproject.toml b/pyproject.toml index e3ff96d..95317ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,18 @@ dev = [ addopts = "-sv --durations=0" [tool.ruff] -lint.ignore = ["F403"] +line-length = 119 +lint.ignore = [ + "F403", # undefined-local-with-import-star + "E501", # line-too-long +] +lint.select = ["E", "F", "I", "W"] [tool.ruff.lint.per-file-ignores] "examples/*" = [ "E402", # module-import-not-at-top-of-file ] + +[tool.ruff.lint.isort] +known-first-party = ["smolagents"] +lines-after-imports = 2 diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py index f457b7e..fbce4fb 100644 --- a/src/smolagents/__init__.py +++ b/src/smolagents/__init__.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from transformers.utils import _LazyModule from transformers.utils.import_utils import define_import_structure + if TYPE_CHECKING: from .agents import * from .default_tools import * diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 1b04fd0..a9ec233 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -16,18 +16,17 @@ # limitations under the License. import time from dataclasses import dataclass +from enum import IntEnum from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from enum import IntEnum from rich import box -from rich.console import Group +from rich.console import Console, Group from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax from rich.text import Text -from rich.console import Console -from .default_tools import FinalAnswerTool, TOOL_MAPPING +from .default_tools import TOOL_MAPPING, FinalAnswerTool from .e2b_executor import E2BExecutor from .local_python_executor import ( BASE_BUILTIN_MODULES, @@ -112,20 +111,11 @@ class SystemPromptStep(AgentStepLog): system_prompt: str -def get_tool_descriptions( - tools: Dict[str, Tool], tool_description_template: str -) -> str: - return "\n".join( - [ - get_tool_description_with_args(tool, tool_description_template) - for tool in tools.values() - ] - ) +def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str: + return "\n".join([get_tool_description_with_args(tool, tool_description_template) for tool in tools.values()]) -def format_prompt_with_tools( - tools: Dict[str, Tool], prompt_template: str, tool_description_template: str -) -> str: +def format_prompt_with_tools(tools: Dict[str, Tool], prompt_template: str, tool_description_template: str) -> str: tool_descriptions = get_tool_descriptions(tools, tool_description_template) prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) if "{{tool_names}}" in prompt: @@ -159,9 +149,7 @@ def format_prompt_with_managed_agents_descriptions( f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'" ) if len(managed_agents.keys()) > 0: - return prompt_template.replace( - agent_descriptions_placeholder, show_agents_descriptions(managed_agents) - ) + return prompt_template.replace(agent_descriptions_placeholder, show_agents_descriptions(managed_agents)) else: return prompt_template.replace(agent_descriptions_placeholder, "") @@ -214,9 +202,7 @@ class MultiStepAgent: self.model = model self.system_prompt_template = system_prompt self.tool_description_template = ( - tool_description_template - if tool_description_template - else DEFAULT_TOOL_DESCRIPTION_TEMPLATE + tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE ) self.max_steps = max_steps self.tool_parser = tool_parser @@ -231,10 +217,7 @@ class MultiStepAgent: self.tools = {tool.name: tool for tool in tools} if add_base_tools: for tool_name, tool_class in TOOL_MAPPING.items(): - if ( - tool_name != "python_interpreter" - or self.__class__.__name__ == "ToolCallingAgent" - ): + if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent": self.tools[tool_name] = tool_class() self.tools["final_answer"] = FinalAnswerTool() @@ -253,15 +236,11 @@ class MultiStepAgent: self.system_prompt_template, self.tool_description_template, ) - self.system_prompt = format_prompt_with_managed_agents_descriptions( - self.system_prompt, self.managed_agents - ) + self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) return self.system_prompt - def write_inner_memory_from_logs( - self, summary_mode: Optional[bool] = False - ) -> List[Dict[str, str]]: + def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: """ Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages that can be used as input to the LLM. @@ -355,10 +334,7 @@ class MultiStepAgent: return memory def get_succinct_logs(self): - return [ - {key: value for key, value in log.items() if key != "agent_memory"} - for log in self.logs - ] + return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]: """ @@ -402,9 +378,7 @@ class MultiStepAgent: except Exception as e: return f"Error in generating final LLM output:\n{e}" - def execute_tool_call( - self, tool_name: str, arguments: Union[Dict[str, str], str] - ) -> Any: + def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any: """ Execute tool with the provided input and returns the result. This method replaces arguments with the actual values from the state if they refer to state variables. @@ -423,9 +397,7 @@ class MultiStepAgent: if tool_name in self.managed_agents: observation = available_tools[tool_name].__call__(arguments) else: - observation = available_tools[tool_name].__call__( - arguments, sanitize_inputs_outputs=True - ) + observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True) elif isinstance(arguments, dict): for key, value in arguments.items(): if isinstance(value, str) and value in self.state: @@ -433,18 +405,14 @@ class MultiStepAgent: if tool_name in self.managed_agents: observation = available_tools[tool_name].__call__(**arguments) else: - observation = available_tools[tool_name].__call__( - **arguments, sanitize_inputs_outputs=True - ) + observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) else: error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." raise AgentExecutionError(error_msg) return observation except Exception as e: if tool_name in self.tools: - tool_description = get_tool_description_with_args( - available_tools[tool_name] - ) + tool_description = get_tool_description_with_args(available_tools[tool_name]) error_msg = ( f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"As a reminder, this tool's description is the following:\n{tool_description}" @@ -544,10 +512,7 @@ You have been provided with these additional arguments, that you can access usin step_start_time = time.time() step_log = ActionStep(step=self.step_number, start_time=step_start_time) try: - if ( - self.planning_interval is not None - and self.step_number % self.planning_interval == 0 - ): + if self.planning_interval is not None and self.step_number % self.planning_interval == 0: self.planning_step( task, is_first_step=(self.step_number == 0), @@ -600,10 +565,7 @@ You have been provided with these additional arguments, that you can access usin step_start_time = time.time() step_log = ActionStep(step=self.step_number, start_time=step_start_time) try: - if ( - self.planning_interval is not None - and self.step_number % self.planning_interval == 0 - ): + if self.planning_interval is not None and self.step_number % self.planning_interval == 0: self.planning_step( task, is_first_step=(self.step_number == 0), @@ -668,9 +630,7 @@ You have been provided with these additional arguments, that you can access usin Now begin!""", } - answer_facts = self.model( - [message_prompt_facts, message_prompt_task] - ).content + answer_facts = self.model([message_prompt_facts, message_prompt_task]).content message_system_prompt_plan = { "role": MessageRole.SYSTEM, @@ -680,12 +640,8 @@ Now begin!""", "role": MessageRole.USER, "content": USER_PROMPT_PLAN.format( task=task, - tool_descriptions=get_tool_descriptions( - self.tools, self.tool_description_template - ), - managed_agents_descriptions=( - show_agents_descriptions(self.managed_agents) - ), + tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), + managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), answer_facts=answer_facts, ), } @@ -702,9 +658,7 @@ Now begin!""", ``` {answer_facts} ```""".strip() - self.logs.append( - PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) - ) + self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.logger.log( Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction), @@ -724,9 +678,7 @@ Now begin!""", "role": MessageRole.USER, "content": USER_PROMPT_FACTS_UPDATE, } - facts_update = self.model( - [facts_update_system_prompt] + agent_memory + [facts_update_message] - ).content + facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content # Redact updated plan plan_update_message = { @@ -737,12 +689,8 @@ Now begin!""", "role": MessageRole.USER, "content": USER_PROMPT_PLAN_UPDATE.format( task=task, - tool_descriptions=get_tool_descriptions( - self.tools, self.tool_description_template - ), - managed_agents_descriptions=( - show_agents_descriptions(self.managed_agents) - ), + tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), + managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), facts_update=facts_update, remaining_steps=(self.max_steps - step), ), @@ -753,16 +701,12 @@ Now begin!""", ).content # Log final facts and plan - final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format( - task=task, plan_update=plan_update - ) + final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) final_facts_redaction = f"""Here is the updated list of the facts that I know: ``` {facts_update} ```""" - self.logs.append( - PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) - ) + self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.logger.log( Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction), @@ -816,19 +760,13 @@ class ToolCallingAgent(MultiStepAgent): tool_arguments = tool_call.function.arguments except Exception as e: - raise AgentGenerationError( - f"Error in generating tool call with model:\n{e}" - ) + raise AgentGenerationError(f"Error in generating tool call with model:\n{e}") - log_entry.tool_calls = [ - ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id) - ] + log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] # Execute self.logger.log( - Panel( - Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}") - ), + Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")), level=LogLevel.INFO, ) if tool_name == "final_answer": @@ -900,16 +838,10 @@ class CodeAgent(MultiStepAgent): if system_prompt is None: system_prompt = CODE_SYSTEM_PROMPT - self.additional_authorized_imports = ( - additional_authorized_imports if additional_authorized_imports else [] - ) - self.authorized_imports = list( - set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) - ) + self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] + self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) if "{{authorized_imports}}" not in system_prompt: - raise AgentError( - "Tag '{{authorized_imports}}' should be provided in the prompt." - ) + raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.") super().__init__( tools=tools, model=model, @@ -966,9 +898,7 @@ class CodeAgent(MultiStepAgent): log_entry.agent_memory = agent_memory.copy() try: - additional_args = ( - {"grammar": self.grammar} if self.grammar is not None else {} - ) + additional_args = {"grammar": self.grammar} if self.grammar is not None else {} llm_output = self.model( self.input_messages, stop_sequences=["", "Observation:"], @@ -999,9 +929,7 @@ class CodeAgent(MultiStepAgent): try: code_action = fix_final_answer_code(parse_code_blobs(llm_output)) except Exception as e: - error_msg = ( - f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." - ) + error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." raise AgentParsingError(error_msg) log_entry.tool_calls = [ @@ -1088,17 +1016,13 @@ class ManagedAgent: self.description = description self.additional_prompting = additional_prompting self.provide_run_summary = provide_run_summary - self.managed_agent_prompt = ( - managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT - ) + self.managed_agent_prompt = managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT def write_full_task(self, task): """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" full_task = self.managed_agent_prompt.format(name=self.name, task=task) if self.additional_prompting: - full_task = full_task.replace( - "\n{{additional_prompting}}", self.additional_prompting - ).strip() + full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip() else: full_task = full_task.replace("\n{{additional_prompting}}", "").strip() return full_task @@ -1107,9 +1031,7 @@ class ManagedAgent: full_task = self.write_full_task(request) output = self.agent.run(full_task, **kwargs) if self.provide_run_summary: - answer = ( - f"Here is the final answer from your managed agent '{self.name}':\n" - ) + answer = f"Here is the final answer from your managed agent '{self.name}':\n" answer += str(output) answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" for message in self.agent.write_inner_memory_from_logs(summary_mode=True): diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 14a46ae..c0fa139 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -20,8 +20,6 @@ from dataclasses import dataclass from typing import Dict, Optional from huggingface_hub import hf_hub_download, list_spaces - - from transformers.utils import is_offline_mode, is_torch_available from .local_python_executor import ( @@ -32,6 +30,7 @@ from .local_python_executor import ( from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .types import AgentAudio + if is_torch_available(): from transformers.models.whisper import ( WhisperForConditionalGeneration, @@ -61,9 +60,7 @@ def get_remote_tools(logger, organization="huggingface-tools"): tools = {} for space_info in spaces: repo_id = space_info.id - resolved_config_file = hf_hub_download( - repo_id, TOOL_CONFIG_FILE, repo_type="space" - ) + resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") with open(resolved_config_file, encoding="utf-8") as reader: config = json.load(reader) task = repo_id.split("/")[-1] @@ -94,9 +91,7 @@ class PythonInterpreterTool(Tool): if authorized_imports is None: self.authorized_imports = list(set(BASE_BUILTIN_MODULES)) else: - self.authorized_imports = list( - set(BASE_BUILTIN_MODULES) | set(authorized_imports) - ) + self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports)) self.inputs = { "code": { "type": "string", @@ -126,9 +121,7 @@ class PythonInterpreterTool(Tool): class FinalAnswerTool(Tool): name = "final_answer" description = "Provides a final answer to the given problem." - inputs = { - "answer": {"type": "any", "description": "The final answer to the problem"} - } + inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} output_type = "any" def forward(self, answer): @@ -138,9 +131,7 @@ class FinalAnswerTool(Tool): class UserInputTool(Tool): name = "user_input" description = "Asks for user's input on a specific question" - inputs = { - "question": {"type": "string", "description": "The question to ask the user"} - } + inputs = {"question": {"type": "string", "description": "The question to ask the user"}} output_type = "string" def forward(self, question): @@ -151,9 +142,7 @@ class UserInputTool(Tool): class DuckDuckGoSearchTool(Tool): name = "web_search" description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.""" - inputs = { - "query": {"type": "string", "description": "The search query to perform."} - } + inputs = {"query": {"type": "string", "description": "The search query to perform."}} output_type = "string" def __init__(self, *args, max_results=10, **kwargs): @@ -169,10 +158,7 @@ class DuckDuckGoSearchTool(Tool): def forward(self, query: str) -> str: results = self.ddgs.text(query, max_results=self.max_results) - postprocessed_results = [ - f"[{result['title']}]({result['href']})\n{result['body']}" - for result in results - ] + postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] return "## Search Results\n\n" + "\n\n".join(postprocessed_results) @@ -199,9 +185,7 @@ class GoogleSearchTool(Tool): import requests if self.serpapi_key is None: - raise ValueError( - "Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables." - ) + raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.") params = { "engine": "google", @@ -210,9 +194,7 @@ class GoogleSearchTool(Tool): "google_domain": "google.com", } if filter_year is not None: - params["tbs"] = ( - f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" - ) + params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" response = requests.get("https://serpapi.com/search.json", params=params) @@ -227,13 +209,9 @@ class GoogleSearchTool(Tool): f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." ) else: - raise Exception( - f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." - ) + raise Exception(f"'organic_results' key not found for query: '{query}'. Use a less restrictive query.") if len(results["organic_results"]) == 0: - year_filter_message = ( - f" with filter year={filter_year}" if filter_year is not None else "" - ) + year_filter_message = f" with filter year={filter_year}" if filter_year is not None else "" return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." web_snippets = [] @@ -253,9 +231,7 @@ class GoogleSearchTool(Tool): redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" - redacted_version = redacted_version.replace( - "Your browser can't play this video.", "" - ) + redacted_version = redacted_version.replace("Your browser can't play this video.", "") web_snippets.append(redacted_version) return "## Search Results\n" + "\n\n".join(web_snippets) @@ -263,7 +239,9 @@ class GoogleSearchTool(Tool): class VisitWebpageTool(Tool): name = "visit_webpage" - description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." + description = ( + "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." + ) inputs = { "url": { "type": "string", @@ -277,6 +255,7 @@ class VisitWebpageTool(Tool): import requests from markdownify import markdownify from requests.exceptions import RequestException + from smolagents.utils import truncate_content except ImportError: raise ImportError( diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index 8a20a9e..393e572 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -28,6 +28,7 @@ from .tool_validation import validate_tool_attributes from .tools import Tool from .utils import BASE_BUILTIN_MODULES, instance_to_source + load_dotenv() @@ -45,9 +46,7 @@ class E2BExecutor: self.logger = logger additional_imports = additional_imports + ["pickle5", "smolagents"] if len(additional_imports) > 0: - execution = self.sbx.commands.run( - "pip install " + " ".join(additional_imports) - ) + execution = self.sbx.commands.run("pip install " + " ".join(additional_imports)) if execution.error: raise Exception(f"Error installing dependencies: {execution.error}") else: @@ -61,9 +60,7 @@ class E2BExecutor: tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n" tool_codes.append(tool_code) - tool_definition_code = "\n".join( - [f"import {module}" for module in BASE_BUILTIN_MODULES] - ) + tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) tool_definition_code += textwrap.dedent(""" class Tool: def __call__(self, *args, **kwargs): @@ -122,9 +119,7 @@ locals().update({key: value for key, value in pickle_dict.items()}) for attribute_name in ["jpeg", "png"]: if getattr(result, attribute_name) is not None: image_output = getattr(result, attribute_name) - decoded_bytes = base64.b64decode( - image_output.encode("utf-8") - ) + decoded_bytes = base64.b64decode(image_output.encode("utf-8")) return Image.open(BytesIO(decoded_bytes)), execution_logs for attribute_name in [ "chart", diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index e04056d..fdb1cd5 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -13,14 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import gradio as gr -import shutil -import os import mimetypes +import os import re - +import shutil from typing import Optional +import gradio as gr + from .agents import ActionStep, AgentStepLog, MultiStepAgent from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types @@ -59,9 +59,7 @@ def stream_to_gradio( ): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" - for step_log in agent.run( - task, stream=True, reset=reset_agent_memory, additional_args=additional_args - ): + for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): for message in pull_messages_from_step(step_log, test_mode=test_mode): yield message @@ -147,14 +145,10 @@ class GradioUI: sanitized_name = "".join(sanitized_name) # Save the uploaded file to the specified folder - file_path = os.path.join( - self.file_upload_folder, os.path.basename(sanitized_name) - ) + file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) shutil.copy(file.name, file_path) - return gr.Textbox( - f"File uploaded: {file_path}", visible=True - ), file_uploads_log + [file_path] + return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path] def log_user_message(self, text_input, file_uploads_log): return ( @@ -183,9 +177,7 @@ class GradioUI: # If an upload folder is provided, enable the upload feature if self.file_upload_folder is not None: upload_file = gr.File(label="Upload a file") - upload_status = gr.Textbox( - label="Upload Status", interactive=False, visible=False - ) + upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False) upload_file.change( self.upload_file, [upload_file, file_uploads_log], diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index 0c7b5bc..9477be9 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -42,8 +42,7 @@ class InterpreterError(ValueError): ERRORS = { name: getattr(builtins, name) for name in dir(builtins) - if isinstance(getattr(builtins, name), type) - and issubclass(getattr(builtins, name), BaseException) + if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) } PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000 @@ -167,9 +166,7 @@ def evaluate_unaryop( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Any: - operand = evaluate_ast( - expression.operand, state, static_tools, custom_tools, authorized_imports - ) + operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports) if isinstance(expression.op, ast.USub): return -operand elif isinstance(expression.op, ast.UAdd): @@ -179,9 +176,7 @@ def evaluate_unaryop( elif isinstance(expression.op, ast.Invert): return ~operand else: - raise InterpreterError( - f"Unary operation {expression.op.__class__.__name__} is not supported." - ) + raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") def evaluate_lambda( @@ -217,23 +212,17 @@ def evaluate_while( ) -> None: max_iterations = 1000 iterations = 0 - while evaluate_ast( - while_loop.test, state, static_tools, custom_tools, authorized_imports - ): + while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports): for node in while_loop.body: try: - evaluate_ast( - node, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) except BreakException: return None except ContinueException: break iterations += 1 if iterations > max_iterations: - raise InterpreterError( - f"Maximum number of {max_iterations} iterations in While loop exceeded" - ) + raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded") return None @@ -248,8 +237,7 @@ def create_function( func_state = state.copy() arg_names = [arg.arg for arg in func_def.args.args] default_values = [ - evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) - for d in func_def.args.defaults + evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults ] # Apply default values @@ -286,9 +274,7 @@ def create_function( result = None try: for stmt in func_def.body: - result = evaluate_ast( - stmt, func_state, static_tools, custom_tools, authorized_imports - ) + result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports) except ReturnException as e: result = e.value @@ -307,9 +293,7 @@ def evaluate_function_def( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Callable: - custom_tools[func_def.name] = create_function( - func_def, state, static_tools, custom_tools, authorized_imports - ) + custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports) return custom_tools[func_def.name] @@ -321,17 +305,12 @@ def evaluate_class_def( authorized_imports: List[str], ) -> type: class_name = class_def.name - bases = [ - evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) - for base in class_def.bases - ] + bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases] class_dict = {} for stmt in class_def.body: if isinstance(stmt, ast.FunctionDef): - class_dict[stmt.name] = evaluate_function_def( - stmt, state, static_tools, custom_tools, authorized_imports - ) + class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports) elif isinstance(stmt, ast.Assign): for target in stmt.targets: if isinstance(target, ast.Name): @@ -351,9 +330,7 @@ def evaluate_class_def( authorized_imports, ) else: - raise InterpreterError( - f"Unsupported statement in class body: {stmt.__class__.__name__}" - ) + raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") new_class = type(class_name, tuple(bases), class_dict) state[class_name] = new_class @@ -371,38 +348,26 @@ def evaluate_augassign( if isinstance(target, ast.Name): return state.get(target.id, 0) elif isinstance(target, ast.Subscript): - obj = evaluate_ast( - target.value, state, static_tools, custom_tools, authorized_imports - ) - key = evaluate_ast( - target.slice, state, static_tools, custom_tools, authorized_imports - ) + obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) + key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) return obj[key] elif isinstance(target, ast.Attribute): - obj = evaluate_ast( - target.value, state, static_tools, custom_tools, authorized_imports - ) + obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) return getattr(obj, target.attr) elif isinstance(target, ast.Tuple): return tuple(get_current_value(elt) for elt in target.elts) elif isinstance(target, ast.List): return [get_current_value(elt) for elt in target.elts] else: - raise InterpreterError( - "AugAssign not supported for {type(target)} targets." - ) + raise InterpreterError("AugAssign not supported for {type(target)} targets.") current_value = get_current_value(expression.target) - value_to_add = evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) if isinstance(expression.op, ast.Add): if isinstance(current_value, list): if not isinstance(value_to_add, list): - raise InterpreterError( - f"Cannot add non-list value {value_to_add} to a list." - ) + raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") updated_value = current_value + value_to_add else: updated_value = current_value + value_to_add @@ -429,9 +394,7 @@ def evaluate_augassign( elif isinstance(expression.op, ast.RShift): updated_value = current_value >> value_to_add else: - raise InterpreterError( - f"Operation {type(expression.op).__name__} is not supported." - ) + raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") # Update the state set_value( @@ -455,16 +418,12 @@ def evaluate_boolop( ) -> bool: if isinstance(node.op, ast.And): for value in node.values: - if not evaluate_ast( - value, state, static_tools, custom_tools, authorized_imports - ): + if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports): return False return True elif isinstance(node.op, ast.Or): for value in node.values: - if evaluate_ast( - value, state, static_tools, custom_tools, authorized_imports - ): + if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports): return True return False @@ -477,12 +436,8 @@ def evaluate_binop( authorized_imports: List[str], ) -> Any: # Recursively evaluate the left and right operands - left_val = evaluate_ast( - binop.left, state, static_tools, custom_tools, authorized_imports - ) - right_val = evaluate_ast( - binop.right, state, static_tools, custom_tools, authorized_imports - ) + left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports) + right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports) # Determine the operation based on the type of the operator in the BinOp if isinstance(binop.op, ast.Add): @@ -510,9 +465,7 @@ def evaluate_binop( elif isinstance(binop.op, ast.RShift): return left_val >> right_val else: - raise NotImplementedError( - f"Binary operation {type(binop.op).__name__} is not implemented." - ) + raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") def evaluate_assign( @@ -522,17 +475,13 @@ def evaluate_assign( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Any: - result = evaluate_ast( - assign.value, state, static_tools, custom_tools, authorized_imports - ) + result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports) if len(assign.targets) == 1: target = assign.targets[0] set_value(target, result, state, static_tools, custom_tools, authorized_imports) else: if len(assign.targets) != len(result): - raise InterpreterError( - f"Assign failed: expected {len(result)} values but got {len(assign.targets)}." - ) + raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") expanded_values = [] for tgt in assign.targets: if isinstance(tgt, ast.Starred): @@ -554,9 +503,7 @@ def set_value( ) -> None: if isinstance(target, ast.Name): if target.id in static_tools: - raise InterpreterError( - f"Cannot assign to name '{target.id}': doing this would erase the existing tool!" - ) + raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") state[target.id] = value elif isinstance(target, ast.Tuple): if not isinstance(value, tuple): @@ -567,21 +514,13 @@ def set_value( if len(target.elts) != len(value): raise InterpreterError("Cannot unpack tuple of wrong size") for i, elem in enumerate(target.elts): - set_value( - elem, value[i], state, static_tools, custom_tools, authorized_imports - ) + set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports) elif isinstance(target, ast.Subscript): - obj = evaluate_ast( - target.value, state, static_tools, custom_tools, authorized_imports - ) - key = evaluate_ast( - target.slice, state, static_tools, custom_tools, authorized_imports - ) + obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) + key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) obj[key] = value elif isinstance(target, ast.Attribute): - obj = evaluate_ast( - target.value, state, static_tools, custom_tools, authorized_imports - ) + obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) setattr(obj, target.attr, value) @@ -593,15 +532,11 @@ def evaluate_call( authorized_imports: List[str], ) -> Any: if not ( - isinstance(call.func, ast.Attribute) - or isinstance(call.func, ast.Name) - or isinstance(call.func, ast.Subscript) + isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript) ): raise InterpreterError(f"This is not a correct function: {call.func}).") if isinstance(call.func, ast.Attribute): - obj = evaluate_ast( - call.func.value, state, static_tools, custom_tools, authorized_imports - ) + obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) func_name = call.func.attr if not hasattr(obj, func_name): raise InterpreterError(f"Object {obj} has no attribute {func_name}") @@ -623,18 +558,12 @@ def evaluate_call( ) elif isinstance(call.func, ast.Subscript): - value = evaluate_ast( - call.func.value, state, static_tools, custom_tools, authorized_imports - ) - index = evaluate_ast( - call.func.slice, state, static_tools, custom_tools, authorized_imports - ) + value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) + index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports) if isinstance(value, (list, tuple)): func = value[index] else: - raise InterpreterError( - f"Cannot subscript object of type {type(value).__name__}" - ) + raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}") if not callable(func): raise InterpreterError(f"This is not a correct function: {call.func}).") @@ -642,20 +571,12 @@ def evaluate_call( args = [] for arg in call.args: if isinstance(arg, ast.Starred): - args.extend( - evaluate_ast( - arg.value, state, static_tools, custom_tools, authorized_imports - ) - ) + args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports)) else: - args.append( - evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports) - ) + args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)) kwargs = { - keyword.arg: evaluate_ast( - keyword.value, state, static_tools, custom_tools, authorized_imports - ) + keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports) for keyword in call.keywords } @@ -693,17 +614,11 @@ def evaluate_subscript( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Any: - index = evaluate_ast( - subscript.slice, state, static_tools, custom_tools, authorized_imports - ) - value = evaluate_ast( - subscript.value, state, static_tools, custom_tools, authorized_imports - ) + index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports) + value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports) if isinstance(value, str) and isinstance(index, str): - raise InterpreterError( - "You're trying to subscript a string with a string index, which is impossible" - ) + raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") if isinstance(value, pd.core.indexing._LocIndexer): parent_object = value.obj return parent_object.loc[index] @@ -718,15 +633,11 @@ def evaluate_subscript( return value[index] elif isinstance(value, (list, tuple)): if not (-len(value) <= index < len(value)): - raise InterpreterError( - f"Index {index} out of bounds for list of length {len(value)}" - ) + raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") return value[int(index)] elif isinstance(value, str): if not (-len(value) <= index < len(value)): - raise InterpreterError( - f"Index {index} out of bounds for string of length {len(value)}" - ) + raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") return value[index] elif index in value: return value[index] @@ -765,12 +676,9 @@ def evaluate_condition( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> bool: - left = evaluate_ast( - condition.left, state, static_tools, custom_tools, authorized_imports - ) + left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) comparators = [ - evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) - for c in condition.comparators + evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators ] ops = [type(op) for op in condition.ops] @@ -818,21 +726,15 @@ def evaluate_if( authorized_imports: List[str], ) -> Any: result = None - test_result = evaluate_ast( - if_statement.test, state, static_tools, custom_tools, authorized_imports - ) + test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports) if test_result: for line in if_statement.body: - line_result = evaluate_ast( - line, state, static_tools, custom_tools, authorized_imports - ) + line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) if line_result is not None: result = line_result else: for line in if_statement.orelse: - line_result = evaluate_ast( - line, state, static_tools, custom_tools, authorized_imports - ) + line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) if line_result is not None: result = line_result return result @@ -846,9 +748,7 @@ def evaluate_for( authorized_imports: List[str], ) -> Any: result = None - iterator = evaluate_ast( - for_loop.iter, state, static_tools, custom_tools, authorized_imports - ) + iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports) for counter in iterator: set_value( for_loop.target, @@ -860,9 +760,7 @@ def evaluate_for( ) for node in for_loop.body: try: - line_result = evaluate_ast( - node, state, static_tools, custom_tools, authorized_imports - ) + line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) if line_result is not None: result = line_result except BreakException: @@ -882,9 +780,7 @@ def evaluate_listcomp( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> List[Any]: - def inner_evaluate( - generators: List[ast.comprehension], index: int, current_state: Dict[str, Any] - ) -> List[Any]: + def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]: if index >= len(generators): return [ evaluate_ast( @@ -912,9 +808,7 @@ def evaluate_listcomp( else: new_state[generator.target.id] = value if all( - evaluate_ast( - if_clause, new_state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) for if_clause in generator.ifs ): result.extend(inner_evaluate(generators, index + 1, new_state)) @@ -938,32 +832,24 @@ def evaluate_try( for handler in try_node.handlers: if handler.type is None or isinstance( e, - evaluate_ast( - handler.type, state, static_tools, custom_tools, authorized_imports - ), + evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports), ): matched = True if handler.name: state[handler.name] = e for stmt in handler.body: - evaluate_ast( - stmt, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) break if not matched: raise e else: if try_node.orelse: for stmt in try_node.orelse: - evaluate_ast( - stmt, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) finally: if try_node.finalbody: for stmt in try_node.finalbody: - evaluate_ast( - stmt, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) def evaluate_raise( @@ -974,15 +860,11 @@ def evaluate_raise( authorized_imports: List[str], ) -> None: if raise_node.exc is not None: - exc = evaluate_ast( - raise_node.exc, state, static_tools, custom_tools, authorized_imports - ) + exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports) else: exc = None if raise_node.cause is not None: - cause = evaluate_ast( - raise_node.cause, state, static_tools, custom_tools, authorized_imports - ) + cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports) else: cause = None if exc is not None: @@ -1001,14 +883,10 @@ def evaluate_assert( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> None: - test_result = evaluate_ast( - assert_node.test, state, static_tools, custom_tools, authorized_imports - ) + test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports) if not test_result: if assert_node.msg: - msg = evaluate_ast( - assert_node.msg, state, static_tools, custom_tools, authorized_imports - ) + msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports) raise AssertionError(msg) else: # Include the failing condition in the assertion message @@ -1025,9 +903,7 @@ def evaluate_with( ) -> None: contexts = [] for item in with_node.items: - context_expr = evaluate_ast( - item.context_expr, state, static_tools, custom_tools, authorized_imports - ) + context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports) if item.optional_vars: state[item.optional_vars.id] = context_expr.__enter__() contexts.append(state[item.optional_vars.id]) @@ -1069,19 +945,14 @@ def get_safe_module(unsafe_module, dangerous_patterns, visited=None): # Copy all attributes by reference, recursively checking modules for attr_name in dir(unsafe_module): # Skip dangerous patterns at any level - if any( - pattern in f"{unsafe_module.__name__}.{attr_name}" - for pattern in dangerous_patterns - ): + if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns): continue attr_value = getattr(unsafe_module, attr_name) # Recursively process nested modules, passing visited set if isinstance(attr_value, ModuleType): - attr_value = get_safe_module( - attr_value, dangerous_patterns, visited=visited - ) + attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited) setattr(safe_module, attr_name, attr_value) @@ -1116,18 +987,14 @@ def import_modules(expression, state, authorized_imports): module_path = module_name.split(".") if any([module in dangerous_patterns for module in module_path]): return False - module_subpaths = [ - ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) - ] + module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] return any(subpath in authorized_imports for subpath in module_subpaths) if isinstance(expression, ast.Import): for alias in expression.names: if check_module_authorized(alias.name): raw_module = import_module(alias.name) - state[alias.asname or alias.name] = get_safe_module( - raw_module, dangerous_patterns - ) + state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns) else: raise InterpreterError( f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" @@ -1135,9 +1002,7 @@ def import_modules(expression, state, authorized_imports): return None elif isinstance(expression, ast.ImportFrom): if check_module_authorized(expression.module): - raw_module = __import__( - expression.module, fromlist=[alias.name for alias in expression.names] - ) + raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) for alias in expression.names: state[alias.asname or alias.name] = get_safe_module( getattr(raw_module, alias.name), dangerous_patterns @@ -1156,9 +1021,7 @@ def evaluate_dictcomp( ) -> Dict[Any, Any]: result = {} for gen in dictcomp.generators: - iter_value = evaluate_ast( - gen.iter, state, static_tools, custom_tools, authorized_imports - ) + iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports) for value in iter_value: new_state = state.copy() set_value( @@ -1170,9 +1033,7 @@ def evaluate_dictcomp( authorized_imports, ) if all( - evaluate_ast( - if_clause, new_state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) for if_clause in gen.ifs ): key = evaluate_ast( @@ -1229,202 +1090,116 @@ def evaluate_ast( if isinstance(expression, ast.Assign): # Assignment -> we evaluate the assignment which should update the state # We return the variable assigned as it may be used to determine the final result. - return evaluate_assign( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.AugAssign): - return evaluate_augassign( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Call): # Function call -> we return the value of the function call - return evaluate_call( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Constant): # Constant -> just return the value return expression.value elif isinstance(expression, ast.Tuple): return tuple( - evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) - for elt in expression.elts + evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts ) elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): - return evaluate_listcomp( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.UnaryOp): - return evaluate_unaryop( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Starred): - return evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.BoolOp): # Boolean operation -> evaluate the operation - return evaluate_boolop( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Break): raise BreakException() elif isinstance(expression, ast.Continue): raise ContinueException() elif isinstance(expression, ast.BinOp): # Binary operation -> execute operation - return evaluate_binop( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Compare): # Comparison -> evaluate the comparison - return evaluate_condition( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Lambda): - return evaluate_lambda( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.FunctionDef): - return evaluate_function_def( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Dict): # Dict -> evaluate all keys and values - keys = [ - evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) - for k in expression.keys - ] - values = [ - evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) - for v in expression.values - ] + keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys] + values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values] return dict(zip(keys, values)) elif isinstance(expression, ast.Expr): # Expression -> evaluate the content - return evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.For): # For loop -> execute the loop - return evaluate_for( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.FormattedValue): # Formatted value (part of f-string) -> evaluate the content and return - return evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.If): # If -> execute the right branch - return evaluate_if( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports) elif hasattr(ast, "Index") and isinstance(expression, ast.Index): - return evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.JoinedStr): return "".join( - [ - str( - evaluate_ast( - v, state, static_tools, custom_tools, authorized_imports - ) - ) - for v in expression.values - ] + [str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values] ) elif isinstance(expression, ast.List): # List -> evaluate all elements - return [ - evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) - for elt in expression.elts - ] + return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts] elif isinstance(expression, ast.Name): # Name -> pick up the value in the state - return evaluate_name( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Subscript): # Subscript -> return the value of the indexing - return evaluate_subscript( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.IfExp): - test_val = evaluate_ast( - expression.test, state, static_tools, custom_tools, authorized_imports - ) + test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports) if test_val: - return evaluate_ast( - expression.body, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports) else: - return evaluate_ast( - expression.orelse, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Attribute): - value = evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) return getattr(value, expression.attr) elif isinstance(expression, ast.Slice): return slice( - evaluate_ast( - expression.lower, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports) if expression.lower is not None else None, - evaluate_ast( - expression.upper, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports) if expression.upper is not None else None, - evaluate_ast( - expression.step, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports) if expression.step is not None else None, ) elif isinstance(expression, ast.DictComp): - return evaluate_dictcomp( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.While): - return evaluate_while( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, (ast.Import, ast.ImportFrom)): return import_modules(expression, state, authorized_imports) elif isinstance(expression, ast.ClassDef): - return evaluate_class_def( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Try): - return evaluate_try( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Raise): - return evaluate_raise( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Assert): - return evaluate_assert( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.With): - return evaluate_with( - expression, state, static_tools, custom_tools, authorized_imports - ) + return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports) elif isinstance(expression, ast.Set): - return { - evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) - for elt in expression.elts - } + return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts} elif isinstance(expression, ast.Return): raise ReturnException( - evaluate_ast( - expression.value, state, static_tools, custom_tools, authorized_imports - ) + evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) if expression.value else None ) @@ -1488,18 +1263,12 @@ def evaluate_python_code( try: for node in expression.body: - result = evaluate_ast( - node, state, static_tools, custom_tools, authorized_imports - ) - state["print_outputs"] = truncate_content( - PRINT_OUTPUTS, max_length=max_print_outputs_length - ) + result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) + state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) is_final_answer = False return result, is_final_answer except FinalAnswerException as e: - state["print_outputs"] = truncate_content( - PRINT_OUTPUTS, max_length=max_print_outputs_length - ) + state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) is_final_answer = True return e.value, is_final_answer except InterpreterError as e: @@ -1521,9 +1290,7 @@ class LocalPythonInterpreter: if max_print_outputs_length is None: self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT self.additional_authorized_imports = additional_authorized_imports - self.authorized_imports = list( - set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) - ) + self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) # Add base trusted tools to list self.static_tools = { **tools, @@ -1531,9 +1298,7 @@ class LocalPythonInterpreter: } # TODO: assert self.authorized imports are all installed locally - def __call__( - self, code_action: str, additional_variables: Dict - ) -> Tuple[Any, str, bool]: + def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]: self.state.update(additional_variables) output, is_final_answer = evaluate_python_code( code_action, diff --git a/src/smolagents/models.py b/src/smolagents/models.py index ca234f2..f19a291 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -14,17 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, asdict import json import logging import os import random from copy import deepcopy +from dataclasses import asdict, dataclass from enum import Enum -from typing import Dict, List, Optional, Union, Any +from typing import Any, Dict, List, Optional, Union from huggingface_hub import InferenceClient - from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -35,6 +34,7 @@ from transformers import ( from .tools import Tool + logger = logging.getLogger(__name__) DEFAULT_JSONAGENT_REGEX_GRAMMAR = { @@ -100,10 +100,7 @@ class ChatMessage: def from_hf_api(cls, message) -> "ChatMessage": tool_calls = None if getattr(message, "tool_calls", None) is not None: - tool_calls = [ - ChatMessageToolCall.from_hf_api(tool_call) - for tool_call in message.tool_calls - ] + tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] return cls(role=message.role, content=message.content, tool_calls=tool_calls) @@ -172,17 +169,12 @@ def get_clean_message_list( role = message["role"] if role not in MessageRole.roles(): - raise ValueError( - f"Incorrect role {role}, only {MessageRole.roles()} are supported for now." - ) + raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") if role in role_conversions: message["role"] = role_conversions[role] - if ( - len(final_message_list) > 0 - and message["role"] == final_message_list[-1]["role"] - ): + if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: final_message_list[-1]["content"] += "\n=======\n" + message["content"] else: final_message_list.append(message) @@ -292,9 +284,7 @@ class HfApiModel(Model): Gets an LLM output message for the given list of input messages. If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. """ - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) + messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) if tools_to_call_from: response = self.client.chat.completions.create( messages=messages, @@ -367,9 +357,7 @@ class TransformersModel(Model): default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" if model_id is None: model_id = default_model_id - logger.warning( - f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" - ) + logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'") self.model_id = model_id self.kwargs = kwargs if device_map is None: @@ -389,9 +377,7 @@ class TransformersModel(Model): ) self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained( - model_id, device_map=device_map, torch_dtype=torch_dtype - ) + self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria): @@ -404,16 +390,9 @@ class TransformersModel(Model): self.stream = "" def __call__(self, input_ids, scores, **kwargs): - generated = self.tokenizer.decode( - input_ids[0][-1], skip_special_tokens=True - ) + generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True) self.stream += generated - if any( - [ - self.stream.endswith(stop_string) - for stop_string in self.stop_strings - ] - ): + if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]): return True return False @@ -426,9 +405,7 @@ class TransformersModel(Model): grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) + messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) if tools_to_call_from is not None: prompt_tensor = self.tokenizer.apply_chat_template( messages, @@ -448,9 +425,7 @@ class TransformersModel(Model): out = self.model.generate( **prompt_tensor, - stopping_criteria=( - self.make_stopping_criteria(stop_sequences) if stop_sequences else None - ), + stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None), **self.kwargs, ) generated_tokens = out[0, count_prompt_tokens:] @@ -475,9 +450,7 @@ class TransformersModel(Model): ChatMessageToolCall( id="".join(random.choices("0123456789", k=5)), type="function", - function=ChatMessageToolCallDefinition( - name=tool_name, arguments=tool_arguments - ), + function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), ) ], ) @@ -525,9 +498,7 @@ class LiteLLMModel(Model): grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, ) -> ChatMessage: - messages = get_clean_message_list( - messages, role_conversions=tool_role_conversions - ) + messages = get_clean_message_list(messages, role_conversions=tool_role_conversions) import litellm if tools_to_call_from: @@ -604,11 +575,7 @@ class OpenAIServerModel(Model): ) -> ChatMessage: messages = get_clean_message_list( messages, - role_conversions=( - self.custom_role_conversions - if self.custom_role_conversions - else tool_role_conversions - ), + role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions), ) if tools_to_call_from: response = self.client.chat.completions.create( diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index 13de796..722f25e 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -22,10 +22,7 @@ class Monitor: self.step_durations = [] self.tracked_model = tracked_model self.logger = logger - if ( - getattr(self.tracked_model, "last_input_token_count", "Not found") - != "Not found" - ): + if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found": self.total_input_token_count = 0 self.total_output_token_count = 0 @@ -48,7 +45,9 @@ class Monitor: if getattr(self.tracked_model, "last_input_token_count", None) is not None: self.total_input_token_count += self.tracked_model.last_input_token_count self.total_output_token_count += self.tracked_model.last_output_token_count - console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}" + console_outputs += ( + f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}" + ) console_outputs += "]" self.logger.log(Text(console_outputs, style="dim"), level=1) diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 821c315..960331f 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -6,6 +6,7 @@ from typing import Set from .utils import BASE_BUILTIN_MODULES + _BUILTIN_NAMES = set(vars(builtins)) @@ -141,9 +142,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: # Check that __init__ method takes no arguments if not cls.__init__.__qualname__ == "Tool.__init__": sig = inspect.signature(cls.__init__) - non_self_params = list( - [arg_name for arg_name in sig.parameters.keys() if arg_name != "self"] - ) + non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]) if len(non_self_params) > 0: errors.append( f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!" @@ -174,9 +173,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: # Check if the assignment is more complex than simple literals if not all( - isinstance( - val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set) - ) + isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)) for val in ast.walk(node.value) ): for target in node.targets: @@ -195,9 +192,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: # Run checks on all methods for node in class_node.body: if isinstance(node, ast.FunctionDef): - method_checker = MethodChecker( - class_level_checker.class_attributes, check_imports=check_imports - ) + method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports) method_checker.visit(node) errors += [f"- {node.name}: {error}" for error in method_checker.errors] diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index fc85979..57ac7b0 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -36,7 +36,6 @@ from huggingface_hub import ( upload_folder, ) from huggingface_hub.utils import RepositoryNotFoundError - from packaging import version from transformers.dynamic_module_utils import get_imports from transformers.utils import ( @@ -52,6 +51,7 @@ from .tool_validation import MethodChecker, validate_tool_attributes from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .utils import instance_to_source + logger = logging.getLogger(__name__) if is_accelerate_available(): @@ -77,9 +77,7 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) return "model" except RepositoryNotFoundError: - raise EnvironmentError( - f"`{repo_id}` does not seem to be a valid repo identifier on the Hub." - ) + raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.") except Exception: return "model" except Exception: @@ -109,9 +107,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict: properties[param_name]["nullable"] = True for param_name in signature.parameters.keys(): if signature.parameters[param_name].default != inspect.Parameter.empty: - if ( - param_name not in properties - ): # this can happen if the param has no type hint but a default value + if param_name not in properties: # this can happen if the param has no type hint but a default value properties[param_name] = {"nullable": True} return properties @@ -181,9 +177,7 @@ class Tool: f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." ) for input_name, input_content in self.inputs.items(): - assert isinstance(input_content, dict), ( - f"Input '{input_name}' should be a dictionary." - ) + assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." assert "type" in input_content and "description" in input_content, ( f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." ) @@ -348,15 +342,7 @@ class Tool: imports = [] for module in [tool_file]: imports.extend(get_imports(module)) - imports = list( - set( - [ - el - for el in imports + ["smolagents"] - if el not in sys.stdlib_module_names - ] - ) - ) + imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names])) with open(requirements_file, "w", encoding="utf-8") as f: f.write("\n".join(imports) + "\n") @@ -410,9 +396,7 @@ class Tool: print(work_dir) with open(work_dir + "/tool.py", "r") as f: print("\n".join(f.readlines())) - logger.info( - f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" - ) + logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") return upload_folder( repo_id=repo_id, commit_message=commit_message, @@ -592,9 +576,7 @@ class Tool: self.name = name self.description = description self.client = Client(space_id, hf_token=token) - space_description = self.client.view_api( - return_format="dict", print_info=False - )["named_endpoints"] + space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] # If api_name is not defined, take the first of the available APIs for this space if api_name is None: @@ -607,9 +589,7 @@ class Tool: try: space_description_api = space_description[api_name] except KeyError: - raise KeyError( - f"Could not find specified {api_name=} among available api names." - ) + raise KeyError(f"Could not find specified {api_name=} among available api names.") self.inputs = {} for parameter in space_description_api["parameters"]: @@ -683,8 +663,7 @@ class Tool: self._gradio_tool = _gradio_tool func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) self.inputs = { - key: {"type": CONVERSION_DICT[value.annotation], "description": ""} - for key, value in func_args + key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args } self.forward = self._gradio_tool.run @@ -726,9 +705,7 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ """ -def get_tool_description_with_args( - tool: Tool, description_template: Optional[str] = None -) -> str: +def get_tool_description_with_args(tool: Tool, description_template: Optional[str] = None) -> str: if description_template is None: description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE compiled_template = compile_jinja_template(description_template) @@ -748,10 +725,7 @@ def compile_jinja_template(template): raise ImportError("template requires jinja2 to be installed.") if version.parse(jinja2.__version__) < version.parse("3.1.0"): - raise ImportError( - "template requires jinja2>=3.1.0 to be installed. Your version is " - f"{jinja2.__version__}." - ) + raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.") def raise_exception(message): raise TemplateError(message) @@ -772,9 +746,7 @@ def launch_gradio_demo(tool: Tool): try: import gradio as gr except ImportError: - raise ImportError( - "Gradio should be installed in order to launch a gradio demo." - ) + raise ImportError("Gradio should be installed in order to launch a gradio demo.") TYPE_TO_COMPONENT_CLASS_MAPPING = { "image": gr.Image, @@ -791,9 +763,7 @@ def launch_gradio_demo(tool: Tool): gradio_inputs = [] for input_name, input_details in tool.inputs.items(): - input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ - input_details["type"] - ] + input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] new_component = input_gradio_component_class(label=input_name) gradio_inputs.append(new_component) @@ -922,14 +892,9 @@ class ToolCollection: ``` """ _collection = get_collection(collection_slug, token=token) - _hub_repo_ids = { - item.item_id for item in _collection.items if item.item_type == "space" - } + _hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"} - tools = { - Tool.from_hub(repo_id, token, trust_remote_code) - for repo_id in _hub_repo_ids - } + tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids} return cls(tools) @@ -986,9 +951,7 @@ def tool(tool_function: Callable) -> Tool: """ parameters = get_json_schema(tool_function)["function"] if "return" not in parameters: - raise TypeHintParsingException( - "Tool return type not found: make sure your function has a return type hint!" - ) + raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") class SimpleTool(Tool): def __init__(self, name, description, inputs, output_type, function): @@ -1007,9 +970,9 @@ def tool(tool_function: Callable) -> Tool: function=tool_function, ) original_signature = inspect.signature(tool_function) - new_parameters = [ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY) - ] + list(original_signature.parameters.values()) + new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list( + original_signature.parameters.values() + ) new_signature = original_signature.replace(parameters=new_parameters) simple_tool.forward.__signature__ = new_signature return simple_tool @@ -1082,9 +1045,7 @@ class PipelineTool(Tool): if model is None: if self.default_checkpoint is None: - raise ValueError( - "This tool does not implement a default checkpoint, you need to pass one." - ) + raise ValueError("This tool does not implement a default checkpoint, you need to pass one.") model = self.default_checkpoint if pre_processor is None: pre_processor = model @@ -1107,21 +1068,15 @@ class PipelineTool(Tool): Instantiates the `pre_processor`, `model` and `post_processor` if necessary. """ if isinstance(self.pre_processor, str): - self.pre_processor = self.pre_processor_class.from_pretrained( - self.pre_processor, **self.hub_kwargs - ) + self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) if isinstance(self.model, str): - self.model = self.model_class.from_pretrained( - self.model, **self.model_kwargs, **self.hub_kwargs - ) + self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs) if self.post_processor is None: self.post_processor = self.pre_processor elif isinstance(self.post_processor, str): - self.post_processor = self.post_processor_class.from_pretrained( - self.post_processor, **self.hub_kwargs - ) + self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) if self.device is None: if self.device_map is not None: @@ -1165,12 +1120,8 @@ class PipelineTool(Tool): encoded_inputs = self.encode(*args, **kwargs) - tensor_inputs = { - k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) - } - non_tensor_inputs = { - k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor) - } + tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)} + non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)} encoded_inputs = send_to_device(tensor_inputs, self.device) outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) diff --git a/src/smolagents/types.py b/src/smolagents/types.py index 038885f..e18de51 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -27,6 +27,7 @@ from transformers.utils import ( is_vision_available, ) + logger = logging.getLogger(__name__) if is_vision_available(): @@ -113,9 +114,7 @@ class AgentImage(AgentType, ImageType): elif isinstance(value, np.ndarray): self._tensor = torch.from_numpy(value) else: - raise TypeError( - f"Unsupported type for {self.__class__.__name__}: {type(value)}" - ) + raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") def _ipython_display_(self, include=None, exclude=None): """ @@ -264,9 +263,7 @@ if is_torch_available(): def handle_agent_input_types(*args, **kwargs): args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] - kwargs = { - k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items() - } + kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()} return args, kwargs @@ -279,9 +276,7 @@ def handle_agent_output_types(output, output_type=None): # If the class does not have defined output, then we map according to the type for _k, _v in INSTANCE_TYPE_MAPPING.items(): if isinstance(output, _k): - if ( - _k is not object - ): # avoid converting to audio if torch is not installed + if _k is not object: # avoid converting to audio if torch is not installed return _v(output) return output diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index ac4565f..196f21c 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -83,9 +83,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] - json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace( - '\\"', "'" - ) + json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") json_data = json.loads(json_blob, strict=False) return json_data except json.JSONDecodeError as e: @@ -162,9 +160,7 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: MAX_LENGTH_TRUNCATE_CONTENT = 20000 -def truncate_content( - content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT -) -> str: +def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str: if len(content) <= max_length: return content else: @@ -206,12 +202,8 @@ def is_same_method(method1, method2): source2 = get_method_source(method2) # Remove method decorators if any - source1 = "\n".join( - line for line in source1.split("\n") if not line.strip().startswith("@") - ) - source2 = "\n".join( - line for line in source2.split("\n") if not line.strip().startswith("@") - ) + source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@")) + source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@")) return source1 == source2 except (TypeError, OSError): @@ -248,9 +240,7 @@ def instance_to_source(instance, base_cls=None): for name, value in cls.__dict__.items() if not name.startswith("__") and not callable(value) - and not ( - base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value - ) + and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) } for name, value in class_attrs.items(): @@ -271,9 +261,7 @@ def instance_to_source(instance, base_cls=None): for name, func in cls.__dict__.items() if callable(func) and not ( - base_cls - and hasattr(base_cls, name) - and getattr(base_cls, name).__code__.co_code == func.__code__.co_code + base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code ) } @@ -284,9 +272,7 @@ def instance_to_source(instance, base_cls=None): first_line = method_lines[0] indent = len(first_line) - len(first_line.lstrip()) method_lines = [line[indent:] for line in method_lines] - method_source = "\n".join( - [" " + line if line.strip() else line for line in method_lines] - ) + method_source = "\n".join([" " + line if line.strip() else line for line in method_lines]) class_lines.append(method_source) class_lines.append("") diff --git a/tests/test_agents.py b/tests/test_agents.py index 4a03137..2a56a0f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -28,13 +28,13 @@ from smolagents.agents import ( ToolCallingAgent, ) from smolagents.default_tools import PythonInterpreterTool -from smolagents.tools import tool -from smolagents.types import AgentImage, AgentText from smolagents.models import ( ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, ) +from smolagents.tools import tool +from smolagents.types import AgentImage, AgentText from smolagents.utils import BASE_BUILTIN_MODULES @@ -44,9 +44,7 @@ def get_new_path(suffix="") -> str: class FakeToolCallModel: - def __call__( - self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None - ): + def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None): if len(messages) < 3: return ChatMessage( role="assistant", @@ -69,18 +67,14 @@ class FakeToolCallModel: ChatMessageToolCall( id="call_1", type="function", - function=ChatMessageToolCallDefinition( - name="final_answer", arguments={"answer": "7.2904"} - ), + function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "7.2904"}), ) ], ) class FakeToolCallModelImage: - def __call__( - self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None - ): + def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None): if len(messages) < 3: return ChatMessage( role="assistant", @@ -104,9 +98,7 @@ class FakeToolCallModelImage: ChatMessageToolCall( id="call_1", type="function", - function=ChatMessageToolCallDefinition( - name="final_answer", arguments="image.png" - ), + function=ChatMessageToolCallDefinition(name="final_answer", arguments="image.png"), ) ], ) @@ -271,17 +263,13 @@ print(result) class AgentTests(unittest.TestCase): def test_fake_single_step_code_agent(self): - agent = CodeAgent( - tools=[PythonInterpreterTool()], model=fake_code_model_single_step - ) + agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_single_step) output = agent.run("What is 2 multiplied by 3.6452?", single_step=True) assert isinstance(output, str) assert "7.2904" in output def test_fake_toolcalling_agent(self): - agent = ToolCallingAgent( - tools=[PythonInterpreterTool()], model=FakeToolCallModel() - ) + agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel()) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, str) assert "7.2904" in output @@ -301,9 +289,7 @@ class AgentTests(unittest.TestCase): """ return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png") - agent = ToolCallingAgent( - tools=[fake_image_generation_tool], model=FakeToolCallModelImage() - ) + agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage()) output = agent.run("Make me an image.") assert isinstance(output, AgentImage) assert isinstance(agent.state["image.png"], Image.Image) @@ -315,9 +301,7 @@ class AgentTests(unittest.TestCase): assert output == 7.2904 assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.logs[3].tool_calls == [ - ToolCall( - name="python_interpreter", arguments="final_answer(7.2904)", id="call_3" - ) + ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3") ] def test_additional_args_added_to_task(self): @@ -351,9 +335,7 @@ class AgentTests(unittest.TestCase): assert "Code execution failed at line 'print = 2' because of" in str(agent.logs) def test_code_agent_syntax_error_show_offending_lines(self): - agent = CodeAgent( - tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error - ) + agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) output = agent.run("What is 2 multiplied by 3.6452?") assert isinstance(output, AgentText) assert output == "got an error" @@ -391,9 +373,7 @@ class AgentTests(unittest.TestCase): def test_init_agent_with_different_toolsets(self): toolset_1 = [] agent = CodeAgent(tools=toolset_1, model=fake_code_model) - assert ( - len(agent.tools) == 1 - ) # when no tools are provided, only the final_answer tool is added by default + assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] agent = CodeAgent(tools=toolset_2, model=fake_code_model) @@ -436,9 +416,7 @@ class AgentTests(unittest.TestCase): assert "You can also give requests to team members." not in agent.system_prompt print("ok1") assert "{{managed_agents_descriptions}}" not in agent.system_prompt - assert ( - "You can also give requests to team members." in manager_agent.system_prompt - ) + assert "You can also give requests to team members." in manager_agent.system_prompt def test_code_agent_missing_import_triggers_advice_in_error_log(self): agent = CodeAgent(tools=[], model=fake_code_model_import) diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index d1adabd..68a88d3 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -136,9 +136,7 @@ class TestDocs: try: code_blocks = [ ( - block.replace( - "", os.getenv("HF_TOKEN") - ) + block.replace("", os.getenv("HF_TOKEN")) .replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY")) .replace("{your_username}", "m-ric") ) @@ -150,9 +148,7 @@ class TestDocs: except SubprocessCallException as e: pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") except Exception: - pytest.fail( - f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}" - ) + pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}") @pytest.fixture(autouse=True) def _setup(self): @@ -174,6 +170,4 @@ def pytest_generate_tests(metafunc): test_class.setup_class() # Parameterize with the markdown files - metafunc.parametrize( - "doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files] - ) + metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]) diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py index d966b84..d92b387 100644 --- a/tests/test_default_tools.py +++ b/tests/test_default_tools.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest + import pytest from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool @@ -23,14 +24,10 @@ from .test_tools import ToolTesterMixin class DefaultToolTests(unittest.TestCase): def test_visit_webpage(self): - arguments = { - "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security" - } + arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"} result = VisitWebpageTool()(arguments) assert isinstance(result, str) - assert ( - "* [About Wikipedia](/wiki/Wikipedia:About)" in result - ) # Proper wikipedia pages have an About + assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): @@ -59,12 +56,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): for _input, expected_input in zip(inputs, self.tool.inputs.values()): input_type = expected_input["type"] if isinstance(input_type, list): - _inputs.append( - [ - AGENT_TYPE_MAPPING[_input_type](_input) - for _input_type in input_type - ] - ) + _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) else: _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 873dcdc..89f4ffc 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -26,6 +26,7 @@ from smolagents.types import AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin + if is_torch_available(): import torch @@ -45,11 +46,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): def create_inputs(self): inputs_text = {"answer": "Text input"} - inputs_image = { - "answer": Image.open( - Path(get_tests_dir("fixtures")) / "000000039769.png" - ).resize((512, 512)) - } + inputs_image = {"answer": Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))} inputs_audio = {"answer": torch.Tensor(np.ones(3000))} return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} diff --git a/tests/test_models.py b/tests/test_models.py index 9921631..8faae92 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest import json +import unittest from typing import Optional -from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel +from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool class ModelTests(unittest.TestCase): @@ -33,12 +33,7 @@ class ModelTests(unittest.TestCase): """ return "The weather is UNGODLY with torrential rains and temperatures below -10°C" - assert ( - "nullable" - in models.get_json_schema(get_weather)["function"]["parameters"][ - "properties" - ]["celsius"] - ) + assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"] def test_chatmessage_has_model_dumps_json(self): message = ChatMessage("user", "Hello!") diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index bd8b148..50afde6 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -43,9 +43,7 @@ class FakeLLMModel: ChatMessageToolCall( id="fake_id", type="function", - function=ChatMessageToolCallDefinition( - name="final_answer", arguments={"answer": "image"} - ), + function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}), ) ], ) @@ -122,9 +120,7 @@ class MonitoringTester(unittest.TestCase): ) agent.run("Fake task") - self.assertEqual( - agent.monitor.total_input_token_count, 20 - ) # Should have done two monitoring callbacks + self.assertEqual(agent.monitor.total_input_token_count, 20) # Should have done two monitoring callbacks self.assertEqual(agent.monitor.total_output_token_count, 0) def test_streaming_agent_text_output(self): diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 4976c56..0e83177 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -55,10 +55,7 @@ class PythonInterpreterTester(unittest.TestCase): code = "print = '3'" with pytest.raises(InterpreterError) as e: evaluate_python_code(code, {"print": print}, state={}) - assert ( - "Cannot assign to name 'print': doing this would erase the existing tool!" - in str(e) - ) + assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e) def test_subscript_call(self): code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)""" @@ -92,9 +89,7 @@ class PythonInterpreterTester(unittest.TestCase): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) self.assertDictEqual(result, {"x": 3, "y": 5}) - self.assertDictEqual( - state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} - ) + self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) def test_evaluate_expression(self): code = "x = 3\ny = 5" @@ -110,9 +105,7 @@ class PythonInterpreterTester(unittest.TestCase): result, _ = evaluate_python_code(code, {}, state=state) # evaluate returns the value of the last assignment. assert result == "This is x: 3." - self.assertDictEqual( - state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""} - ) + self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}) def test_evaluate_if(self): code = "if x <= 3:\n y = 2\nelse:\n y = 5" @@ -153,15 +146,11 @@ class PythonInterpreterTester(unittest.TestCase): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) assert result == 5 - self.assertDictEqual( - state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} - ) + self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" state = {} - evaluate_python_code( - code, {"min": min, "print": print, "round": round}, state=state - ) + evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} def test_subscript_string_with_string_index_raises_appropriate_error(self): @@ -317,9 +306,7 @@ print(check_digits) assert result == {0: 0, 1: 1, 2: 4} code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" - result, _ = evaluate_python_code( - code, {"print": print}, state={}, authorized_imports=["pandas"] - ) + result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) assert result == {102: "b"} code = """ @@ -367,9 +354,7 @@ else: best_city = "Manhattan" best_city """ - result, _ = evaluate_python_code( - code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} - ) + result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) assert result == "Brooklyn" code = """if d > e and a < b: @@ -380,9 +365,7 @@ else: best_city = "Manhattan" best_city """ - result, _ = evaluate_python_code( - code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} - ) + result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) assert result == "Sacramento" def test_if_conditions(self): @@ -398,9 +381,7 @@ if char.isalpha(): result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == 2.0 - code = ( - "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" - ) + code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result == "lose" @@ -434,14 +415,10 @@ if char.isalpha(): # Test submodules are handled properly, thus not raising error code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" - result, _ = evaluate_python_code( - code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] - ) + result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" - result, _ = evaluate_python_code( - code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] - ) + result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) def test_additional_imports(self): code = "import numpy as np" @@ -613,9 +590,7 @@ except ValueError as e: def test_types_as_objects(self): code = "type_a = float(2); type_b = str; type_c = int" state = {} - result, is_final_answer = evaluate_python_code( - code, {"float": float, "str": str, "int": int}, state=state - ) + result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) assert result is int def test_tuple_id(self): @@ -733,9 +708,7 @@ while True: break i""" - result, is_final_answer = evaluate_python_code( - code, {"print": print, "round": round}, state={} - ) + result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={}) assert result == 3 assert not is_final_answer @@ -781,9 +754,7 @@ out = [i for sublist in all_res for i in sublist] out[:10] """ state = {} - result, is_final_answer = evaluate_python_code( - code, {"print": print, "range": range}, state=state - ) + result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state) assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] def test_pandas(self): @@ -798,9 +769,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0] parts_with_5_set_count[['Quantity', 'SetCount']].values[1] """ state = {} - result, _ = evaluate_python_code( - code, {}, state=state, authorized_imports=["pandas"] - ) + result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"]) assert np.array_equal(result, [-1, 5]) code = """ @@ -811,9 +780,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]}) # Filter the DataFrame to get only the rows with outdated atomic numbers filtered_df = df.loc[df['AtomicNumber'].isin([104])] """ - result, _ = evaluate_python_code( - code, {"print": print}, state={}, authorized_imports=["pandas"] - ) + result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) assert np.array_equal(result.values[0], [104, 1]) # Test groupby @@ -825,9 +792,7 @@ data = pd.DataFrame.from_dict([ ]) survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() """ - result, _ = evaluate_python_code( - code, {}, state={}, authorized_imports=["pandas"] - ) + result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) assert result.values[1] == 0.5 # Test loc and iloc @@ -839,11 +804,9 @@ data = pd.DataFrame.from_dict([ ]) survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean() survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean() -survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0] +survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0] """ - result, _ = evaluate_python_code( - code, {}, state={}, authorized_imports=["pandas"] - ) + result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"]) def test_starred(self): code = """ @@ -864,9 +827,7 @@ coords_barcelona = (41.3869, 2.1660) distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) """ - result, _ = evaluate_python_code( - code, {"print": print, "map": map}, state={}, authorized_imports=["math"] - ) + result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"]) assert round(result, 1) == 622395.4 def test_for(self): diff --git a/tests/test_tools.py b/tests/test_tools.py index 5b2dc0e..917bcf1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -16,7 +16,7 @@ import unittest from pathlib import Path from textwrap import dedent from typing import Dict, Optional, Union -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import mcp import numpy as np @@ -32,6 +32,7 @@ from smolagents.types import ( AgentText, ) + if is_torch_available(): import torch @@ -48,9 +49,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): if input_type == "string": inputs[input_name] = "Text input" elif input_type == "image": - inputs[input_name] = Image.open( - Path(get_tests_dir("fixtures")) / "000000039769.png" - ).resize((512, 512)) + inputs[input_name] = Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512)) elif input_type == "audio": inputs[input_name] = np.ones(3000) else: @@ -224,9 +223,7 @@ class ToolTests(unittest.TestCase): class FailTool(Tool): name = "specific" description = "test description" - inputs = { - "string_input": {"type": "string", "description": "input description"} - } + inputs = {"string_input": {"type": "string", "description": "input description"}} output_type = "string" def __init__(self, url): @@ -248,9 +245,7 @@ class ToolTests(unittest.TestCase): class FailTool(Tool): name = "specific" description = "test description" - inputs = { - "string_input": {"type": "string", "description": "input description"} - } + inputs = {"string_input": {"type": "string", "description": "input description"}} output_type = "string" def useless_method(self): @@ -269,9 +264,7 @@ class ToolTests(unittest.TestCase): class SuccessTool(Tool): name = "specific" description = "test description" - inputs = { - "string_input": {"type": "string", "description": "input description"} - } + inputs = {"string_input": {"type": "string", "description": "input description"}} output_type = "string" def useless_method(self): @@ -300,9 +293,7 @@ class ToolTests(unittest.TestCase): }, } - def forward( - self, location: str, celsius: Optional[bool] = False - ) -> str: + def forward(self, location: str, celsius: Optional[bool] = False) -> str: return "The weather is UNGODLY with torrential rains and temperatures below -10°C" GetWeatherTool() @@ -340,9 +331,7 @@ class ToolTests(unittest.TestCase): } output_type = "string" - def forward( - self, location: str, celsius: Optional[bool] = False - ) -> str: + def forward(self, location: str, celsius: Optional[bool] = False) -> str: return "The weather is UNGODLY with torrential rains and temperatures below -10°C" GetWeatherTool() @@ -410,9 +399,7 @@ def mock_smolagents_adapter(): class TestToolCollection: - def test_from_mcp( - self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter - ): + def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter): with ToolCollection.from_mcp(mock_server_parameters) as tool_collection: assert isinstance(tool_collection, ToolCollection) assert len(tool_collection.tools) == 2 @@ -440,9 +427,5 @@ class TestToolCollection: with ToolCollection.from_mcp(mcp_server_params) as tool_collection: assert len(tool_collection.tools) == 1, "Expected 1 tool" - assert tool_collection.tools[0].name == "echo_tool", ( - "Expected tool name to be 'echo_tool'" - ) - assert tool_collection.tools[0](text="Hello") == "Hello", ( - "Expected tool to echo the input text" - ) + assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'" + assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text" diff --git a/tests/test_utils.py b/tests/test_utils.py index 0a661a3..91064e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest + import pytest from smolagents.utils import parse_code_blobs diff --git a/utils/check_tests_in_ci.py b/utils/check_tests_in_ci.py index 4c55ef0..65ebca7 100644 --- a/utils/check_tests_in_ci.py +++ b/utils/check_tests_in_ci.py @@ -16,6 +16,7 @@ from pathlib import Path + ROOT = Path(__file__).parent.parent TESTS_FOLDER = ROOT / "tests" @@ -37,11 +38,7 @@ def check_tests_in_ci(): if path.name.startswith("test_") ] ci_workflow_file_content = CI_WORKFLOW_FILE.read_text() - missing_test_files = [ - test_file - for test_file in test_files - if test_file not in ci_workflow_file_content - ] + missing_test_files = [test_file for test_file in test_files if test_file not in ci_workflow_file_content] if missing_test_files: print( "❌ Some test files seem to be ignored in the CI:\n"