Make default tools more robust (#186)
This commit is contained in:
		
							parent
							
								
									12a2e6f4b4
								
							
						
					
					
						commit
						5f32373551
					
				|  | @ -36,6 +36,9 @@ jobs: | |||
|       - name: Agent tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_agents.py | ||||
|       - name: Default tools tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_default_tools.py | ||||
|       - name: Final answer tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_final_answer.py | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|  | @ -29,8 +29,7 @@ | |||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||||
|       "  from .autonotebook import tqdm as notebook_tqdm\n", | ||||
|       "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n" | ||||
|       "  from .autonotebook import tqdm as notebook_tqdm\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|  | @ -173,7 +172,7 @@ | |||
|        "[132 rows x 4 columns]" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 5, | ||||
|      "execution_count": 4, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|  | @ -196,19 +195,9 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n", | ||||
|       "* 'fields' has been removed\n", | ||||
|       "  warnings.warn(message, UserWarning)\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import time\n", | ||||
|     "import json\n", | ||||
|  | @ -408,100 +397,9 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "open_model_ids = [\n", | ||||
|     "    \"meta-llama/Llama-3.3-70B-Instruct\",\n", | ||||
|  | @ -554,42 +452,9 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'gpt-4o'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n", | ||||
|       "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from smolagents import LiteLLMModel\n", | ||||
|     "\n", | ||||
|  | @ -614,7 +479,7 @@ | |||
|     "    agent = CodeAgent(\n", | ||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", | ||||
|     "        model=LiteLLMModel(model_id),\n", | ||||
|     "        additional_authorized_imports=[\"numpy\"],\n", | ||||
|     "        additional_authorized_imports=[\"numpy\", \"sympy\"],\n", | ||||
|     "        max_steps=10,\n", | ||||
|     "    )\n", | ||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||
|  | @ -631,34 +496,39 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 9, | ||||
|    "execution_count": 8, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# import glob\n", | ||||
|     "# import json\n", | ||||
|     "\n", | ||||
|     "# jsonl_files = glob.glob(f\"output/*.jsonl\")\n", | ||||
|     "\n", | ||||
|     "# for file_path in jsonl_files:\n", | ||||
|     "#     print(file_path)\n", | ||||
|     "#     # Read all lines and filter out SimpleQA sources\n", | ||||
|     "#     filtered_lines = []\n", | ||||
|     "#     removed = 0\n", | ||||
|     "#     with open(file_path, 'r', encoding='utf-8') as f:\n", | ||||
|     "#         for line in f:\n", | ||||
|     "#             try:\n", | ||||
|     "#                 data = json.loads(line.strip())\n", | ||||
|     "#                 if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", | ||||
|     "#                     removed +=1\n", | ||||
|     "#                 else:\n", | ||||
|     "#                     filtered_lines.append(line)\n", | ||||
|     "#             except json.JSONDecodeError:\n", | ||||
|     "#                 print(\"Invalid line:\", line)\n", | ||||
|     "#                 continue  # Skip invalid JSON lines\n", | ||||
|     "#     print(f\"Removed {removed} lines.\")\n", | ||||
|     "#     # Write filtered content back to the same file\n", | ||||
|     "#     with open(file_path, 'w', encoding='utf-8') as f:\n", | ||||
|     "#         f.writelines(filtered_lines)" | ||||
|     "#     if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n", | ||||
|     "#         print(file_path)\n", | ||||
|     "#         # Read all lines and filter out SimpleQA sources\n", | ||||
|     "#         filtered_lines = []\n", | ||||
|     "#         removed = 0\n", | ||||
|     "#         with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", | ||||
|     "#             for line in f:\n", | ||||
|     "#                 try:\n", | ||||
|     "#                     data = json.loads(line.strip())\n", | ||||
|     "#                     data[\"answer\"] = data[\"answer\"][\"content\"]\n", | ||||
|     "#                     # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", | ||||
|     "#                     #     removed +=1\n", | ||||
|     "#                     # else:\n", | ||||
|     "#                     filtered_lines.append(json.dumps(data) + \"\\n\")\n", | ||||
|     "#                 except json.JSONDecodeError:\n", | ||||
|     "#                     print(\"Invalid line:\", line)\n", | ||||
|     "#                     continue  # Skip invalid JSON lines\n", | ||||
|     "#         print(f\"Removed {removed} lines.\")\n", | ||||
|     "#         # Write filtered content back to the same file\n", | ||||
|     "#         with open(\n", | ||||
|     "#             str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n", | ||||
|     "#         ) as f:\n", | ||||
|     "#             f.writelines(filtered_lines)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|  | @ -670,14 +540,14 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 9, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", | ||||
|       "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", | ||||
|       "  warnings.warn(\n" | ||||
|      ] | ||||
|     } | ||||
|  | @ -731,7 +601,7 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 14, | ||||
|    "execution_count": 10, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|  | @ -752,7 +622,7 @@ | |||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 15, | ||||
|    "execution_count": 11, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|  | @ -794,28 +664,28 @@ | |||
|        "      <th>1</th>\n", | ||||
|        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>74.0</td>\n", | ||||
|        "      <td>76.0</td>\n", | ||||
|        "      <td>30.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>2</th>\n", | ||||
|        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>70.0</td>\n", | ||||
|        "      <td>88.0</td>\n", | ||||
|        "      <td>10.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>3</th>\n", | ||||
|        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>18.8</td>\n", | ||||
|        "      <td>25.0</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>4</th>\n", | ||||
|        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>76.0</td>\n", | ||||
|        "      <td>86.0</td>\n", | ||||
|        "      <td>60.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|  | @ -829,63 +699,63 @@ | |||
|        "      <th>6</th>\n", | ||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>40.6</td>\n", | ||||
|        "      <td>NaN</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>7</th>\n", | ||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>67.0</td>\n", | ||||
|        "      <td>NaN</td>\n", | ||||
|        "      <td>50.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>8</th>\n", | ||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>90.0</td>\n", | ||||
|        "      <td>NaN</td>\n", | ||||
|        "      <td>34.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>9</th>\n", | ||||
|        "      <td>gpt-4o</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>28.1</td>\n", | ||||
|        "      <td>25.6</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>10</th>\n", | ||||
|        "      <td>gpt-4o</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>70.0</td>\n", | ||||
|        "      <td>58.0</td>\n", | ||||
|        "      <td>40.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>11</th>\n", | ||||
|        "      <td>gpt-4o</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>88.0</td>\n", | ||||
|        "      <td>86.0</td>\n", | ||||
|        "      <td>6.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>12</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>0.0</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "      <td>0.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>13</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>42.0</td>\n", | ||||
|        "      <td>14.0</td>\n", | ||||
|        "      <td>18.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>14</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>54.0</td>\n", | ||||
|        "      <td>2.0</td>\n", | ||||
|        "      <td>6.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|  | @ -899,49 +769,49 @@ | |||
|        "      <th>16</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>32.0</td>\n", | ||||
|        "      <td>40.0</td>\n", | ||||
|        "      <td>12.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>17</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>4.0</td>\n", | ||||
|        "      <td>20.0</td>\n", | ||||
|        "      <td>0.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>18</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>34.4</td>\n", | ||||
|        "      <td>31.2</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>19</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>82.0</td>\n", | ||||
|        "      <td>72.0</td>\n", | ||||
|        "      <td>40.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>20</th>\n", | ||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>84.0</td>\n", | ||||
|        "      <td>78.0</td>\n", | ||||
|        "      <td>12.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>21</th>\n", | ||||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||
|        "      <td>GAIA</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "      <td>0.0</td>\n", | ||||
|        "      <td>3.1</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|        "      <th>22</th>\n", | ||||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||
|        "      <td>MATH</td>\n", | ||||
|        "      <td>20.0</td>\n", | ||||
|        "      <td>30.0</td>\n", | ||||
|        "      <td>22.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "    <tr>\n", | ||||
|  | @ -949,7 +819,7 @@ | |||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||
|        "      <td>SimpleQA</td>\n", | ||||
|        "      <td>30.0</td>\n", | ||||
|        "      <td>0.0</td>\n", | ||||
|        "      <td>6.0</td>\n", | ||||
|        "    </tr>\n", | ||||
|        "  </tbody>\n", | ||||
|        "</table>\n", | ||||
|  | @ -958,29 +828,29 @@ | |||
|       "text/plain": [ | ||||
|        "action_type                              model_id    source  code  vanilla\n", | ||||
|        "0                       Qwen/Qwen2.5-72B-Instruct      GAIA  28.1      6.2\n", | ||||
|        "1                       Qwen/Qwen2.5-72B-Instruct      MATH  74.0     30.0\n", | ||||
|        "2                       Qwen/Qwen2.5-72B-Instruct  SimpleQA  70.0     10.0\n", | ||||
|        "3                 Qwen/Qwen2.5-Coder-32B-Instruct      GAIA  18.8      3.1\n", | ||||
|        "4                 Qwen/Qwen2.5-Coder-32B-Instruct      MATH  76.0     60.0\n", | ||||
|        "1                       Qwen/Qwen2.5-72B-Instruct      MATH  76.0     30.0\n", | ||||
|        "2                       Qwen/Qwen2.5-72B-Instruct  SimpleQA  88.0     10.0\n", | ||||
|        "3                 Qwen/Qwen2.5-Coder-32B-Instruct      GAIA  25.0      3.1\n", | ||||
|        "4                 Qwen/Qwen2.5-Coder-32B-Instruct      MATH  86.0     60.0\n", | ||||
|        "5                 Qwen/Qwen2.5-Coder-32B-Instruct  SimpleQA  86.0      8.0\n", | ||||
|        "6              anthropic/claude-3-5-sonnet-latest      GAIA  40.6      3.1\n", | ||||
|        "7              anthropic/claude-3-5-sonnet-latest      MATH  67.0     50.0\n", | ||||
|        "8              anthropic/claude-3-5-sonnet-latest  SimpleQA  90.0     34.0\n", | ||||
|        "9                                          gpt-4o      GAIA  28.1      3.1\n", | ||||
|        "10                                         gpt-4o      MATH  70.0     40.0\n", | ||||
|        "11                                         gpt-4o  SimpleQA  88.0      6.0\n", | ||||
|        "12               meta-llama/Llama-3.1-8B-Instruct      GAIA   0.0      0.0\n", | ||||
|        "13               meta-llama/Llama-3.1-8B-Instruct      MATH  42.0     18.0\n", | ||||
|        "14               meta-llama/Llama-3.1-8B-Instruct  SimpleQA  54.0      6.0\n", | ||||
|        "6              anthropic/claude-3-5-sonnet-latest      GAIA   NaN      3.1\n", | ||||
|        "7              anthropic/claude-3-5-sonnet-latest      MATH   NaN     50.0\n", | ||||
|        "8              anthropic/claude-3-5-sonnet-latest  SimpleQA   NaN     34.0\n", | ||||
|        "9                                          gpt-4o      GAIA  25.6      3.1\n", | ||||
|        "10                                         gpt-4o      MATH  58.0     40.0\n", | ||||
|        "11                                         gpt-4o  SimpleQA  86.0      6.0\n", | ||||
|        "12               meta-llama/Llama-3.1-8B-Instruct      GAIA   3.1      0.0\n", | ||||
|        "13               meta-llama/Llama-3.1-8B-Instruct      MATH  14.0     18.0\n", | ||||
|        "14               meta-llama/Llama-3.1-8B-Instruct  SimpleQA   2.0      6.0\n", | ||||
|        "15               meta-llama/Llama-3.2-3B-Instruct      GAIA   3.1      0.0\n", | ||||
|        "16               meta-llama/Llama-3.2-3B-Instruct      MATH  32.0     12.0\n", | ||||
|        "17               meta-llama/Llama-3.2-3B-Instruct  SimpleQA   4.0      0.0\n", | ||||
|        "18              meta-llama/Llama-3.3-70B-Instruct      GAIA  34.4      3.1\n", | ||||
|        "19              meta-llama/Llama-3.3-70B-Instruct      MATH  82.0     40.0\n", | ||||
|        "20              meta-llama/Llama-3.3-70B-Instruct  SimpleQA  84.0     12.0\n", | ||||
|        "21           mistralai/Mistral-Nemo-Instruct-2407      GAIA   3.1      0.0\n", | ||||
|        "22           mistralai/Mistral-Nemo-Instruct-2407      MATH  20.0     22.0\n", | ||||
|        "23           mistralai/Mistral-Nemo-Instruct-2407  SimpleQA  30.0      0.0" | ||||
|        "16               meta-llama/Llama-3.2-3B-Instruct      MATH  40.0     12.0\n", | ||||
|        "17               meta-llama/Llama-3.2-3B-Instruct  SimpleQA  20.0      0.0\n", | ||||
|        "18              meta-llama/Llama-3.3-70B-Instruct      GAIA  31.2      3.1\n", | ||||
|        "19              meta-llama/Llama-3.3-70B-Instruct      MATH  72.0     40.0\n", | ||||
|        "20              meta-llama/Llama-3.3-70B-Instruct  SimpleQA  78.0     12.0\n", | ||||
|        "21           mistralai/Mistral-Nemo-Instruct-2407      GAIA   0.0      3.1\n", | ||||
|        "22           mistralai/Mistral-Nemo-Instruct-2407      MATH  30.0     22.0\n", | ||||
|        "23           mistralai/Mistral-Nemo-Instruct-2407  SimpleQA  30.0      6.0" | ||||
|       ] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|  | @ -1005,6 +875,15 @@ | |||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "ename": "", | ||||
|      "evalue": "", | ||||
|      "output_type": "error", | ||||
|      "traceback": [ | ||||
|       "\u001b[1;31mnotebook controller is DISPOSED. \n", | ||||
|       "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|  |  | |||
|  | @ -15,7 +15,6 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import time | ||||
| import json | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||
| 
 | ||||
|  | @ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent): | |||
|                 tools_to_call_from=list(self.tools.values()), | ||||
|                 stop_sequences=["Observation:"], | ||||
|             ) | ||||
| 
 | ||||
