Make default tools more robust (#186)
This commit is contained in:
parent
12a2e6f4b4
commit
5f32373551
|
@ -36,6 +36,9 @@ jobs:
|
|||
- name: Agent tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_agents.py
|
||||
- name: Default tools tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_default_tools.py
|
||||
- name: Final answer tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_final_answer.py
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -29,8 +29,7 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -173,7 +172,7 @@
|
|||
"[132 rows x 4 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -196,19 +195,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
|
||||
"* 'fields' has been removed\n",
|
||||
" warnings.warn(message, UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import json\n",
|
||||
|
@ -408,100 +397,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"open_model_ids = [\n",
|
||||
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||
|
@ -554,42 +452,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'gpt-4o'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
|
||||
"100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from smolagents import LiteLLMModel\n",
|
||||
"\n",
|
||||
|
@ -614,7 +479,7 @@
|
|||
" agent = CodeAgent(\n",
|
||||
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
||||
" model=LiteLLMModel(model_id),\n",
|
||||
" additional_authorized_imports=[\"numpy\"],\n",
|
||||
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
||||
" max_steps=10,\n",
|
||||
" )\n",
|
||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||
|
@ -631,34 +496,39 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import glob\n",
|
||||
"# import json\n",
|
||||
"\n",
|
||||
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
|
||||
"\n",
|
||||
"# for file_path in jsonl_files:\n",
|
||||
"# print(file_path)\n",
|
||||
"# # Read all lines and filter out SimpleQA sources\n",
|
||||
"# filtered_lines = []\n",
|
||||
"# removed = 0\n",
|
||||
"# with open(file_path, 'r', encoding='utf-8') as f:\n",
|
||||
"# for line in f:\n",
|
||||
"# try:\n",
|
||||
"# data = json.loads(line.strip())\n",
|
||||
"# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
|
||||
"# removed +=1\n",
|
||||
"# else:\n",
|
||||
"# filtered_lines.append(line)\n",
|
||||
"# except json.JSONDecodeError:\n",
|
||||
"# print(\"Invalid line:\", line)\n",
|
||||
"# continue # Skip invalid JSON lines\n",
|
||||
"# print(f\"Removed {removed} lines.\")\n",
|
||||
"# # Write filtered content back to the same file\n",
|
||||
"# with open(file_path, 'w', encoding='utf-8') as f:\n",
|
||||
"# f.writelines(filtered_lines)"
|
||||
"# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
|
||||
"# print(file_path)\n",
|
||||
"# # Read all lines and filter out SimpleQA sources\n",
|
||||
"# filtered_lines = []\n",
|
||||
"# removed = 0\n",
|
||||
"# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
|
||||
"# for line in f:\n",
|
||||
"# try:\n",
|
||||
"# data = json.loads(line.strip())\n",
|
||||
"# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
|
||||
"# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
|
||||
"# # removed +=1\n",
|
||||
"# # else:\n",
|
||||
"# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
|
||||
"# except json.JSONDecodeError:\n",
|
||||
"# print(\"Invalid line:\", line)\n",
|
||||
"# continue # Skip invalid JSON lines\n",
|
||||
"# print(f\"Removed {removed} lines.\")\n",
|
||||
"# # Write filtered content back to the same file\n",
|
||||
"# with open(\n",
|
||||
"# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
|
||||
"# ) as f:\n",
|
||||
"# f.writelines(filtered_lines)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -670,14 +540,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
|
@ -731,7 +601,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -752,7 +622,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -794,28 +664,28 @@
|
|||
" <th>1</th>\n",
|
||||
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>74.0</td>\n",
|
||||
" <td>76.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>70.0</td>\n",
|
||||
" <td>88.0</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>18.8</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>76.0</td>\n",
|
||||
" <td>86.0</td>\n",
|
||||
" <td>60.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
|
@ -829,63 +699,63 @@
|
|||
" <th>6</th>\n",
|
||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>40.6</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>50.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>gpt-4o</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>28.1</td>\n",
|
||||
" <td>25.6</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>gpt-4o</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>70.0</td>\n",
|
||||
" <td>58.0</td>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>gpt-4o</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>88.0</td>\n",
|
||||
" <td>86.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>42.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>18.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>54.0</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
|
@ -899,49 +769,49 @@
|
|||
" <th>16</th>\n",
|
||||
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>32.0</td>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" <td>12.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>17</th>\n",
|
||||
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>18</th>\n",
|
||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>34.4</td>\n",
|
||||
" <td>31.2</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>19</th>\n",
|
||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>82.0</td>\n",
|
||||
" <td>72.0</td>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>20</th>\n",
|
||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>84.0</td>\n",
|
||||
" <td>78.0</td>\n",
|
||||
" <td>12.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>21</th>\n",
|
||||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||
" <td>GAIA</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>22</th>\n",
|
||||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||
" <td>MATH</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>22.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
|
@ -949,7 +819,7 @@
|
|||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||
" <td>SimpleQA</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
|
@ -958,29 +828,29 @@
|
|||
"text/plain": [
|
||||
"action_type model_id source code vanilla\n",
|
||||
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
|
||||
"1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
|
||||
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
|
||||
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
|
||||
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
|
||||
"1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
|
||||
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
|
||||
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
|
||||
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
|
||||
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
|
||||
"6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
|
||||
"7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
|
||||
"8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
|
||||
"9 gpt-4o GAIA 28.1 3.1\n",
|
||||
"10 gpt-4o MATH 70.0 40.0\n",
|
||||
"11 gpt-4o SimpleQA 88.0 6.0\n",
|
||||
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
|
||||
"13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
|
||||
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
|
||||
"6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
|
||||
"7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
|
||||
"8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
|
||||
"9 gpt-4o GAIA 25.6 3.1\n",
|
||||
"10 gpt-4o MATH 58.0 40.0\n",
|
||||
"11 gpt-4o SimpleQA 86.0 6.0\n",
|
||||
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
|
||||
"13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
|
||||
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
|
||||
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
|
||||
"16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
|
||||
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
|
||||
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
|
||||
"19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
|
||||
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
|
||||
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
|
||||
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
|
||||
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
|
||||
"16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
|
||||
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
|
||||
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
|
||||
"19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
|
||||
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
|
||||
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
|
||||
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
|
||||
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -1005,6 +875,15 @@
|
|||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
||||
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tools_to_call_from=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
|
||||
# Extract tool call from model output
|
||||
if (
|
||||
type(model_message.tool_calls) is list
|
||||
and len(model_message.tool_calls) > 0
|
||||
):
|
||||
tool_calls = model_message.tool_calls[0]
|
||||
tool_arguments = tool_calls.function.arguments
|
||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||
else:
|
||||
start, end = (
|
||||
model_message.content.find("{"),
|
||||
model_message.content.rfind("}") + 1,
|
||||
)
|
||||
tool_calls = json.loads(model_message.content[start:end])
|
||||
tool_arguments = tool_calls["tool_arguments"]
|
||||
tool_name, tool_call_id = (
|
||||
tool_calls["tool_name"],
|
||||
f"call_{len(self.logs)}",
|
||||
)
|
||||
tool_call = model_message.tool_calls[0]
|
||||
tool_name, tool_call_id = tool_call.function.name, tool_call.id
|
||||
tool_arguments = tool_call.function.arguments
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(
|
||||
|
@ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
updated_information = f"Stored '{observation_name}' in memory."
|
||||
else:
|
||||
updated_information = str(observation).strip()
|
||||
self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
|
||||
self.logger.log(
|
||||
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
|
||||
level=LogLevel.INFO,
|
||||
)
|
||||
log_entry.observations = updated_information
|
||||
return None
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from .local_python_executor import (
|
|||
)
|
||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||
from .types import AgentAudio
|
||||
from .utils import truncate_content
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.whisper import (
|
||||
|
@ -112,18 +113,15 @@ class PythonInterpreterTool(Tool):
|
|||
|
||||
def forward(self, code: str) -> str:
|
||||
state = {}
|
||||
try:
|
||||
output = str(
|
||||
self.python_evaluator(
|
||||
code,
|
||||
state=state,
|
||||
static_tools=self.base_python_tools,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)[0] # The second element is boolean is_final_answer
|
||||
)
|
||||
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
output = str(
|
||||
self.python_evaluator(
|
||||
code,
|
||||
state=state,
|
||||
static_tools=self.base_python_tools,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)[0] # The second element is boolean is_final_answer
|
||||
)
|
||||
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
||||
|
||||
|
||||
class FinalAnswerTool(Tool):
|
||||
|
@ -295,7 +293,7 @@ class VisitWebpageTool(Tool):
|
|||
# Remove multiple line breaks
|
||||
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
||||
|
||||
return markdown_content
|
||||
return truncate_content(markdown_content)
|
||||
|
||||
except RequestException as e:
|
||||
return f"Error fetching the webpage: {str(e)}"
|
||||
|
|
|
@ -14,20 +14,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
|
||||
from huggingface_hub import (
|
||||
InferenceClient,
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
)
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
|
@ -58,6 +54,27 @@ if _is_package_available("litellm"):
|
|||
import litellm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessageToolCallDefinition:
|
||||
arguments: Any
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessageToolCall:
|
||||
function: ChatMessageToolCallDefinition
|
||||
id: str
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ChatMessageToolCall]] = None
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
@ -140,6 +157,17 @@ def get_clean_message_list(
|
|||
return final_message_list
|
||||
|
||||
|
||||
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
|
||||
try:
|
||||
start, end = (
|
||||
possible_dictionary.find("{"),
|
||||
possible_dictionary.rfind("}") + 1,
|
||||
)
|
||||
return json.loads(possible_dictionary[start:end])
|
||||
except Exception:
|
||||
return possible_dictionary
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = None
|
||||
|
@ -157,7 +185,7 @@ class Model:
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
) -> ChatCompletionOutputMessage:
|
||||
) -> ChatMessage:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
Parameters:
|
||||
|
@ -228,7 +256,7 @@ class HfApiModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatCompletionOutputMessage:
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Gets an LLM output message for the given list of input messages.
|
||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||
|
@ -329,7 +357,7 @@ class TransformersModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatCompletionOutputMessage:
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
@ -365,21 +393,21 @@ class TransformersModel(Model):
|
|||
if stop_sequences is not None:
|
||||
output = remove_stop_sequences(output, stop_sequences)
|
||||
if tools_to_call_from is None:
|
||||
return ChatCompletionOutputMessage(role="assistant", content=output)
|
||||
return ChatMessage(role="assistant", content=output)
|
||||
else:
|
||||
if "Action:" in output:
|
||||
output = output.split("Action:", 1)[1].strip()
|
||||
parsed_output = json.loads(output)
|
||||
tool_name = parsed_output.get("tool_name")
|
||||
tool_arguments = parsed_output.get("tool_arguments")
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="".join(random.choices("0123456789", k=5)),
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name=tool_name, arguments=tool_arguments
|
||||
),
|
||||
)
|
||||
|
@ -414,7 +442,7 @@ class LiteLLMModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatCompletionOutputMessage:
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
@ -485,7 +513,7 @@ class OpenAIServerModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
max_tokens: int = 1500,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatCompletionOutputMessage:
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
|
|
@ -221,6 +221,16 @@ class Tool:
|
|||
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
# Handle the arguments might be passed as a single dictionary
|
||||
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
|
||||
potential_kwargs = args[0]
|
||||
|
||||
# If the dictionary keys match our input parameters, convert it to kwargs
|
||||
if all(key in self.inputs for key in potential_kwargs):
|
||||
args = ()
|
||||
kwargs = potential_kwargs
|
||||
|
||||
if sanitize_inputs_outputs:
|
||||
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||
outputs = self.forward(*args, **kwargs)
|
||||
|
|
|
@ -30,10 +30,10 @@ from smolagents.agents import (
|
|||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
from huggingface_hub import (
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
from smolagents.models import (
|
||||
ChatMessage,
|
||||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
)
|
||||
|
||||
|
||||
|
@ -47,28 +47,28 @@ class FakeToolCallModel:
|
|||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="python_interpreter", arguments={"code": "2*3.6452"}
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments={"answer": "7.2904"}
|
||||
),
|
||||
)
|
||||
|
@ -81,14 +81,14 @@ class FakeToolCallModelImage:
|
|||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
if len(messages) < 3:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="fake_image_generation_tool",
|
||||
arguments={"prompt": "An image of a cat"},
|
||||
),
|
||||
|
@ -96,14 +96,14 @@ class FakeToolCallModelImage:
|
|||
],
|
||||
)
|
||||
else:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments="image.png"
|
||||
),
|
||||
)
|
||||
|
@ -114,7 +114,7 @@ class FakeToolCallModelImage:
|
|||
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
|
@ -125,7 +125,7 @@ result = 2**3.6452
|
|||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
|
@ -140,7 +140,7 @@ final_answer(7.2904)
|
|||
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
|
@ -154,7 +154,7 @@ print("Ok, calculation done!")
|
|||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
|
@ -169,7 +169,7 @@ final_answer("got an error")
|
|||
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
|
@ -183,7 +183,7 @@ print("Ok, calculation done!")
|
|||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
|
@ -196,7 +196,7 @@ final_answer("got an error")
|
|||
|
||||
|
||||
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can answer the question
|
||||
|
@ -212,7 +212,7 @@ final_answer("got an error")
|
|||
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's define the function. special_marker
|
||||
|
@ -226,7 +226,7 @@ def moving_average(x, w):
|
|||
""",
|
||||
)
|
||||
else: # We're at step 2
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I can now answer the initial question
|
||||
|
@ -241,7 +241,7 @@ final_answer(res)
|
|||
|
||||
|
||||
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
|
@ -255,7 +255,7 @@ final_answer(result)
|
|||
|
||||
|
||||
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
|
@ -454,14 +454,14 @@ class AgentTests(unittest.TestCase):
|
|||
):
|
||||
if tools_to_call_from is not None:
|
||||
if len(messages) < 3:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="search_agent",
|
||||
arguments="Who is the current US president?",
|
||||
),
|
||||
|
@ -470,14 +470,14 @@ class AgentTests(unittest.TestCase):
|
|||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments="Final report."
|
||||
),
|
||||
)
|
||||
|
@ -485,7 +485,7 @@ class AgentTests(unittest.TestCase):
|
|||
)
|
||||
else:
|
||||
if len(messages) < 3:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's call our search agent.
|
||||
|
@ -497,7 +497,7 @@ result = search_agent("Who is the current US president?")
|
|||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's return the report.
|
||||
|
@ -518,14 +518,14 @@ final_answer("Final report.")
|
|||
stop_sequences=None,
|
||||
grammar=None,
|
||||
):
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer",
|
||||
arguments="Report on the current US president",
|
||||
),
|
||||
|
@ -568,7 +568,7 @@ final_answer("Final report.")
|
|||
|
||||
def test_code_nontrivial_final_answer_works(self):
|
||||
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""Code:
|
||||
```py
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
class DefaultToolTests(unittest.TestCase):
|
||||
def test_visit_webpage(self):
|
||||
arguments = {
|
||||
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
|
||||
}
|
||||
result = VisitWebpageTool()(arguments)
|
||||
assert isinstance(result, str)
|
||||
assert (
|
||||
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
|
||||
) # Proper wikipedia pages have an About
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append(
|
||||
[
|
||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||
for _input_type in input_type
|
||||
]
|
||||
)
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_imports_work(self):
|
||||
result = self.tool("import numpy as np")
|
||||
assert "import from numpy is not allowed" not in result.lower()
|
||||
|
||||
def test_unauthorized_imports_fail(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
self.tool("import sympy as sp")
|
||||
assert "sympy" in str(e).lower()
|
|
@ -23,9 +23,9 @@ from smolagents import (
|
|||
stream_to_gradio,
|
||||
)
|
||||
from huggingface_hub import (
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionOutputToolCall,
|
||||
ChatCompletionOutputFunctionDefinition,
|
||||
ChatMessage,
|
||||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,21 +36,21 @@ class FakeLLMModel:
|
|||
|
||||
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
|
||||
if tools_to_call_from is not None:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionOutputToolCall(
|
||||
ChatMessageToolCall(
|
||||
id="fake_id",
|
||||
type="function",
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments={"answer": "image"}
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
return ChatCompletionOutputMessage(
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Code:
|
||||
|
@ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase):
|
|||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return ChatCompletionOutputMessage(
|
||||
role="assistant", content="Malformed answer"
|
||||
)
|
||||
return ChatMessage(role="assistant", content="Malformed answer")
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
|
|
|
@ -18,15 +18,12 @@ import unittest
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||
from smolagents.local_python_executor import (
|
||||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
fix_final_answer_code,
|
||||
)
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
|
@ -34,47 +31,6 @@ def add_two(x):
|
|||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append(
|
||||
[
|
||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||
for _input_type in input_type
|
||||
]
|
||||
)
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
|
|
Loading…
Reference in New Issue