diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 5ff7e93..a595bed 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -56,4 +56,7 @@ jobs:
uv run pytest -sv ./tests/test_tools.py
- name: Types tests
run: |
- uv run pytest -sv ./tests/test_types.py
\ No newline at end of file
+ uv run pytest -sv ./tests/test_types.py
+ - name: Utils tests
+ run: |
+ uv run pytest -sv ./tests/test_utils.py
\ No newline at end of file
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index dcb3532..09beeb6 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -1,14 +1,188 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -e .. sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using the latest cached version of the dataset since m-ric/smolagentsbenchmark couldn't be found on the Hugging Face Hub\n",
+ "Found the latest cached dataset configuration 'default' at /Users/aymeric/.cache/huggingface/datasets/m-ric___smolagentsbenchmark/default/0.0.0/0ad5fb2293ab185eece723a4ac0e4a7188f71add (last modified on Wed Jan 8 17:50:13 2025).\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " question | \n",
+ " source | \n",
+ " true_answer | \n",
+ " true_reasoning | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " If Eliud Kipchoge could maintain his record-ma... | \n",
+ " GAIA | \n",
+ " 17 | \n",
+ " None | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " How many studio albums were published by Merce... | \n",
+ " GAIA | \n",
+ " 3 | \n",
+ " None | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Here's a fun riddle that I think you'll enjoy.... | \n",
+ " GAIA | \n",
+ " 3 | \n",
+ " None | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " My family reunion is this week, and I was assi... | \n",
+ " GAIA | \n",
+ " 2 | \n",
+ " None | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " In Emily Midkiff's June 2014 article in a jour... | \n",
+ " GAIA | \n",
+ " fluffy | \n",
+ " None | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 127 | \n",
+ " What year was the municipality of San Carlos, ... | \n",
+ " SimpleQA | \n",
+ " 1786 | \n",
+ " ['https://en.wikipedia.org/wiki/San_Carlos,_An... | \n",
+ "
\n",
+ " \n",
+ " 128 | \n",
+ " In which year was Maria Elena Walsh named Illu... | \n",
+ " SimpleQA | \n",
+ " 1985 | \n",
+ " ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... | \n",
+ "
\n",
+ " \n",
+ " 129 | \n",
+ " What is the durability of the Istarelle spear ... | \n",
+ " SimpleQA | \n",
+ " 800 | \n",
+ " ['http://demonssouls.wikidot.com/spear', 'http... | \n",
+ "
\n",
+ " \n",
+ " 130 | \n",
+ " What is the number of the executive order that... | \n",
+ " SimpleQA | \n",
+ " 7034 | \n",
+ " ['https://www.loc.gov/collections/federal-thea... | \n",
+ "
\n",
+ " \n",
+ " 131 | \n",
+ " Within plus or minus one minute, when was Marq... | \n",
+ " SimpleQA | \n",
+ " 77 | \n",
+ " ['https://www.fifa.com/fifaplus/en/match-centr... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
132 rows × 4 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " question source true_answer \\\n",
+ "0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
+ "1 How many studio albums were published by Merce... GAIA 3 \n",
+ "2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
+ "3 My family reunion is this week, and I was assi... GAIA 2 \n",
+ "4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
+ ".. ... ... ... \n",
+ "127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
+ "128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
+ "129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
+ "130 What is the number of the executive order that... SimpleQA 7034 \n",
+ "131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
+ "\n",
+ " true_reasoning \n",
+ "0 None \n",
+ "1 None \n",
+ "2 None \n",
+ "3 None \n",
+ "4 None \n",
+ ".. ... \n",
+ "127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
+ "128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
+ "129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
+ "130 ['https://www.loc.gov/collections/federal-thea... \n",
+ "131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
+ "\n",
+ "[132 rows x 4 columns]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"import datasets\n",
+ "import pandas as pd\n",
"\n",
- "eval_ds = datasets.load_dataset(\"m-ric/agents_medium_benchmark_2\")[\"train\"]"
+ "eval_ds = datasets.load_dataset(\"m-ric/smolagentsbenchmark\")[\"train\"]\n",
+ "pd.DataFrame(eval_ds)"
]
},
{
@@ -69,6 +243,8 @@
" question = example[\"question\"]\n",
" if example[\"source\"] == \"SimpleQA\":\n",
" question += \" Answer with only the final number.\"\n",
+ " if example[\"source\"] == \"MATH\":\n",
+ " question += \" Write code, not latex.\"\n",
" if question in answered_questions:\n",
" continue\n",
" start_time = time.time()\n",
@@ -223,26 +399,27 @@
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
+ " \"meta-llama/Llama-3.1-8B-Instruct\",\n",
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
"\n",
"for model_id in open_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
- " action_type = \"tool_calling\"\n",
- " agent = ToolCallingAgent(\n",
- " tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
- " model=HfApiModel(model_id),\n",
- " max_steps=10,\n",
- " )\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
+ " # action_type = \"tool_calling\"\n",
+ " # agent = ToolCallingAgent(\n",
+ " # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
+ " # model=HfApiModel(model_id),\n",
+ " # max_steps=10,\n",
+ " # )\n",
+ " # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
+ " # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=HfApiModel(model_id),\n",
- " additional_authorized_imports=[\"numpy\"],\n",
+ " additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
@@ -270,7 +447,11 @@
" print(f\"Evaluating '{model_id}'...\")\n",
" action_type = \"tool_calling\"\n",
" agent = ToolCallingAgent(\n",
- " tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
+ " tools=[\n",
+ " GoogleSearchTool(),\n",
+ " VisitWebpageTool(),\n",
+ " PythonInterpreterTool([\"numpy\", \"sympy\"]),\n",
+ " ],\n",
" model=LiteLLMModel(model_id),\n",
" max_steps=10,\n",
" )\n",
@@ -292,7 +473,38 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "output/Qwen_Qwen2.5-Coder-32B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/meta-llama_Llama-3.3-70B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 124 lines.\n",
+ "output/Qwen_Qwen2.5-72B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/anthropic_claude-3-5-sonnet-latest-tool_calling-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/meta-llama_Llama-3.3-70B-Instruct-tool_calling-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/anthropic_claude-3-5-sonnet-latest-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/Qwen_Qwen2.5-72B-Instruct-tool_calling-26-dec-2024.jsonl\n",
+ "Removed 99 lines.\n",
+ "output/HuggingFaceTB_SmolLM2-1.7B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/gpt-4o-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/meta-llama_Llama-3.1-70B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/meta-llama_Llama-3.2-3B-Instruct-code-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n",
+ "output/gpt-4o-tool_calling-26-dec-2024.jsonl\n",
+ "Removed 109 lines.\n"
+ ]
+ }
+ ],
"source": [
"# import glob\n",
"# import json\n",
@@ -307,7 +519,7 @@
"# for line in f:\n",
"# try:\n",
"# data = json.loads(line.strip())\n",
- "# if data[\"source\"] == \"SimpleQA\" and \"Answer with only the final number.\" not in data[\"question\"]:\n",
+ "# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
"# removed +=1\n",
"# else:\n",
"# filtered_lines.append(line)\n",
@@ -329,15 +541,17 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_37227/1724525657.py:154: UserWarning: Answer lists have different lengths, returning False.\n",
- " warnings.warn(\n"
+ "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_17219/1724525657.py:154: UserWarning:\n",
+ "\n",
+ "Answer lists have different lengths, returning False.\n",
+ "\n"
]
}
],
@@ -352,7 +566,7 @@
"\n",
"\n",
"def get_correct(row):\n",
- " if row[\"source\"] == \"GSM8K\":\n",
+ " if row[\"source\"] == \"MATH\":\n",
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
" if len(numbers_answer) == 0:\n",
" return False\n",
@@ -385,7 +599,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@@ -397,6 +611,16 @@
" [\"gpt-4o\", \"GSM8K\", 94.3],\n",
" [\"anthropic/claude-3-5-sonnet-latest\", \"GSM8K\", 96.4],\n",
" [\"meta-llama/Llama-3.3-70B-Instruct\", \"GSM8K\", 95.1],\n",
+ " [\n",
+ " \"meta-llama/Llama-3.3-70B-Instruct\",\n",
+ " \"MATH\",\n",
+ " 30.7,\n",
+ " ], # As per Open LLM Leaderboard for 3.1, score for 3.3 is too low. https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/?search=llama-3.1\n",
+ " [\n",
+ " \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
+ " \"MATH\",\n",
+ " 30.6,\n",
+ " ], # As per Open LLM Leaderboard for the base model, score for instruct too low. https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/?search=llama-3.1\n",
"]\n",
"\n",
"df2 = pd.DataFrame(vanilla_data, columns=[\"model_id\", \"source\", \"correct\"])\n",
@@ -412,6 +636,15 @@
").reset_index()"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 68,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pivot_df = pivot_df.loc[~pivot_df[\"source\"].isin([\"GSM8K\"])]"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -421,7 +654,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 69,
"metadata": {},
"outputs": [
{
@@ -460,101 +693,101 @@
" NaN | \n",
" \n",
" \n",
- " 1 | \n",
+ " 2 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
- " GSM8K | \n",
- " 82.9 | \n",
+ " MATH | \n",
+ " 77.5 | \n",
" NaN | \n",
"
\n",
" \n",
- " 2 | \n",
+ " 3 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
" SimpleQA | \n",
" 42.5 | \n",
" 9.1 | \n",
"
\n",
" \n",
- " 3 | \n",
+ " 4 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" GAIA | \n",
" 28.1 | \n",
" NaN | \n",
"
\n",
" \n",
- " 4 | \n",
+ " 6 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
- " GSM8K | \n",
- " 92.9 | \n",
- " NaN | \n",
+ " MATH | \n",
+ " 85.0 | \n",
+ " 30.6 | \n",
"
\n",
" \n",
- " 5 | \n",
+ " 7 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" SimpleQA | \n",
" 42.5 | \n",
" NaN | \n",
"
\n",
" \n",
- " 6 | \n",
+ " 8 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" GAIA | \n",
" 43.8 | \n",
" NaN | \n",
"
\n",
" \n",
- " 7 | \n",
+ " 10 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
- " GSM8K | \n",
- " 91.4 | \n",
- " 96.4 | \n",
+ " MATH | \n",
+ " 85.0 | \n",
+ " NaN | \n",
"
\n",
" \n",
- " 8 | \n",
+ " 11 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" SimpleQA | \n",
" 47.5 | \n",
" 28.4 | \n",
"
\n",
" \n",
- " 9 | \n",
+ " 12 | \n",
" gpt-4o | \n",
" GAIA | \n",
" 25.0 | \n",
" 9.3 | \n",
"
\n",
" \n",
- " 10 | \n",
+ " 14 | \n",
" gpt-4o | \n",
- " GSM8K | \n",
- " 91.4 | \n",
- " 94.3 | \n",
+ " MATH | \n",
+ " 77.5 | \n",
+ " NaN | \n",
"
\n",
" \n",
- " 11 | \n",
+ " 15 | \n",
" gpt-4o | \n",
" SimpleQA | \n",
" 60.0 | \n",
" 38.2 | \n",
"
\n",
" \n",
- " 12 | \n",
+ " 16 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" GAIA | \n",
" 21.9 | \n",
" NaN | \n",
"
\n",
" \n",
- " 13 | \n",
+ " 18 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
- " GSM8K | \n",
- " 95.7 | \n",
- " 95.1 | \n",
+ " MATH | \n",
+ " 82.1 | \n",
+ " 30.7 | \n",
"
\n",
" \n",
- " 14 | \n",
+ " 19 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" SimpleQA | \n",
- " 30.0 | \n",
+ " 30.9 | \n",
" NaN | \n",
"
\n",
" \n",
@@ -564,20 +797,20 @@
"text/plain": [
"type model_id source agent vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 12.5 NaN\n",
- "1 Qwen/Qwen2.5-72B-Instruct GSM8K 82.9 NaN\n",
- "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 42.5 9.1\n",
- "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 28.1 NaN\n",
- "4 Qwen/Qwen2.5-Coder-32B-Instruct GSM8K 92.9 NaN\n",
- "5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 42.5 NaN\n",
- "6 anthropic/claude-3-5-sonnet-latest GAIA 43.8 NaN\n",
- "7 anthropic/claude-3-5-sonnet-latest GSM8K 91.4 96.4\n",
- "8 anthropic/claude-3-5-sonnet-latest SimpleQA 47.5 28.4\n",
- "9 gpt-4o GAIA 25.0 9.3\n",
- "10 gpt-4o GSM8K 91.4 94.3\n",
- "11 gpt-4o SimpleQA 60.0 38.2\n",
- "12 meta-llama/Llama-3.3-70B-Instruct GAIA 21.9 NaN\n",
- "13 meta-llama/Llama-3.3-70B-Instruct GSM8K 95.7 95.1\n",
- "14 meta-llama/Llama-3.3-70B-Instruct SimpleQA 30.0 NaN"
+ "2 Qwen/Qwen2.5-72B-Instruct MATH 77.5 NaN\n",
+ "3 Qwen/Qwen2.5-72B-Instruct SimpleQA 42.5 9.1\n",
+ "4 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 28.1 NaN\n",
+ "6 Qwen/Qwen2.5-Coder-32B-Instruct MATH 85.0 30.6\n",
+ "7 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 42.5 NaN\n",
+ "8 anthropic/claude-3-5-sonnet-latest GAIA 43.8 NaN\n",
+ "10 anthropic/claude-3-5-sonnet-latest MATH 85.0 NaN\n",
+ "11 anthropic/claude-3-5-sonnet-latest SimpleQA 47.5 28.4\n",
+ "12 gpt-4o GAIA 25.0 9.3\n",
+ "14 gpt-4o MATH 77.5 NaN\n",
+ "15 gpt-4o SimpleQA 60.0 38.2\n",
+ "16 meta-llama/Llama-3.3-70B-Instruct GAIA 21.9 NaN\n",
+ "18 meta-llama/Llama-3.3-70B-Instruct MATH 82.1 30.7\n",
+ "19 meta-llama/Llama-3.3-70B-Instruct SimpleQA 30.9 NaN"
]
},
"metadata": {},
@@ -590,37 +823,129 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 84,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\\begin{array}{llcc}\n",
- "\\text{Model} & \\text{Task} & \\text{Agent} & \\text{Vanilla} \\\\\n",
- "\\hline\n",
- "\\textit{Qwen/Qwen2.5-72B-Instruct} & GAIA & 12.500 & - \\\\\n",
- "\\; & GSM8K & 82.900 & - \\\\\n",
- "\\; & SimpleQA & \\textbf{42.500} & 9.100 \\\\\n",
- "\\hline\n",
- "\\textit{Qwen/Qwen2.5-Coder-32B-Instruct} & GAIA & 28.100 & - \\\\\n",
- "\\; & GSM8K & 92.900 & - \\\\\n",
- "\\; & SimpleQA & 42.500 & - \\\\\n",
- "\\hline\n",
- "\\textit{anthropic/claude-3-5-sonnet-latest} & GAIA & 43.800 & - \\\\\n",
- "\\; & GSM8K & 91.400 & \\textbf{96.400} \\\\\n",
- "\\; & SimpleQA & \\textbf{47.500} & 28.400 \\\\\n",
- "\\hline\n",
- "gpt-4o & GAIA & \\textbf{25.000} & 9.300 \\\\\n",
- "\\; & GSM8K & 91.400 & \\textbf{94.300} \\\\\n",
- "\\; & SimpleQA & \\textbf{60.000} & 38.200 \\\\\n",
- "\\hline\n",
- "meta-llama/Llama-3.3-70B-Instruct & GAIA & 21.900 & - \\\\\n",
- "\\; & GSM8K & \\textbf{95.700} & 95.100 \\\\\n",
- "\\; & SimpleQA & 30.000 & - \\\\\n",
- "\\hline\n",
- "\\end{array}\n"
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from matplotlib.legend_handler import HandlerTuple # Added import\n",
+ "\n",
+ "# Assuming pivot_df is your original dataframe\n",
+ "models = pivot_df[\"model_id\"].unique()\n",
+ "sources = pivot_df[\"source\"].unique()\n",
+ "\n",
+ "# Create figure and axis\n",
+ "plt.style.use(\"seaborn-v0_8-white\")\n",
+ "fig, ax = plt.subplots(figsize=(15, 6))\n",
+ "\n",
+ "# Set the width of each bar group and positions of the bars\n",
+ "width = 0.15 # width of each bar\n",
+ "spacing = 0.02 # space between bars within a group\n",
+ "group_spacing = 0.2 # space between model groups\n",
+ "\n",
+ "# Calculate positions for the bars\n",
+ "num_sources = len(sources)\n",
+ "total_width_per_group = (width + spacing) * num_sources * 2 # *2 for agent and vanilla\n",
+ "x = np.arange(len(models)) * (total_width_per_group + group_spacing)\n",
+ "\n",
+ "# Plot bars for each source\n",
+ "for i, source in enumerate(sources):\n",
+ " source_data = pivot_df[pivot_df[\"source\"] == source]\n",
+ " agent_scores = [\n",
+ " source_data[source_data[\"model_id\"] == model][\"agent\"].values[0]\n",
+ " if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
+ " else np.nan\n",
+ " for model in models\n",
+ " ]\n",
+ " vanilla_scores = [\n",
+ " source_data[source_data[\"model_id\"] == model][\"vanilla\"].values[0]\n",
+ " if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
+ " else np.nan\n",
+ " for model in models\n",
+ " ]\n",
+ "\n",
+ " # Position calculation for each pair of bars\n",
+ " pos = x + i * (width * 2 + spacing)\n",
+ "\n",
+ " agent_bars = ax.bar(pos, agent_scores, width, label=f\"{source} (Agent)\", alpha=0.8)\n",
+ " vanilla_bars = ax.bar(\n",
+ " pos + width * 0.6,\n",
+ " vanilla_scores,\n",
+ " width,\n",
+ " hatch=\"////\",\n",
+ " alpha=0.5,\n",
+ " hatch_linewidth=2,\n",
+ " label=f\"{source} (Vanilla)\",\n",
+ " color=\"white\",\n",
+ " edgecolor=agent_bars[0].get_facecolor(),\n",
+ " )\n",
+ "\n",
+ "# Customize the plot\n",
+ "ax.set_ylabel(\"Score\")\n",
+ "ax.set_title(\"Model Performance Comparison\")\n",
+ "\n",
+ "# Set x-axis ticks in the middle of each group\n",
+ "group_centers = x + (total_width_per_group - spacing) / 2\n",
+ "ax.set_xticks(group_centers)\n",
+ "\n",
+ "# Wrap long model names to prevent overlap\n",
+ "wrapped_labels = [\"\\n\".join(model.split(\"/\")) for model in models]\n",
+ "ax.set_xticklabels(wrapped_labels, rotation=0, ha=\"center\")\n",
+ "\n",
+ "# Modify legend to combine agent and vanilla entries\n",
+ "handles, labels = ax.get_legend_handles_labels()\n",
+ "unique_sources = sources\n",
+ "legend_elements = [\n",
+ " (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n",
+ " for i in range(len(unique_sources))\n",
+ "]\n",
+ "custom_legend = ax.legend(\n",
+ " [\n",
+ " (agent_handle, vanilla_handle)\n",
+ " for agent_handle, vanilla_handle, _ in legend_elements\n",
+ " ],\n",
+ " [label for _, _, label in legend_elements],\n",
+ " handler_map={tuple: HandlerTuple(ndivide=None)},\n",
+ " bbox_to_anchor=(1.05, 1),\n",
+ " loc=\"upper left\",\n",
+ ")\n",
+ "\n",
+ "ax.yaxis.grid(True, linestyle=\"--\", alpha=0.3)\n",
+ "ax.set_ylim(bottom=0)\n",
+ "plt.tight_layout()\n",
+ "ax.spines[\"top\"].set_visible(False)\n",
+ "ax.spines[\"right\"].set_visible(False)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'formatted_df' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[12], line 45\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mathjax_table\n\u001b[1;32m 44\u001b[0m \u001b[38;5;66;03m# Usage (after running your previous data processing code):\u001b[39;00m\n\u001b[0;32m---> 45\u001b[0m mathjax_table \u001b[38;5;241m=\u001b[39m create_mathjax_table(pivot_df, \u001b[43mformatted_df\u001b[49m)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28mprint\u001b[39m(mathjax_table)\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'formatted_df' is not defined"
]
}
],
@@ -676,7 +1001,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "compare-agents",
+ "display_name": "test",
"language": "python",
"name": "python3"
},
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index eac9231..05be772 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -26,7 +26,11 @@ from rich.text import Text
from .default_tools import FinalAnswerTool
from .e2b_executor import E2BExecutor
-from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
+from .local_python_executor import (
+ BASE_BUILTIN_MODULES,
+ LocalPythonInterpreter,
+ fix_final_answer_code,
+)
from .models import MessageRole
from .monitoring import Monitor
from .prompts import (
@@ -895,7 +899,6 @@ class CodeAgent(MultiStepAgent):
)
log_entry.llm_output = llm_output
except Exception as e:
- console.print_exception()
raise AgentGenerationError(f"Error in generating model output:\n{e}")
if self.verbose:
@@ -917,10 +920,11 @@ class CodeAgent(MultiStepAgent):
# Parse
try:
- code_action = parse_code_blob(llm_output)
+ code_action = fix_final_answer_code(parse_code_blob(llm_output))
except Exception as e:
- console.print_exception()
- error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
+ error_msg = (
+ f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
+ )
raise AgentParsingError(error_msg)
log_entry.tool_call = ToolCall(
@@ -944,8 +948,9 @@ class CodeAgent(MultiStepAgent):
)
)
observation = ""
+ is_final_answer = False
try:
- output, execution_logs = self.python_executor(
+ output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
)
@@ -976,12 +981,6 @@ class CodeAgent(MultiStepAgent):
observation += "Last output from code snippet:\n" + truncated_output
log_entry.observations = observation
- is_final_answer = False
- for line in code_action.split("\n"):
- if line[: len("final_answer")] == "final_answer":
- is_final_answer = True
- break
-
execution_outputs_console += [
Text(
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py
index 5628930..5959cda 100644
--- a/src/smolagents/default_tools.py
+++ b/src/smolagents/default_tools.py
@@ -112,7 +112,7 @@ class PythonInterpreterTool(Tool):
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
- )
+ )[0] # The second element is boolean is_final_answer
)
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
except Exception as e:
diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py
index 6c7cb4a..17095a1 100644
--- a/src/smolagents/local_python_executor.py
+++ b/src/smolagents/local_python_executor.py
@@ -18,6 +18,7 @@ import ast
import builtins
import difflib
import math
+import re
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -129,6 +130,34 @@ def get_iterable(obj):
raise InterpreterError("Object is not iterable")
+def fix_final_answer_code(code: str) -> str:
+ """
+ Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool.
+ This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable,
+ while preserving function calls to final_answer().
+ """
+ # First, find if there's a direct assignment to final_answer
+ # Use word boundary and negative lookbehind to ensure it's not an object attribute
+ assignment_pattern = r"(? 1
- ): # Check for user-defined classes
- # Instantiate the class using its constructor
- obj = func.__new__(func) # Create a new instance of the class
- if hasattr(obj, "__init__"): # Check if the class has an __init__ method
- obj.__init__(*args, **kwargs) # Call the __init__ method correctly
- return obj
- else:
- if func_name == "super":
- if not args:
- if "__class__" in state and "self" in state:
- return super(state["__class__"], state["self"])
- else:
- raise InterpreterError("super() needs at least one argument")
- cls = args[0]
- if not isinstance(cls, type):
- raise InterpreterError("super() argument 1 must be type")
- if len(args) == 1:
- return super(cls)
- elif len(args) == 2:
- instance = args[1]
- return super(cls, instance)
+ if func_name == "super":
+ if not args:
+ if "__class__" in state and "self" in state:
+ return super(state["__class__"], state["self"])
else:
- raise InterpreterError("super() takes at most 2 arguments")
+ raise InterpreterError("super() needs at least one argument")
+ cls = args[0]
+ if not isinstance(cls, type):
+ raise InterpreterError("super() argument 1 must be type")
+ if len(args) == 1:
+ return super(cls)
+ elif len(args) == 2:
+ instance = args[1]
+ return super(cls, instance)
else:
- if func_name == "print":
- output = " ".join(map(str, args))
- global PRINT_OUTPUTS
- PRINT_OUTPUTS += output + "\n"
- # cap the number of lines
- return None
- else: # Assume it's a callable object
- output = func(*args, **kwargs)
- return output
+ raise InterpreterError("super() takes at most 2 arguments")
+ else:
+ if func_name == "print":
+ output = " ".join(map(str, args))
+ global PRINT_OUTPUTS
+ PRINT_OUTPUTS += output + "\n"
+ # cap the number of lines
+ return None
+ else: # Assume it's a callable object
+ return func(*args, **kwargs)
def evaluate_subscript(subscript, state, static_tools, custom_tools):
@@ -990,6 +1013,11 @@ def truncate_print_outputs(
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
+class FinalAnswerException(Exception):
+ def __init__(self, value):
+ self.value = value
+
+
def evaluate_python_code(
code: str,
static_tools: Optional[Dict[str, Callable]] = None,
@@ -1029,6 +1057,12 @@ def evaluate_python_code(
PRINT_OUTPUTS = ""
global OPERATIONS_COUNT
OPERATIONS_COUNT = 0
+
+ def final_answer(value):
+ raise FinalAnswerException(value)
+
+ static_tools["final_answer"] = final_answer
+
try:
for node in expression.body:
result = evaluate_ast(
@@ -1037,7 +1071,14 @@ def evaluate_python_code(
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
- return result
+ is_final_answer = False
+ return result, is_final_answer
+ except FinalAnswerException as e:
+ state["print_outputs"] = truncate_content(
+ PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
+ )
+ is_final_answer = True
+ return e.value, is_final_answer
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
@@ -1059,9 +1100,11 @@ class LocalPythonInterpreter:
}
# TODO: assert self.authorized imports are all installed locally
- def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]:
+ def __call__(
+ self, code_action: str, additional_variables: Dict
+ ) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
- output = evaluate_python_code(
+ output, is_final_answer = evaluate_python_code(
code_action,
static_tools=self.static_tools,
custom_tools=self.custom_tools,
@@ -1069,7 +1112,7 @@ class LocalPythonInterpreter:
authorized_imports=self.authorized_imports,
)
logs = self.state["print_outputs"]
- return output, logs
+ return output, logs, is_final_answer
__all__ = ["evaluate_python_code", "LocalPythonInterpreter"]
diff --git a/src/smolagents/prompts.py b/src/smolagents/prompts.py
index e85ab19..af68b27 100644
--- a/src/smolagents/prompts.py
+++ b/src/smolagents/prompts.py
@@ -373,7 +373,7 @@ Here are the rules you should always follow to solve your task:
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
-7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
+7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py
index 902ebb7..fe006fd 100644
--- a/src/smolagents/utils.py
+++ b/src/smolagents/utils.py
@@ -106,26 +106,35 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str:
- try:
- pattern = r"```(?:py|python)?\n(.*?)\n```"
- match = re.search(pattern, code_blob, re.DOTALL)
- if match is None:
- raise ValueError(
- f"No match ground for regex pattern {pattern} in {code_blob=}."
- )
- return match.group(1).strip()
+ """Parses the LLM's output to get any code blob inside. Will retrun the code directly if it's code."""
+ pattern = r"```(?:py|python)?\n(.*?)\n```"
+ match = re.search(pattern, code_blob, re.DOTALL)
+ if match is None:
+ try: # Maybe the LLM outputted a code blob directly
+ ast.parse(code_blob)
+ return code_blob
+ except SyntaxError:
+ pass
- except Exception as e:
+ if "final" in code_blob and "answer" in code_blob:
+ raise ValueError(
+ f"""
+The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. It seems like you're trying to return the final answer, you can do it as follows:
+Code:
+```py
+final_answer("YOUR FINAL ANSWER HERE")
+```""".strip()
+ )
raise ValueError(
f"""
-The code blob you used is invalid: due to the following error: {e}
-This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
+The code blob is invalid, because the regex pattern {pattern} was not found in {code_blob=}. Make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
Code:
```py
# Your python code here
-```"""
+```""".strip()
)
+ return match.group(1).strip()
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 06d9554..2d666e6 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -444,3 +444,18 @@ final_answer("Final report.")
report = manager_toolcalling_agent.run("Fake question.")
assert report == "Final report."
+
+ def test_code_nontrivial_final_answer_works(self):
+ def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
+ return """Code:
+```py
+def nested_answer():
+ final_answer("Correct!")
+
+nested_answer()
+```"""
+
+ agent = CodeAgent(tools=[], model=fake_code_model_final_answer)
+
+ output = agent.run("Count to 3")
+ assert output == "Correct!"
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 3508161..5f4ffc4 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -23,6 +23,7 @@ from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
+ fix_final_answer_code,
)
from smolagents.types import AGENT_TYPE_MAPPING
@@ -79,19 +80,19 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"
state = {}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
code = "x = y"
state = {"y": 5}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
code = "a=1;b=None"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
# evaluate returns the value of the last assignment.
assert result is None
@@ -107,7 +108,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_call(self):
code = "y = add_two(x)"
state = {"x": 3}
- result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@@ -119,14 +120,14 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_constant(self):
code = "x = 3"
state = {}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
def test_evaluate_dict(self):
code = "test_dict = {'x': x, 'y': add_two(x)}"
state = {"x": 3}
- result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@@ -135,7 +136,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
state = {}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
@@ -143,7 +144,7 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_f_string(self):
code = "text = f'This is x: {x}.'"
state = {"x": 3}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(
@@ -153,13 +154,13 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 2
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
state = {"x": 8}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
@@ -167,27 +168,27 @@ class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_list(self):
code = "test_list = [x, add_two(x)]"
state = {"x": 3}
- result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5])
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
def test_evaluate_name(self):
code = "y = x"
state = {"x": 3}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
def test_evaluate_subscript(self):
code = "test_list = [x, add_two(x)]\ntest_list[1]"
state = {"x": 3}
- result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3}
- result = evaluate_python_code(code, {"add_two": add_two}, state=state)
+ result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
@@ -215,14 +216,14 @@ for result in search_results:
def test_evaluate_for(self):
code = "x = 0\nfor i in range(3):\n x = i"
state = {}
- result = evaluate_python_code(code, {"range": range}, state=state)
+ result, _ = evaluate_python_code(code, {"range": range}, state=state)
assert result == 2
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
def test_evaluate_binop(self):
code = "y + x"
state = {"x": 3, "y": 6}
- result = evaluate_python_code(code, {}, state=state)
+ result, _ = evaluate_python_code(code, {}, state=state)
assert result == 9
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
@@ -234,27 +235,27 @@ def recur_fibo(n):
else:
return(recur_fibo(n-1) + recur_fibo(n-2))
recur_fibo(6)"""
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == 8
def test_evaluate_string_methods(self):
code = "'hello'.replace('h', 'o').split('e')"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == ["o", "llo"]
def test_evaluate_slicing(self):
code = "'hello'[1:3][::-1]"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == "le"
def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result is int
def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == "t-h-e-s-e-a-g-u-l-l"
def test_string_indexing(self):
@@ -267,12 +268,12 @@ for block in text_block:
for col in range(len(text_block[0])):
sentence += block[col]
"""
- result = evaluate_python_code(code, {"len": len, "range": range}, state={})
+ result, _ = evaluate_python_code(code, {"len": len, "range": range}, state={})
assert result == "THESEAGULL"
def test_tuples(self):
code = "x = (1, 2, 3)\nx[1]"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == 2
code = """
@@ -325,35 +326,35 @@ print(check_digits)
def test_listcomp(self):
code = "x = [i for i in range(3)]"
- result = evaluate_python_code(code, {"range": range}, state={})
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == [0, 1, 2]
def test_break_continue(self):
code = "for i in range(10):\n if i == 5:\n break\ni"
- result = evaluate_python_code(code, {"range": range}, state={})
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 5
code = "for i in range(10):\n if i == 5:\n continue\ni"
- result = evaluate_python_code(code, {"range": range}, state={})
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == 9
def test_call_int(self):
code = "import math\nstr(math.ceil(149))"
- result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
+ result, _ = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
assert result == "149"
def test_lambda(self):
code = "f = lambda x: x + 2\nf(3)"
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == 5
def test_dictcomp(self):
code = "x = {i: i**2 for i in range(3)}"
- result = evaluate_python_code(code, {"range": range}, state={})
+ result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert result == {102: "b"}
@@ -362,17 +363,17 @@ print(check_digits)
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
- result = evaluate_python_code(code, {}, state={})
+ result, _ = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self):
code = "a, b = 0, 1\nb"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1
def test_while(self):
code = "i = 0\nwhile i < 3:\n i += 1\ni"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 3
# test infinite loop
@@ -393,7 +394,7 @@ while i < n and house_positions[i] <= loc:
def test_generator(self):
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [1, 4, 9, 16, 25]
def test_boolops(self):
@@ -403,7 +404,7 @@ else:
best_city = "Manhattan"
best_city
"""
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Brooklyn"
@@ -416,7 +417,7 @@ else:
best_city = "Manhattan"
best_city
"""
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Sacramento"
@@ -431,51 +432,51 @@ if char.isalpha():
def test_imports(self):
code = "import math\nmath.sqrt(4)"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0
code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
code = "import time, re\ntime.sleep(0.1)"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 1
code = "import itertools\nlist(itertools.islice(range(10), 3))"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == [0, 1, 2]
code = "import re\nre.search('a', 'abc').group()"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "a"
code = "import stat\nstat.S_ISREG(0o100644)"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.8
code = "import unicodedata\nunicodedata.name('A')"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
@@ -491,25 +492,25 @@ if char.isalpha():
def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 1 < 4 and 0 <= -5 < 4"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 4 < 4 and 0 <= 3 < 4"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 3 < 4 and 0 <= 3 < 4"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result
def test_print_output(self):
code = "print('Hello world!')\nprint('Ok no one cares')"
state = {}
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result is None
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
@@ -525,7 +526,7 @@ function()"""
def test_tuple_target_in_iterator(self):
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
- result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "Samuel"
def test_classes(self):
@@ -618,7 +619,7 @@ def var_args_method(self, *args, **kwargs):
var_args_method(1, 2, 3, x=4, y=5)
"""
state = {}
- result = evaluate_python_code(code, {"sum": sum}, state=state)
+ result, _ = evaluate_python_code(code, {"sum": sum}, state=state)
assert result == 15
def test_exceptions(self):
@@ -648,7 +649,7 @@ except ValueError as e:
def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
- result = evaluate_python_code(
+ result, is_final_answer = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state
)
assert result is int
@@ -659,7 +660,7 @@ food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
"""
state = {}
- result = evaluate_python_code(code, {}, state=state)
+ result, is_final_answer = evaluate_python_code(code, {}, state=state)
assert result == ["orange", "pear"]
def test_nonsimple_augassign(self):
@@ -742,8 +743,9 @@ def f(a, b=333, n=1000):
return b + n
n = f(1, n=667)
"""
- res = evaluate_python_code(code, {}, {})
+ res, is_final_answer = evaluate_python_code(code, {}, {})
assert res == 1000
+ assert not is_final_answer
def test_set(self):
code = """
@@ -767,8 +769,11 @@ while True:
break
i"""
- result = evaluate_python_code(code, {"print": print, "round": round}, state={})
+ result, is_final_answer = evaluate_python_code(
+ code, {"print": print, "round": round}, state={}
+ )
assert result == 3
+ assert not is_final_answer
def test_return(self):
# test early returns
@@ -781,7 +786,7 @@ def add_one(n, shift):
add_one(1, 1)
"""
state = {}
- result = evaluate_python_code(
+ result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result == 2
@@ -794,7 +799,7 @@ def returns_none(a):
returns_none(1)
"""
state = {}
- result = evaluate_python_code(
+ result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result is None
@@ -812,7 +817,7 @@ out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
- result = evaluate_python_code(
+ result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range}, state=state
)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
@@ -829,7 +834,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"]
)
assert np.array_equal(result, [-1, 5])
@@ -842,7 +847,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert np.array_equal(result.values[0], [104, 1])
@@ -855,7 +860,9 @@ data = pd.DataFrame.from_dict([
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
- result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
+ result, _ = evaluate_python_code(
+ code, {}, state={}, authorized_imports=["pandas"]
+ )
assert result.values[1] == 0.5
def test_starred(self):
@@ -877,7 +884,7 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
- result = evaluate_python_code(
+ result, _ = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
assert round(result, 1) == 622395.4
@@ -894,5 +901,42 @@ for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
- result = evaluate_python_code(code, {"print": print, "map": map}, state={})
+ result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
+
+ def test_fix_final_answer_code(self):
+ test_cases = [
+ (
+ "final_answer = 3.21\nfinal_answer(final_answer)",
+ "final_answer_variable = 3.21\nfinal_answer(final_answer_variable)",
+ ),
+ (
+ "x = final_answer(5)\nfinal_answer = x + 1\nfinal_answer(final_answer)",
+ "x = final_answer(5)\nfinal_answer_variable = x + 1\nfinal_answer(final_answer_variable)",
+ ),
+ (
+ "def func():\n final_answer = 42\n return final_answer(final_answer)",
+ "def func():\n final_answer_variable = 42\n return final_answer(final_answer_variable)",
+ ),
+ (
+ "final_answer(5) # Should not change function calls",
+ "final_answer(5) # Should not change function calls",
+ ),
+ (
+ "obj.final_answer = 5 # Should not change object attributes",
+ "obj.final_answer = 5 # Should not change object attributes",
+ ),
+ (
+ "final_answer=3.21;final_answer(final_answer)",
+ "final_answer_variable=3.21;final_answer(final_answer_variable)",
+ ),
+ ]
+
+ for i, (input_code, expected) in enumerate(test_cases, 1):
+ result = fix_final_answer_code(input_code)
+ assert result == expected, f"""
+ Test case {i} failed:
+ Input: {input_code}
+ Expected: {expected}
+ Got: {result}
+ """
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4bd0f81..1ec6343 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,86 +1,39 @@
-import os
-import shutil
-import tempfile
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import unittest
-from pathlib import Path
+import pytest
+
+from smolagents.utils import parse_code_blob
-def str_to_bool(value) -> int:
- """
- Converts a string representation of truth to `True` (1) or `False` (0).
+class AgentTextTests(unittest.TestCase):
+ def test_parse_code_blob(self):
+ with pytest.raises(ValueError):
+ parse_code_blob("Wrong blob!")
- True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
- """
- value = value.lower()
- if value in ("y", "yes", "t", "true", "on", "1"):
- return 1
- elif value in ("n", "no", "f", "false", "off", "0"):
- return 0
- else:
- raise ValueError(f"invalid truth value {value}")
+ # Parsing mardkwon with code blobs should work
+ output = parse_code_blob("""
+Here is how to solve the problem:
+Code:
+```py
+import numpy as np
+```
+""")
+ assert output == "import numpy as np"
-
-def get_int_from_env(env_keys, default):
- """Returns the first positive env value found in the `env_keys` list or the default."""
- for e in env_keys:
- val = int(os.environ.get(e, -1))
- if val >= 0:
- return val
- return default
-
-
-def parse_flag_from_env(key, default=False):
- """Returns truthy value for `key` from the env if available else the default."""
- value = os.environ.get(key, str(default))
- return (
- str_to_bool(value) == 1
- ) # As its name indicates `str_to_bool` actually returns an int...
-
-
-_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
-
-
-def skip(test_case):
- "Decorator that skips a test unconditionally"
- return unittest.skip("Test was skipped")(test_case)
-
-
-def slow(test_case):
- """
- Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
- truthy value to run them.
- """
- return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
-
-
-class TempDirTestCase(unittest.TestCase):
- """
- A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
- data at the start of a test, and then destroyes it at the end of the TestCase.
-
- Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
-
- The temporary directory location will be stored in `self.tmpdir`
- """
-
- clear_on_setup = True
-
- @classmethod
- def setUpClass(cls):
- "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
- cls.tmpdir = Path(tempfile.mkdtemp())
-
- @classmethod
- def tearDownClass(cls):
- "Remove `cls.tmpdir` after test suite has finished"
- if os.path.exists(cls.tmpdir):
- shutil.rmtree(cls.tmpdir)
-
- def setUp(self):
- "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
- if self.clear_on_setup:
- for path in self.tmpdir.glob("**/*"):
- if path.is_file():
- path.unlink()
- elif path.is_dir():
- shutil.rmtree(path)
+ # Parsing code blobs should work
+ code_blob = "import numpy as np"
+ output = parse_code_blob(code_blob)
+ assert output == code_blob