|             # Extract tool call from model output | ||||
|             if ( | ||||
|                 type(model_message.tool_calls) is list | ||||
|                 and len(model_message.tool_calls) > 0 | ||||
|             ): | ||||
|                 tool_calls = model_message.tool_calls[0] | ||||
|                 tool_arguments = tool_calls.function.arguments | ||||
|                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id | ||||
|             else: | ||||
|                 start, end = ( | ||||
|                     model_message.content.find("{"), | ||||
|                     model_message.content.rfind("}") + 1, | ||||
|                 ) | ||||
|                 tool_calls = json.loads(model_message.content[start:end]) | ||||
|                 tool_arguments = tool_calls["tool_arguments"] | ||||
|                 tool_name, tool_call_id = ( | ||||
|                     tool_calls["tool_name"], | ||||
|                     f"call_{len(self.logs)}", | ||||
|                 ) | ||||
|             tool_call = model_message.tool_calls[0] | ||||
|             tool_name, tool_call_id = tool_call.function.name, tool_call.id | ||||
|             tool_arguments = tool_call.function.arguments | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             raise AgentGenerationError( | ||||
|  | @ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent): | |||
|                 updated_information = f"Stored '{observation_name}' in memory." | ||||
|             else: | ||||
|                 updated_information = str(observation).strip() | ||||
|             self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO) | ||||
|             self.logger.log( | ||||
|                 f"Observations: {updated_information.replace('[', '|')}",  # escape potential rich-tag-like components | ||||
|                 level=LogLevel.INFO, | ||||
|             ) | ||||
|             log_entry.observations = updated_information | ||||
|             return None | ||||
| 
 | ||||
