smolagents/examples/benchmark.ipynb

1018 lines
111 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install -e .. sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using the latest cached version of the dataset since m-ric/smolagentsbenchmark couldn't be found on the Hugging Face Hub\n",
"Found the latest cached dataset configuration 'default' at /Users/aymeric/.cache/huggingface/datasets/m-ric___smolagentsbenchmark/default/0.0.0/0ad5fb2293ab185eece723a4ac0e4a7188f71add (last modified on Wed Jan 8 17:50:13 2025).\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>source</th>\n",
" <th>true_answer</th>\n",
" <th>true_reasoning</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>If Eliud Kipchoge could maintain his record-ma...</td>\n",
" <td>GAIA</td>\n",
" <td>17</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>How many studio albums were published by Merce...</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Here's a fun riddle that I think you'll enjoy....</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>My family reunion is this week, and I was assi...</td>\n",
" <td>GAIA</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>In Emily Midkiff's June 2014 article in a jour...</td>\n",
" <td>GAIA</td>\n",
" <td>fluffy</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>127</th>\n",
" <td>What year was the municipality of San Carlos, ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1786</td>\n",
" <td>['https://en.wikipedia.org/wiki/San_Carlos,_An...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>128</th>\n",
" <td>In which year was Maria Elena Walsh named Illu...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1985</td>\n",
" <td>['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>129</th>\n",
" <td>What is the durability of the Istarelle spear ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>800</td>\n",
" <td>['http://demonssouls.wikidot.com/spear', 'http...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>130</th>\n",
" <td>What is the number of the executive order that...</td>\n",
" <td>SimpleQA</td>\n",
" <td>7034</td>\n",
" <td>['https://www.loc.gov/collections/federal-thea...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131</th>\n",
" <td>Within plus or minus one minute, when was Marq...</td>\n",
" <td>SimpleQA</td>\n",
" <td>77</td>\n",
" <td>['https://www.fifa.com/fifaplus/en/match-centr...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>132 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" question source true_answer \\\n",
"0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
"1 How many studio albums were published by Merce... GAIA 3 \n",
"2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
"3 My family reunion is this week, and I was assi... GAIA 2 \n",
"4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
".. ... ... ... \n",
"127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
"128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
"129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
"130 What is the number of the executive order that... SimpleQA 7034 \n",
"131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
"\n",
" true_reasoning \n",
"0 None \n",
"1 None \n",
"2 None \n",
"3 None \n",
"4 None \n",
".. ... \n",
"127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
"128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
"129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
"130 ['https://www.loc.gov/collections/federal-thea... \n",
"131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
"\n",
"[132 rows x 4 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import datasets\n",
"import pandas as pd\n",
"\n",
"eval_ds = datasets.load_dataset(\"m-ric/smolagentsbenchmark\")[\"train\"]\n",
"pd.DataFrame(eval_ds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define utilities and tools\n",
"To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import json\n",
"import os\n",
"import re\n",
"import string\n",
"import warnings\n",
"from tqdm import tqdm\n",
"from typing import List\n",
"\n",
"from smolagents import (\n",
" GoogleSearchTool,\n",
" CodeAgent,\n",
" ToolCallingAgent,\n",
" HfApiModel,\n",
" AgentError,\n",
" VisitWebpageTool,\n",
" PythonInterpreterTool,\n",
")\n",
"from smolagents.agents import ActionStep\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"os.makedirs(\"output\", exist_ok=True)\n",
"\n",
"\n",
"def serialize_agent_error(obj):\n",
" if isinstance(obj, AgentError):\n",
" return {\"error_type\": obj.__class__.__name__, \"message\": obj.message}\n",
" else:\n",
" return str(obj)\n",
"\n",
"\n",
"def answer_questions(\n",
" eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n",
"):\n",
" answered_questions = []\n",
" if os.path.exists(file_name):\n",
" with open(file_name, \"r\") as f:\n",
" for line in f:\n",
" answered_questions.append(json.loads(line)[\"question\"])\n",
"\n",
" for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n",
" try:\n",
" question = example[\"question\"]\n",
" if example[\"source\"] == \"SimpleQA\":\n",
" question += \" Answer with only the final number.\"\n",
" if example[\"source\"] == \"MATH\":\n",
" question += \" Write code, not latex.\"\n",
" if question in answered_questions:\n",
" continue\n",
" start_time = time.time()\n",
"\n",
" if is_vanilla_llm:\n",
" llm = agent\n",
" answer = llm([{\"role\": \"user\", \"content\": question}])\n",
" token_count = llm.last_input_token_count + llm.last_output_token_count\n",
" intermediate_steps = []\n",
" else:\n",
" answer = agent.run(question)\n",
" token_count = agent.monitor.get_total_token_counts()\n",
" intermediate_steps = str(agent.logs)\n",
" # Remove memory from logs to make them more compact.\n",
" for step in agent.logs:\n",
" if isinstance(step, ActionStep):\n",
" step.agent_memory = None\n",
"\n",
" end_time = time.time()\n",
" annotated_example = {\n",
" \"model_id\": model_id,\n",
" \"agent_action_type\": action_type,\n",
" \"question\": question,\n",
" \"answer\": answer,\n",
" \"true_answer\": example[\"true_answer\"],\n",
" \"source\": example[\"source\"],\n",
" \"intermediate_steps\": intermediate_steps,\n",
" \"start_time\": start_time,\n",
" \"end_time\": end_time,\n",
" \"token_counts\": token_count,\n",
" }\n",
"\n",
" with open(file_name, \"a\") as f:\n",
" json.dump(annotated_example, f, default=serialize_agent_error)\n",
" f.write(\"\\n\") # add a newline for JSONL format\n",
" except Exception as e:\n",
" print(\"Failed:\", e)\n",
"\n",
"\n",
"def normalize_number_str(number_str: str) -> float:\n",
" # we replace these common units and commas to allow\n",
" # conversion to float\n",
" for char in [\"$\", \"%\", \",\"]:\n",
" number_str = number_str.replace(char, \"\")\n",
" try:\n",
" return float(number_str)\n",
" except ValueError:\n",
" return float(\"inf\")\n",
"\n",
"\n",
"def split_string(\n",
" s: str,\n",
" char_list: list[str] = [\",\", \";\"],\n",
") -> list[str]:\n",
" pattern = f\"[{''.join(char_list)}]\"\n",
" return re.split(pattern, s)\n",
"\n",
"\n",
"def is_float(element: any) -> bool:\n",
" try:\n",
" float(element)\n",
" return True\n",
" except ValueError:\n",
" return False\n",
"\n",
"\n",
"def normalize_str(input_str, remove_punct=True) -> str:\n",
" \"\"\"\n",
" Normalize a string by:\n",
" - Removing all white spaces\n",
" - Optionally removing punctuation (if remove_punct is True)\n",
" - Converting to lowercase\n",
" Parameters:\n",
" - input_str: str, the string to normalize\n",
" - remove_punct: bool, whether to remove punctuation (default: True)\n",
" Returns:\n",
" - str, the normalized string\n",
" \"\"\"\n",
" # Remove all white spaces. Required e.g for seagull vs. sea gull\n",
" no_spaces = re.sub(r\"\\s\", \"\", input_str)\n",
"\n",
" # Remove punctuation, if specified.\n",
" if remove_punct:\n",
" translator = str.maketrans(\"\", \"\", string.punctuation)\n",
" return no_spaces.lower().translate(translator)\n",
" else:\n",
" return no_spaces.lower()\n",
"\n",
"\n",
"def extract_numbers(text: str) -> List[str]:\n",
" \"\"\"This pattern matches:\n",
" - Optional negative sign\n",
" - Numbers with optional comma thousand separators\n",
" - Optional decimal points with decimal numbers\n",
" \"\"\"\n",
" pattern = r\"-?(?:\\d{1,3}(?:,\\d{3})+|\\d+)(?:\\.\\d+)?\"\n",
"\n",
" return [el.replace(\",\", \"\") for el in re.findall(pattern, text)]\n",
"\n",
"\n",
"def get_question_score_gaia(\n",
" model_answer: str,\n",
" ground_truth: str,\n",
") -> bool:\n",
" if is_float(ground_truth):\n",
" normalized_answer = normalize_number_str(str(model_answer))\n",
" return normalized_answer == float(ground_truth)\n",
"\n",
" elif any(char in ground_truth for char in [\",\", \";\"]): # if gt is a list\n",
" # question with the fish: normalization removes punct\n",
" gt_elems = split_string(ground_truth)\n",
" ma_elems = split_string(model_answer)\n",
"\n",
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
" warnings.warn(\n",
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
" )\n",
" return False\n",
"\n",
" comparisons = []\n",
" for ma_elem, gt_elem in zip(\n",
" ma_elems, gt_elems\n",
" ): # compare each element as float or str\n",
" if is_float(gt_elem):\n",
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
" else:\n",
" # we do not remove punct since comparisons can include punct\n",
" comparisons.append(\n",
" normalize_str(ma_elem, remove_punct=False)\n",
" == normalize_str(gt_elem, remove_punct=False)\n",
" )\n",
" return all(comparisons)\n",
"\n",
" else: # if gt is a str\n",
" return normalize_str(model_answer) == normalize_str(ground_truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark agents\n",
"\n",
"### Open models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
" # \"Qwen/QwQ-32B-Preview\",\n",
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
" \"meta-llama/Llama-3.1-8B-Instruct\",\n",
" \"mistralai/Mistral-Nemo-Instruct-2407\",\n",
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
"\n",
"for model_id in open_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
" # action_type = \"tool_calling\"\n",
" # agent = ToolCallingAgent(\n",
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
" # model=HfApiModel(model_id),\n",
" # max_steps=10,\n",
" # )\n",
" # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=HfApiModel(model_id),\n",
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n",
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Closed models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n",
"for model_id in litellm_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
" action_type = \"tool_calling\"\n",
" agent = ToolCallingAgent(\n",
" tools=[\n",
" GoogleSearchTool(),\n",
" VisitWebpageTool(),\n",
" PythonInterpreterTool([\"numpy\", \"sympy\"]),\n",
" ],\n",
" model=LiteLLMModel(model_id),\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=LiteLLMModel(model_id),\n",
" additional_authorized_imports=[\"numpy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n",
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# import glob\n",
"# import json\n",
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
"\n",
"# for file_path in jsonl_files:\n",
"# print(file_path)\n",
"# # Read all lines and filter out SimpleQA sources\n",
"# filtered_lines = []\n",
"# removed = 0\n",
"# with open(file_path, 'r', encoding='utf-8') as f:\n",
"# for line in f:\n",
"# try:\n",
"# data = json.loads(line.strip())\n",
"# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
"# removed +=1\n",
"# else:\n",
"# filtered_lines.append(line)\n",
"# except json.JSONDecodeError:\n",
"# print(\"Invalid line:\", line)\n",
"# continue # Skip invalid JSON lines\n",
"# print(f\"Removed {removed} lines.\")\n",
"# # Write filtered content back to the same file\n",
"# with open(file_path, 'w', encoding='utf-8') as f:\n",
"# f.writelines(filtered_lines)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Score answers"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_74415/3026956094.py:163: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import pandas as pd\n",
"import glob\n",
"\n",
"res = []\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
" smoldf = pd.read_json(file_path, lines=True)\n",
" smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
" res.append(smoldf)\n",
"result_df = pd.concat(res)\n",
"\n",
"\n",
"def get_correct(row):\n",
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
" if len(numbers_answer) == 0:\n",
" return False\n",
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
" else:\n",
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
"\n",
"\n",
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
"\n",
"result_df = (\n",
" (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n",
" .round(1)\n",
" .reset_index()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"pivot_df = result_df.pivot_table(\n",
" index=[\"model_id\", \"source\"],\n",
" columns=[\"action_type\"],\n",
" values=\"correct\",\n",
" fill_value=float(\"nan\"),\n",
").reset_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Display results"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>action_type</th>\n",
" <th>model_id</th>\n",
" <th>source</th>\n",
" <th>code</th>\n",
" <th>vanilla</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>28.1</td>\n",
" <td>6.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>74.0</td>\n",
" <td>31.9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>70.0</td>\n",
" <td>10.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>18.8</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>76.0</td>\n",
" <td>60.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>86.0</td>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>GAIA</td>\n",
" <td>40.6</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>MATH</td>\n",
" <td>67.0</td>\n",
" <td>50.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>SimpleQA</td>\n",
" <td>90.0</td>\n",
" <td>34.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>gpt-4o</td>\n",
" <td>GAIA</td>\n",
" <td>28.1</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>gpt-4o</td>\n",
" <td>MATH</td>\n",
" <td>70.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>gpt-4o</td>\n",
" <td>SimpleQA</td>\n",
" <td>88.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>42.0</td>\n",
" <td>18.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>54.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>3.1</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>32.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>34.4</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>82.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>84.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"action_type model_id source code vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
"1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 31.9\n",
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
"6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
"7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
"8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
"9 gpt-4o GAIA 28.1 3.1\n",
"10 gpt-4o MATH 70.0 40.0\n",
"11 gpt-4o SimpleQA 88.0 6.0\n",
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
"13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
"16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
"19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(pivot_df)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
"\n",
"# Assuming pivot_df is your original dataframe\n",
"models = pivot_df[\"model_id\"].unique()\n",
"sources = pivot_df[\"source\"].unique()\n",
"\n",
"# Create figure and axis\n",
"plt.style.use(\"seaborn-v0_8-white\")\n",
"fig, ax = plt.subplots(figsize=(15, 6))\n",
"\n",
"# Set the width of each bar group and positions of the bars\n",
"width = 0.15 # width of each bar\n",
"spacing = 0.02 # space between bars within a group\n",
"group_spacing = 0.2 # space between model groups\n",
"\n",
"# Calculate positions for the bars\n",
"num_sources = len(sources)\n",
"total_width_per_group = (width + spacing) * num_sources * 2 # *2 for agent and vanilla\n",
"x = np.arange(len(models)) * (total_width_per_group + group_spacing)\n",
"\n",
"# Plot bars for each source\n",
"for i, source in enumerate(sources):\n",
" source_data = pivot_df[pivot_df[\"source\"] == source]\n",
" agent_scores = [\n",
" source_data[source_data[\"model_id\"] == model][\"code\"].values[0]\n",
" if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
" else np.nan\n",
" for model in models\n",
" ]\n",
" vanilla_scores = [\n",
" source_data[source_data[\"model_id\"] == model][\"vanilla\"].values[0]\n",
" if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
" else np.nan\n",
" for model in models\n",
" ]\n",
"\n",
" # Position calculation for each pair of bars\n",
" pos = x + i * (width * 2 + spacing)\n",
"\n",
" agent_bars = ax.bar(pos, agent_scores, width, label=f\"{source} (Agent)\", alpha=0.8)\n",
" vanilla_bars = ax.bar(\n",
" pos + width * 0.6,\n",
" vanilla_scores,\n",
" width,\n",
" hatch=\"////\",\n",
" alpha=0.5,\n",
" hatch_linewidth=2,\n",
" label=f\"{source} (Vanilla)\",\n",
" color=\"white\",\n",
" edgecolor=agent_bars[0].get_facecolor(),\n",
" )\n",
"\n",
"# Customize the plot\n",
"ax.set_ylabel(\"Score\")\n",
"ax.set_title(\"Model Performance Comparison\")\n",
"\n",
"# Set x-axis ticks in the middle of each group\n",
"group_centers = x + (total_width_per_group - spacing) / 2\n",
"ax.set_xticks(group_centers)\n",
"\n",
"# Wrap long model names to prevent overlap\n",
"wrapped_labels = [\"\\n\".join(model.split(\"/\")) for model in models]\n",
"ax.set_xticklabels(wrapped_labels, rotation=0, ha=\"center\")\n",
"\n",
"# Modify legend to combine agent and vanilla entries\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"unique_sources = sources\n",
"legend_elements = [\n",
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n",
" for i in range(len(unique_sources))\n",
"]\n",
"custom_legend = ax.legend(\n",
" [\n",
" (agent_handle, vanilla_handle)\n",
" for agent_handle, vanilla_handle, _ in legend_elements\n",
" ],\n",
" [label for _, _, label in legend_elements],\n",
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
" bbox_to_anchor=(1.05, 1),\n",
" loc=\"upper left\",\n",
")\n",
"\n",
"ax.yaxis.grid(True, linestyle=\"--\", alpha=0.3)\n",
"ax.set_ylim(bottom=0)\n",
"plt.tight_layout()\n",
"ax.spines[\"top\"].set_visible(False)\n",
"ax.spines[\"right\"].set_visible(False)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'formatted_df' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 45\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mathjax_table\n\u001b[1;32m 44\u001b[0m \u001b[38;5;66;03m# Usage (after running your previous data processing code):\u001b[39;00m\n\u001b[0;32m---> 45\u001b[0m mathjax_table \u001b[38;5;241m=\u001b[39m create_mathjax_table(pivot_df, \u001b[43mformatted_df\u001b[49m)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28mprint\u001b[39m(mathjax_table)\n",
"\u001b[0;31mNameError\u001b[0m: name 'formatted_df' is not defined"
]
}
],
"source": [
"def create_mathjax_table(pivot_df, formatted_df):\n",
" # Start the matrix environment with 4 columns\n",
" # l for left-aligned model and task, c for centered numbers\n",
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
" mathjax_table += (\n",
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" )\n",
" mathjax_table += \"\\\\hline\\n\"\n",
"\n",
" # Sort the DataFrame by model_id and source\n",
" formatted_df = formatted_df.sort_values([\"model_id\", \"source\"])\n",
"\n",
" current_model = None\n",
" for _, row in formatted_df.iterrows():\n",
" model = row[\"model_id\"]\n",
" source = row[\"source\"]\n",
"\n",
" # Add a horizontal line between different models\n",
" if current_model is not None and current_model != model:\n",
" mathjax_table += \"\\\\hline\\n\"\n",
"\n",
" # Format model name\n",
" model_display = model.replace(\"_\", \"\\\\_\")\n",
" if \"Qwen\" in model or \"anthropic\" in model:\n",
" model_display = f\"\\\\textit{{{model_display}}}\"\n",
"\n",
" # If it's the same model as previous row, use empty space\n",
" if current_model == model:\n",
" model_display = \"\\\\;\"\n",
"\n",
" # Add the data row\n",
" mathjax_table += (\n",
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
" )\n",
"\n",
" current_model = model\n",
"\n",
" mathjax_table += \"\\\\hline\\n\"\n",
" mathjax_table += \"\\\\end{array}\"\n",
"\n",
" return mathjax_table\n",
"\n",
"\n",
"# Usage (after running your previous data processing code):\n",
"mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
"print(mathjax_table)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "compare-agents",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}