|  |  | |||
|  | @ -31,6 +31,7 @@ from .local_python_executor import ( | |||
| ) | ||||
| from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | ||||
| from .types import AgentAudio | ||||
| from .utils import truncate_content | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|     from transformers.models.whisper import ( | ||||
|  | @ -112,18 +113,15 @@ class PythonInterpreterTool(Tool): | |||
| 
 | ||||
|     def forward(self, code: str) -> str: | ||||
|         state = {} | ||||
|         try: | ||||
|             output = str( | ||||
|                 self.python_evaluator( | ||||
|                     code, | ||||
|                     state=state, | ||||
|                     static_tools=self.base_python_tools, | ||||
|                     authorized_imports=self.authorized_imports, | ||||
|                 )[0]  # The second element is boolean is_final_answer | ||||
|             ) | ||||
|             return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" | ||||
|         except Exception as e: | ||||
|             return f"Error: {str(e)}" | ||||
|         output = str( | ||||
|             self.python_evaluator( | ||||
|                 code, | ||||
|                 state=state, | ||||
|                 static_tools=self.base_python_tools, | ||||
|                 authorized_imports=self.authorized_imports, | ||||
|             )[0]  # The second element is boolean is_final_answer | ||||
|         ) | ||||
|         return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" | ||||
| 
 | ||||
| 
 | ||||
| class FinalAnswerTool(Tool): | ||||
|  | @ -295,7 +293,7 @@ class VisitWebpageTool(Tool): | |||
|             # Remove multiple line breaks | ||||
|             markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | ||||
| 
 | ||||
|             return markdown_content | ||||
|             return truncate_content(markdown_content) | ||||
| 
 | ||||
|         except RequestException as e: | ||||
|             return f"Error fetching the webpage: {str(e)}" | ||||
|  |  | |||
|  | @ -14,20 +14,16 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from dataclasses import dataclass | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import random | ||||
| from copy import deepcopy | ||||
| from enum import Enum | ||||
| from typing import Dict, List, Optional | ||||
| from typing import Dict, List, Optional, Union, Any | ||||
| 
 | ||||
| from huggingface_hub import ( | ||||
|     InferenceClient, | ||||
|     ChatCompletionOutputMessage, | ||||
|     ChatCompletionOutputToolCall, | ||||
|     ChatCompletionOutputFunctionDefinition, | ||||
| ) | ||||
| from huggingface_hub import InferenceClient | ||||
| 
 | ||||
| from transformers import ( | ||||
|     AutoModelForCausalLM, | ||||
|  | @ -58,6 +54,27 @@ if _is_package_available("litellm"): | |||
|     import litellm | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessageToolCallDefinition: | ||||
|     arguments: Any | ||||
|     name: str | ||||
|     description: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessageToolCall: | ||||
|     function: ChatMessageToolCallDefinition | ||||
|     id: str | ||||
|     type: str | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessage: | ||||
|     role: str | ||||
|     content: Optional[str] = None | ||||
|     tool_calls: Optional[List[ChatMessageToolCall]] = None | ||||
| 
 | ||||
| 
 | ||||
| class MessageRole(str, Enum): | ||||
|     USER = "user" | ||||
|     ASSISTANT = "assistant" | ||||
|  | @ -140,6 +157,17 @@ def get_clean_message_list( | |||
|     return final_message_list | ||||
| 
 | ||||
| 
 | ||||
| def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: | ||||
|     try: | ||||
|         start, end = ( | ||||
|             possible_dictionary.find("{"), | ||||
|             possible_dictionary.rfind("}") + 1, | ||||
|         ) | ||||
|         return json.loads(possible_dictionary[start:end]) | ||||
|     except Exception: | ||||
|         return possible_dictionary | ||||
| 
 | ||||
| 
 | ||||
| class Model: | ||||
|     def __init__(self): | ||||
|         self.last_input_token_count = None | ||||
|  | @ -157,7 +185,7 @@ class Model: | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|     ) -> ChatMessage: | ||||
|         """Process the input messages and return the model's response. | ||||
| 
 | ||||
|         Parameters: | ||||
|  | @ -228,7 +256,7 @@ class HfApiModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|     ) -> ChatMessage: | ||||
|         """ | ||||
|         Gets an LLM output message for the given list of input messages. | ||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||
|  | @ -329,7 +357,7 @@ class TransformersModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  | @ -365,21 +393,21 @@ class TransformersModel(Model): | |||
|         if stop_sequences is not None: | ||||
|             output = remove_stop_sequences(output, stop_sequences) | ||||
|         if tools_to_call_from is None: | ||||
|             return ChatCompletionOutputMessage(role="assistant", content=output) | ||||
|             return ChatMessage(role="assistant", content=output) | ||||
|         else: | ||||
|             if "Action:" in output: | ||||
|                 output = output.split("Action:", 1)[1].strip() | ||||
|             parsed_output = json.loads(output) | ||||
|             tool_name = parsed_output.get("tool_name") | ||||
|             tool_arguments = parsed_output.get("tool_arguments") | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="".join(random.choices("0123456789", k=5)), | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name=tool_name, arguments=tool_arguments | ||||
|                         ), | ||||
|                     ) | ||||
|  | @ -414,7 +442,7 @@ class LiteLLMModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  | @ -485,7 +513,7 @@ class OpenAIServerModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|     ) -> ChatMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  |  | |||
|  | @ -221,6 +221,16 @@ class Tool: | |||
|     def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): | ||||
|         if not self.is_initialized: | ||||
|             self.setup() | ||||
| 
 | ||||
|         # Handle the arguments might be passed as a single dictionary | ||||
|         if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): | ||||
|             potential_kwargs = args[0] | ||||
| 
 | ||||
|             # If the dictionary keys match our input parameters, convert it to kwargs | ||||
|             if all(key in self.inputs for key in potential_kwargs): | ||||
|                 args = () | ||||
|                 kwargs = potential_kwargs | ||||
| 
 | ||||
|         if sanitize_inputs_outputs: | ||||
|             args, kwargs = handle_agent_input_types(*args, **kwargs) | ||||
|         outputs = self.forward(*args, **kwargs) | ||||
|  |  | |||
|  | @ -30,10 +30,10 @@ from smolagents.agents import ( | |||
| from smolagents.default_tools import PythonInterpreterTool | ||||
| from smolagents.tools import tool | ||||
| from smolagents.types import AgentImage, AgentText | ||||
| from huggingface_hub import ( | ||||
|     ChatCompletionOutputMessage, | ||||
|     ChatCompletionOutputToolCall, | ||||
|     ChatCompletionOutputFunctionDefinition, | ||||
| from smolagents.models import ( | ||||
|     ChatMessage, | ||||
|     ChatMessageToolCall, | ||||
|     ChatMessageToolCallDefinition, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -47,28 +47,28 @@ class FakeToolCallModel: | |||
|         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||
|     ): | ||||
|         if len(messages) < 3: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="call_0", | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name="python_interpreter", arguments={"code": "2*3.6452"} | ||||
|                         ), | ||||
|                     ) | ||||
|                 ], | ||||
|             ) | ||||
|         else: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="call_1", | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name="final_answer", arguments={"answer": "7.2904"} | ||||
|                         ), | ||||
|                     ) | ||||
|  | @ -81,14 +81,14 @@ class FakeToolCallModelImage: | |||
|         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||
|     ): | ||||
|         if len(messages) < 3: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="call_0", | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name="fake_image_generation_tool", | ||||
|                             arguments={"prompt": "An image of a cat"}, | ||||
|                         ), | ||||
|  | @ -96,14 +96,14 @@ class FakeToolCallModelImage: | |||
|                 ], | ||||
|             ) | ||||
|         else: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="call_1", | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name="final_answer", arguments="image.png" | ||||
|                         ), | ||||
|                     ) | ||||
|  | @ -114,7 +114,7 @@ class FakeToolCallModelImage: | |||
| def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: | ||||
|     prompt = str(messages) | ||||
|     if "special_marker" not in prompt: | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I should multiply 2 by 3.6452. special_marker | ||||
|  | @ -125,7 +125,7 @@ result = 2**3.6452 | |||
| """, | ||||
|         ) | ||||
|     else:  # We're at step 2 | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I can now answer the initial question | ||||
|  | @ -140,7 +140,7 @@ final_answer(7.2904) | |||
| def fake_code_model_error(messages, stop_sequences=None) -> str: | ||||
|     prompt = str(messages) | ||||
|     if "special_marker" not in prompt: | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I should multiply 2 by 3.6452. special_marker | ||||
|  | @ -154,7 +154,7 @@ print("Ok, calculation done!") | |||
| """, | ||||
|         ) | ||||
|     else:  # We're at step 2 | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I can now answer the initial question | ||||
|  | @ -169,7 +169,7 @@ final_answer("got an error") | |||
| def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: | ||||
|     prompt = str(messages) | ||||
|     if "special_marker" not in prompt: | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I should multiply 2 by 3.6452. special_marker | ||||
|  | @ -183,7 +183,7 @@ print("Ok, calculation done!") | |||
| """, | ||||
|         ) | ||||
|     else:  # We're at step 2 | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I can now answer the initial question | ||||
|  | @ -196,7 +196,7 @@ final_answer("got an error") | |||
| 
 | ||||
| 
 | ||||
| def fake_code_model_import(messages, stop_sequences=None) -> str: | ||||
|     return ChatCompletionOutputMessage( | ||||
|     return ChatMessage( | ||||
|         role="assistant", | ||||
|         content=""" | ||||
| Thought: I can answer the question | ||||
|  | @ -212,7 +212,7 @@ final_answer("got an error") | |||
| def fake_code_functiondef(messages, stop_sequences=None) -> str: | ||||
|     prompt = str(messages) | ||||
|     if "special_marker" not in prompt: | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: Let's define the function. special_marker | ||||
|  | @ -226,7 +226,7 @@ def moving_average(x, w): | |||
| """, | ||||
|         ) | ||||
|     else:  # We're at step 2 | ||||
|         return ChatCompletionOutputMessage( | ||||
|         return ChatMessage( | ||||
|             role="assistant", | ||||
|             content=""" | ||||
| Thought: I can now answer the initial question | ||||
|  | @ -241,7 +241,7 @@ final_answer(res) | |||
| 
 | ||||
| 
 | ||||
| def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: | ||||
|     return ChatCompletionOutputMessage( | ||||
|     return ChatMessage( | ||||
|         role="assistant", | ||||
|         content=""" | ||||
| Thought: I should multiply 2 by 3.6452. special_marker | ||||
|  | @ -255,7 +255,7 @@ final_answer(result) | |||
| 
 | ||||
| 
 | ||||
| def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: | ||||
|     return ChatCompletionOutputMessage( | ||||
|     return ChatMessage( | ||||
|         role="assistant", | ||||
|         content=""" | ||||
| Thought: I should multiply 2 by 3.6452. special_marker | ||||
|  | @ -454,14 +454,14 @@ class AgentTests(unittest.TestCase): | |||
|             ): | ||||
|                 if tools_to_call_from is not None: | ||||
|                     if len(messages) < 3: | ||||
|                         return ChatCompletionOutputMessage( | ||||
|                         return ChatMessage( | ||||
|                             role="assistant", | ||||
|                             content="", | ||||
|                             tool_calls=[ | ||||
|                                 ChatCompletionOutputToolCall( | ||||
|                                 ChatMessageToolCall( | ||||
|                                     id="call_0", | ||||
|                                     type="function", | ||||
|                                     function=ChatCompletionOutputFunctionDefinition( | ||||
|                                     function=ChatMessageToolCallDefinition( | ||||
|                                         name="search_agent", | ||||
|                                         arguments="Who is the current US president?", | ||||
|                                     ), | ||||
|  | @ -470,14 +470,14 @@ class AgentTests(unittest.TestCase): | |||
|                         ) | ||||
|                     else: | ||||
|                         assert "Report on the current US president" in str(messages) | ||||
|                         return ChatCompletionOutputMessage( | ||||
|                         return ChatMessage( | ||||
|                             role="assistant", | ||||
|                             content="", | ||||
|                             tool_calls=[ | ||||
|                                 ChatCompletionOutputToolCall( | ||||
|                                 ChatMessageToolCall( | ||||
|                                     id="call_0", | ||||
|                                     type="function", | ||||
|                                     function=ChatCompletionOutputFunctionDefinition( | ||||
|                                     function=ChatMessageToolCallDefinition( | ||||
|                                         name="final_answer", arguments="Final report." | ||||
|                                     ), | ||||
|                                 ) | ||||
|  | @ -485,7 +485,7 @@ class AgentTests(unittest.TestCase): | |||
|                         ) | ||||
|                 else: | ||||
|                     if len(messages) < 3: | ||||
|                         return ChatCompletionOutputMessage( | ||||
|                         return ChatMessage( | ||||
|                             role="assistant", | ||||
|                             content=""" | ||||
| Thought: Let's call our search agent. | ||||
|  | @ -497,7 +497,7 @@ result = search_agent("Who is the current US president?") | |||
|                         ) | ||||
|                     else: | ||||
|                         assert "Report on the current US president" in str(messages) | ||||
|                         return ChatCompletionOutputMessage( | ||||
|                         return ChatMessage( | ||||
|                             role="assistant", | ||||
|                             content=""" | ||||
| Thought: Let's return the report. | ||||
|  | @ -518,14 +518,14 @@ final_answer("Final report.") | |||
|                 stop_sequences=None, | ||||
|                 grammar=None, | ||||
|             ): | ||||
|                 return ChatCompletionOutputMessage( | ||||
|                 return ChatMessage( | ||||
|                     role="assistant", | ||||
|                     content="", | ||||
|                     tool_calls=[ | ||||
|                         ChatCompletionOutputToolCall( | ||||
|                         ChatMessageToolCall( | ||||
|                             id="call_0", | ||||
|                             type="function", | ||||
|                             function=ChatCompletionOutputFunctionDefinition( | ||||
|                             function=ChatMessageToolCallDefinition( | ||||
|                                 name="final_answer", | ||||
|                                 arguments="Report on the current US president", | ||||
|                             ), | ||||
|  | @ -568,7 +568,7 @@ final_answer("Final report.") | |||
| 
 | ||||
|     def test_code_nontrivial_final_answer_works(self): | ||||
|         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="""Code: | ||||
| ```py | ||||
|  |  | |||
|  | @ -0,0 +1,83 @@ | |||
| # coding=utf-8 | ||||
| # Copyright 2024 HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import unittest | ||||
| import pytest | ||||
| 
 | ||||
| from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| 
 | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| class DefaultToolTests(unittest.TestCase): | ||||
|     def test_visit_webpage(self): | ||||
|         arguments = { | ||||
|             "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security" | ||||
|         } | ||||
|         result = VisitWebpageTool()(arguments) | ||||
|         assert isinstance(result, str) | ||||
|         assert ( | ||||
|             "* [About Wikipedia](/wiki/Wikipedia:About)" in result | ||||
|         )  # Proper wikipedia pages have an About | ||||
| 
 | ||||
| 
 | ||||
| class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||
|     def setUp(self): | ||||
|         self.tool = PythonInterpreterTool(authorized_imports=["numpy"]) | ||||
|         self.tool.setup() | ||||
| 
 | ||||
|     def test_exact_match_arg(self): | ||||
|         result = self.tool("(2 / 2) * 4") | ||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||
| 
 | ||||
|     def test_exact_match_kwarg(self): | ||||
|         result = self.tool(code="(2 / 2) * 4") | ||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||
| 
 | ||||
|     def test_agent_type_output(self): | ||||
|         inputs = ["2 * 2"] | ||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||
|         self.assertTrue(isinstance(output, output_type)) | ||||
| 
 | ||||
|     def test_agent_types_inputs(self): | ||||
|         inputs = ["2 * 2"] | ||||
|         _inputs = [] | ||||
| 
 | ||||
|         for _input, expected_input in zip(inputs, self.tool.inputs.values()): | ||||
|             input_type = expected_input["type"] | ||||
|             if isinstance(input_type, list): | ||||
|                 _inputs.append( | ||||
|                     [ | ||||
|                         AGENT_TYPE_MAPPING[_input_type](_input) | ||||
|                         for _input_type in input_type | ||||
|                     ] | ||||
|                 ) | ||||
|             else: | ||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) | ||||
| 
 | ||||
|         # Should not raise an error | ||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||
|         self.assertTrue(isinstance(output, output_type)) | ||||
| 
 | ||||
|     def test_imports_work(self): | ||||
|         result = self.tool("import numpy as np") | ||||
|         assert "import from numpy is not allowed" not in result.lower() | ||||
| 
 | ||||
|     def test_unauthorized_imports_fail(self): | ||||
|         with pytest.raises(Exception) as e: | ||||
|             self.tool("import sympy as sp") | ||||
|         assert "sympy" in str(e).lower() | ||||
|  | @ -23,9 +23,9 @@ from smolagents import ( | |||
|     stream_to_gradio, | ||||
| ) | ||||
| from huggingface_hub import ( | ||||
|     ChatCompletionOutputMessage, | ||||
|     ChatCompletionOutputToolCall, | ||||
|     ChatCompletionOutputFunctionDefinition, | ||||
|     ChatMessage, | ||||
|     ChatMessageToolCall, | ||||
|     ChatMessageToolCallDefinition, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -36,21 +36,21 @@ class FakeLLMModel: | |||
| 
 | ||||
|     def __call__(self, prompt, tools_to_call_from=None, **kwargs): | ||||
|         if tools_to_call_from is not None: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content="", | ||||
|                 tool_calls=[ | ||||
|                     ChatCompletionOutputToolCall( | ||||
|                     ChatMessageToolCall( | ||||
|                         id="fake_id", | ||||
|                         type="function", | ||||
|                         function=ChatCompletionOutputFunctionDefinition( | ||||
|                         function=ChatMessageToolCallDefinition( | ||||
|                             name="final_answer", arguments={"answer": "image"} | ||||
|                         ), | ||||
|                     ) | ||||
|                 ], | ||||
|             ) | ||||
|         else: | ||||
|             return ChatCompletionOutputMessage( | ||||
|             return ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content=""" | ||||
| Code: | ||||
|  | @ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase): | |||
|                 self.last_output_token_count = 20 | ||||
| 
 | ||||
|             def __call__(self, prompt, **kwargs): | ||||
|                 return ChatCompletionOutputMessage( | ||||
|                     role="assistant", content="Malformed answer" | ||||
|                 ) | ||||
|                 return ChatMessage(role="assistant", content="Malformed answer") | ||||
| 
 | ||||
|         agent = CodeAgent( | ||||
|             tools=[], | ||||
|  |  | |||
|  | @ -18,15 +18,12 @@ import unittest | |||
| import numpy as np | ||||
| import pytest | ||||
| 
 | ||||
| from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool | ||||
| from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||
| from smolagents.local_python_executor import ( | ||||
|     InterpreterError, | ||||
|     evaluate_python_code, | ||||
|     fix_final_answer_code, | ||||
| ) | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| 
 | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| # Fake function we will use as tool | ||||
|  | @ -34,47 +31,6 @@ def add_two(x): | |||
|     return x + 2 | ||||
| 
 | ||||
| 
 | ||||
| class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||
|     def setUp(self): | ||||
|         self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"]) | ||||
|         self.tool.setup() | ||||
| 
 | ||||
|     def test_exact_match_arg(self): | ||||
|         result = self.tool("(2 / 2) * 4") | ||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||
| 
 | ||||
|     def test_exact_match_kwarg(self): | ||||
|         result = self.tool(code="(2 / 2) * 4") | ||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||
| 
 | ||||
|     def test_agent_type_output(self): | ||||
|         inputs = ["2 * 2"] | ||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||
|         self.assertTrue(isinstance(output, output_type)) | ||||
| 
 | ||||
|     def test_agent_types_inputs(self): | ||||
|         inputs = ["2 * 2"] | ||||
|         _inputs = [] | ||||
| 
 | ||||
|         for _input, expected_input in zip(inputs, self.tool.inputs.values()): | ||||
|             input_type = expected_input["type"] | ||||
|             if isinstance(input_type, list): | ||||
|                 _inputs.append( | ||||
|                     [ | ||||
|                         AGENT_TYPE_MAPPING[_input_type](_input) | ||||
|                         for _input_type in input_type | ||||
|                     ] | ||||
|                 ) | ||||
|             else: | ||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) | ||||
| 
 | ||||
|         # Should not raise an error | ||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||
|         self.assertTrue(isinstance(output, output_type)) | ||||
| 
 | ||||
| 
 | ||||
| class PythonInterpreterTester(unittest.TestCase): | ||||
|     def test_evaluate_assign(self): | ||||
|         code = "x = 3" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue