Make default tools more robust (#186)
This commit is contained in:
		
							parent
							
								
									12a2e6f4b4
								
							
						
					
					
						commit
						5f32373551
					
				|  | @ -36,6 +36,9 @@ jobs: | ||||||
|       - name: Agent tests |       - name: Agent tests | ||||||
|         run: | |         run: | | ||||||
|           uv run pytest -sv ./tests/test_agents.py |           uv run pytest -sv ./tests/test_agents.py | ||||||
|  |       - name: Default tools tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_default_tools.py | ||||||
|       - name: Final answer tests |       - name: Final answer tests | ||||||
|         run: | |         run: | | ||||||
|           uv run pytest -sv ./tests/test_final_answer.py |           uv run pytest -sv ./tests/test_final_answer.py | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 5, |    "execution_count": 4, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|  | @ -29,8 +29,7 @@ | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |       "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||||||
|       "  from .autonotebook import tqdm as notebook_tqdm\n", |       "  from .autonotebook import tqdm as notebook_tqdm\n" | ||||||
|       "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n" |  | ||||||
|      ] |      ] | ||||||
|     }, |     }, | ||||||
|     { |     { | ||||||
|  | @ -173,7 +172,7 @@ | ||||||
|        "[132 rows x 4 columns]" |        "[132 rows x 4 columns]" | ||||||
|       ] |       ] | ||||||
|      }, |      }, | ||||||
|      "execution_count": 5, |      "execution_count": 4, | ||||||
|      "metadata": {}, |      "metadata": {}, | ||||||
|      "output_type": "execute_result" |      "output_type": "execute_result" | ||||||
|     } |     } | ||||||
|  | @ -196,19 +195,9 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 6, |    "execution_count": 7, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [], | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n", |  | ||||||
|       "* 'fields' has been removed\n", |  | ||||||
|       "  warnings.warn(message, UserWarning)\n" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |    "source": [ | ||||||
|     "import time\n", |     "import time\n", | ||||||
|     "import json\n", |     "import json\n", | ||||||
|  | @ -408,100 +397,9 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 7, |    "execution_count": null, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [], | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |    "source": [ | ||||||
|     "open_model_ids = [\n", |     "open_model_ids = [\n", | ||||||
|     "    \"meta-llama/Llama-3.3-70B-Instruct\",\n", |     "    \"meta-llama/Llama-3.3-70B-Instruct\",\n", | ||||||
|  | @ -554,42 +452,9 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 8, |    "execution_count": null, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [], | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'gpt-4o'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n", |  | ||||||
|       "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |    "source": [ | ||||||
|     "from smolagents import LiteLLMModel\n", |     "from smolagents import LiteLLMModel\n", | ||||||
|     "\n", |     "\n", | ||||||
|  | @ -614,7 +479,7 @@ | ||||||
|     "    agent = CodeAgent(\n", |     "    agent = CodeAgent(\n", | ||||||
|     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", |     "        tools=[GoogleSearchTool(), VisitWebpageTool()],\n", | ||||||
|     "        model=LiteLLMModel(model_id),\n", |     "        model=LiteLLMModel(model_id),\n", | ||||||
|     "        additional_authorized_imports=[\"numpy\"],\n", |     "        additional_authorized_imports=[\"numpy\", \"sympy\"],\n", | ||||||
|     "        max_steps=10,\n", |     "        max_steps=10,\n", | ||||||
|     "    )\n", |     "    )\n", | ||||||
|     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", |     "    file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", | ||||||
|  | @ -631,34 +496,39 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 9, |    "execution_count": 8, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# import glob\n", |     "# import glob\n", | ||||||
|     "# import json\n", |     "# import json\n", | ||||||
|  |     "\n", | ||||||
|     "# jsonl_files = glob.glob(f\"output/*.jsonl\")\n", |     "# jsonl_files = glob.glob(f\"output/*.jsonl\")\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "# for file_path in jsonl_files:\n", |     "# for file_path in jsonl_files:\n", | ||||||
|     "#     print(file_path)\n", |     "#     if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n", | ||||||
|     "#     # Read all lines and filter out SimpleQA sources\n", |     "#         print(file_path)\n", | ||||||
|     "#     filtered_lines = []\n", |     "#         # Read all lines and filter out SimpleQA sources\n", | ||||||
|     "#     removed = 0\n", |     "#         filtered_lines = []\n", | ||||||
|     "#     with open(file_path, 'r', encoding='utf-8') as f:\n", |     "#         removed = 0\n", | ||||||
|     "#         for line in f:\n", |     "#         with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", | ||||||
|     "#             try:\n", |     "#             for line in f:\n", | ||||||
|     "#                 data = json.loads(line.strip())\n", |     "#                 try:\n", | ||||||
|     "#                 if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", |     "#                     data = json.loads(line.strip())\n", | ||||||
|     "#                     removed +=1\n", |     "#                     data[\"answer\"] = data[\"answer\"][\"content\"]\n", | ||||||
|     "#                 else:\n", |     "#                     # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", | ||||||
|     "#                     filtered_lines.append(line)\n", |     "#                     #     removed +=1\n", | ||||||
|     "#             except json.JSONDecodeError:\n", |     "#                     # else:\n", | ||||||
|     "#                 print(\"Invalid line:\", line)\n", |     "#                     filtered_lines.append(json.dumps(data) + \"\\n\")\n", | ||||||
|     "#                 continue  # Skip invalid JSON lines\n", |     "#                 except json.JSONDecodeError:\n", | ||||||
|     "#     print(f\"Removed {removed} lines.\")\n", |     "#                     print(\"Invalid line:\", line)\n", | ||||||
|     "#     # Write filtered content back to the same file\n", |     "#                     continue  # Skip invalid JSON lines\n", | ||||||
|     "#     with open(file_path, 'w', encoding='utf-8') as f:\n", |     "#         print(f\"Removed {removed} lines.\")\n", | ||||||
|     "#         f.writelines(filtered_lines)" |     "#         # Write filtered content back to the same file\n", | ||||||
|  |     "#         with open(\n", | ||||||
|  |     "#             str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n", | ||||||
|  |     "#         ) as f:\n", | ||||||
|  |     "#             f.writelines(filtered_lines)" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|  | @ -670,14 +540,14 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": 9, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", |       "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", | ||||||
|       "  warnings.warn(\n" |       "  warnings.warn(\n" | ||||||
|      ] |      ] | ||||||
|     } |     } | ||||||
|  | @ -731,7 +601,7 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 14, |    "execution_count": 10, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|  | @ -752,7 +622,7 @@ | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 15, |    "execution_count": 11, | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
|     { |     { | ||||||
|  | @ -794,28 +664,28 @@ | ||||||
|        "      <th>1</th>\n", |        "      <th>1</th>\n", | ||||||
|        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", |        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>74.0</td>\n", |        "      <td>76.0</td>\n", | ||||||
|        "      <td>30.0</td>\n", |        "      <td>30.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>2</th>\n", |        "      <th>2</th>\n", | ||||||
|        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", |        "      <td>Qwen/Qwen2.5-72B-Instruct</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>70.0</td>\n", |        "      <td>88.0</td>\n", | ||||||
|        "      <td>10.0</td>\n", |        "      <td>10.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>3</th>\n", |        "      <th>3</th>\n", | ||||||
|        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", |        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>18.8</td>\n", |        "      <td>25.0</td>\n", | ||||||
|        "      <td>3.1</td>\n", |        "      <td>3.1</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>4</th>\n", |        "      <th>4</th>\n", | ||||||
|        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", |        "      <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>76.0</td>\n", |        "      <td>86.0</td>\n", | ||||||
|        "      <td>60.0</td>\n", |        "      <td>60.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|  | @ -829,63 +699,63 @@ | ||||||
|        "      <th>6</th>\n", |        "      <th>6</th>\n", | ||||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", |        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>40.6</td>\n", |        "      <td>NaN</td>\n", | ||||||
|        "      <td>3.1</td>\n", |        "      <td>3.1</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>7</th>\n", |        "      <th>7</th>\n", | ||||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", |        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>67.0</td>\n", |        "      <td>NaN</td>\n", | ||||||
|        "      <td>50.0</td>\n", |        "      <td>50.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>8</th>\n", |        "      <th>8</th>\n", | ||||||
|        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", |        "      <td>anthropic/claude-3-5-sonnet-latest</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>90.0</td>\n", |        "      <td>NaN</td>\n", | ||||||
|        "      <td>34.0</td>\n", |        "      <td>34.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>9</th>\n", |        "      <th>9</th>\n", | ||||||
|        "      <td>gpt-4o</td>\n", |        "      <td>gpt-4o</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>28.1</td>\n", |        "      <td>25.6</td>\n", | ||||||
|        "      <td>3.1</td>\n", |        "      <td>3.1</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>10</th>\n", |        "      <th>10</th>\n", | ||||||
|        "      <td>gpt-4o</td>\n", |        "      <td>gpt-4o</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>70.0</td>\n", |        "      <td>58.0</td>\n", | ||||||
|        "      <td>40.0</td>\n", |        "      <td>40.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>11</th>\n", |        "      <th>11</th>\n", | ||||||
|        "      <td>gpt-4o</td>\n", |        "      <td>gpt-4o</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>88.0</td>\n", |        "      <td>86.0</td>\n", | ||||||
|        "      <td>6.0</td>\n", |        "      <td>6.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>12</th>\n", |        "      <th>12</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>0.0</td>\n", |        "      <td>3.1</td>\n", | ||||||
|        "      <td>0.0</td>\n", |        "      <td>0.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>13</th>\n", |        "      <th>13</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>42.0</td>\n", |        "      <td>14.0</td>\n", | ||||||
|        "      <td>18.0</td>\n", |        "      <td>18.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>14</th>\n", |        "      <th>14</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.1-8B-Instruct</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>54.0</td>\n", |        "      <td>2.0</td>\n", | ||||||
|        "      <td>6.0</td>\n", |        "      <td>6.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|  | @ -899,49 +769,49 @@ | ||||||
|        "      <th>16</th>\n", |        "      <th>16</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>32.0</td>\n", |        "      <td>40.0</td>\n", | ||||||
|        "      <td>12.0</td>\n", |        "      <td>12.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>17</th>\n", |        "      <th>17</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.2-3B-Instruct</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>4.0</td>\n", |        "      <td>20.0</td>\n", | ||||||
|        "      <td>0.0</td>\n", |        "      <td>0.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>18</th>\n", |        "      <th>18</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>34.4</td>\n", |        "      <td>31.2</td>\n", | ||||||
|        "      <td>3.1</td>\n", |        "      <td>3.1</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>19</th>\n", |        "      <th>19</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>82.0</td>\n", |        "      <td>72.0</td>\n", | ||||||
|        "      <td>40.0</td>\n", |        "      <td>40.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>20</th>\n", |        "      <th>20</th>\n", | ||||||
|        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", |        "      <td>meta-llama/Llama-3.3-70B-Instruct</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>84.0</td>\n", |        "      <td>78.0</td>\n", | ||||||
|        "      <td>12.0</td>\n", |        "      <td>12.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>21</th>\n", |        "      <th>21</th>\n", | ||||||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", |        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||||
|        "      <td>GAIA</td>\n", |        "      <td>GAIA</td>\n", | ||||||
|        "      <td>3.1</td>\n", |  | ||||||
|        "      <td>0.0</td>\n", |        "      <td>0.0</td>\n", | ||||||
|  |        "      <td>3.1</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|        "      <th>22</th>\n", |        "      <th>22</th>\n", | ||||||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", |        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||||
|        "      <td>MATH</td>\n", |        "      <td>MATH</td>\n", | ||||||
|        "      <td>20.0</td>\n", |        "      <td>30.0</td>\n", | ||||||
|        "      <td>22.0</td>\n", |        "      <td>22.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "    <tr>\n", |        "    <tr>\n", | ||||||
|  | @ -949,7 +819,7 @@ | ||||||
|        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", |        "      <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n", | ||||||
|        "      <td>SimpleQA</td>\n", |        "      <td>SimpleQA</td>\n", | ||||||
|        "      <td>30.0</td>\n", |        "      <td>30.0</td>\n", | ||||||
|        "      <td>0.0</td>\n", |        "      <td>6.0</td>\n", | ||||||
|        "    </tr>\n", |        "    </tr>\n", | ||||||
|        "  </tbody>\n", |        "  </tbody>\n", | ||||||
|        "</table>\n", |        "</table>\n", | ||||||
|  | @ -958,29 +828,29 @@ | ||||||
|       "text/plain": [ |       "text/plain": [ | ||||||
|        "action_type                              model_id    source  code  vanilla\n", |        "action_type                              model_id    source  code  vanilla\n", | ||||||
|        "0                       Qwen/Qwen2.5-72B-Instruct      GAIA  28.1      6.2\n", |        "0                       Qwen/Qwen2.5-72B-Instruct      GAIA  28.1      6.2\n", | ||||||
|        "1                       Qwen/Qwen2.5-72B-Instruct      MATH  74.0     30.0\n", |        "1                       Qwen/Qwen2.5-72B-Instruct      MATH  76.0     30.0\n", | ||||||
|        "2                       Qwen/Qwen2.5-72B-Instruct  SimpleQA  70.0     10.0\n", |        "2                       Qwen/Qwen2.5-72B-Instruct  SimpleQA  88.0     10.0\n", | ||||||
|        "3                 Qwen/Qwen2.5-Coder-32B-Instruct      GAIA  18.8      3.1\n", |        "3                 Qwen/Qwen2.5-Coder-32B-Instruct      GAIA  25.0      3.1\n", | ||||||
|        "4                 Qwen/Qwen2.5-Coder-32B-Instruct      MATH  76.0     60.0\n", |        "4                 Qwen/Qwen2.5-Coder-32B-Instruct      MATH  86.0     60.0\n", | ||||||
|        "5                 Qwen/Qwen2.5-Coder-32B-Instruct  SimpleQA  86.0      8.0\n", |        "5                 Qwen/Qwen2.5-Coder-32B-Instruct  SimpleQA  86.0      8.0\n", | ||||||
|        "6              anthropic/claude-3-5-sonnet-latest      GAIA  40.6      3.1\n", |        "6              anthropic/claude-3-5-sonnet-latest      GAIA   NaN      3.1\n", | ||||||
|        "7              anthropic/claude-3-5-sonnet-latest      MATH  67.0     50.0\n", |        "7              anthropic/claude-3-5-sonnet-latest      MATH   NaN     50.0\n", | ||||||
|        "8              anthropic/claude-3-5-sonnet-latest  SimpleQA  90.0     34.0\n", |        "8              anthropic/claude-3-5-sonnet-latest  SimpleQA   NaN     34.0\n", | ||||||
|        "9                                          gpt-4o      GAIA  28.1      3.1\n", |        "9                                          gpt-4o      GAIA  25.6      3.1\n", | ||||||
|        "10                                         gpt-4o      MATH  70.0     40.0\n", |        "10                                         gpt-4o      MATH  58.0     40.0\n", | ||||||
|        "11                                         gpt-4o  SimpleQA  88.0      6.0\n", |        "11                                         gpt-4o  SimpleQA  86.0      6.0\n", | ||||||
|        "12               meta-llama/Llama-3.1-8B-Instruct      GAIA   0.0      0.0\n", |        "12               meta-llama/Llama-3.1-8B-Instruct      GAIA   3.1      0.0\n", | ||||||
|        "13               meta-llama/Llama-3.1-8B-Instruct      MATH  42.0     18.0\n", |        "13               meta-llama/Llama-3.1-8B-Instruct      MATH  14.0     18.0\n", | ||||||
|        "14               meta-llama/Llama-3.1-8B-Instruct  SimpleQA  54.0      6.0\n", |        "14               meta-llama/Llama-3.1-8B-Instruct  SimpleQA   2.0      6.0\n", | ||||||
|        "15               meta-llama/Llama-3.2-3B-Instruct      GAIA   3.1      0.0\n", |        "15               meta-llama/Llama-3.2-3B-Instruct      GAIA   3.1      0.0\n", | ||||||
|        "16               meta-llama/Llama-3.2-3B-Instruct      MATH  32.0     12.0\n", |        "16               meta-llama/Llama-3.2-3B-Instruct      MATH  40.0     12.0\n", | ||||||
|        "17               meta-llama/Llama-3.2-3B-Instruct  SimpleQA   4.0      0.0\n", |        "17               meta-llama/Llama-3.2-3B-Instruct  SimpleQA  20.0      0.0\n", | ||||||
|        "18              meta-llama/Llama-3.3-70B-Instruct      GAIA  34.4      3.1\n", |        "18              meta-llama/Llama-3.3-70B-Instruct      GAIA  31.2      3.1\n", | ||||||
|        "19              meta-llama/Llama-3.3-70B-Instruct      MATH  82.0     40.0\n", |        "19              meta-llama/Llama-3.3-70B-Instruct      MATH  72.0     40.0\n", | ||||||
|        "20              meta-llama/Llama-3.3-70B-Instruct  SimpleQA  84.0     12.0\n", |        "20              meta-llama/Llama-3.3-70B-Instruct  SimpleQA  78.0     12.0\n", | ||||||
|        "21           mistralai/Mistral-Nemo-Instruct-2407      GAIA   3.1      0.0\n", |        "21           mistralai/Mistral-Nemo-Instruct-2407      GAIA   0.0      3.1\n", | ||||||
|        "22           mistralai/Mistral-Nemo-Instruct-2407      MATH  20.0     22.0\n", |        "22           mistralai/Mistral-Nemo-Instruct-2407      MATH  30.0     22.0\n", | ||||||
|        "23           mistralai/Mistral-Nemo-Instruct-2407  SimpleQA  30.0      0.0" |        "23           mistralai/Mistral-Nemo-Instruct-2407  SimpleQA  30.0      6.0" | ||||||
|       ] |       ] | ||||||
|      }, |      }, | ||||||
|      "metadata": {}, |      "metadata": {}, | ||||||
|  | @ -1005,6 +875,15 @@ | ||||||
|      }, |      }, | ||||||
|      "metadata": {}, |      "metadata": {}, | ||||||
|      "output_type": "display_data" |      "output_type": "display_data" | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |      "ename": "", | ||||||
|  |      "evalue": "", | ||||||
|  |      "output_type": "error", | ||||||
|  |      "traceback": [ | ||||||
|  |       "\u001b[1;31mnotebook controller is DISPOSED. \n", | ||||||
|  |       "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | ||||||
|  |      ] | ||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|  |  | ||||||
|  | @ -15,7 +15,6 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import time | import time | ||||||
| import json |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||
| 
 | 
 | ||||||
|  | @ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 tools_to_call_from=list(self.tools.values()), |                 tools_to_call_from=list(self.tools.values()), | ||||||
|                 stop_sequences=["Observation:"], |                 stop_sequences=["Observation:"], | ||||||
|             ) |             ) | ||||||
| 
 |             tool_call = model_message.tool_calls[0] | ||||||
|             # Extract tool call from model output |             tool_name, tool_call_id = tool_call.function.name, tool_call.id | ||||||
|             if ( |             tool_arguments = tool_call.function.arguments | ||||||
|                 type(model_message.tool_calls) is list |  | ||||||
|                 and len(model_message.tool_calls) > 0 |  | ||||||
|             ): |  | ||||||
|                 tool_calls = model_message.tool_calls[0] |  | ||||||
|                 tool_arguments = tool_calls.function.arguments |  | ||||||
|                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id |  | ||||||
|             else: |  | ||||||
|                 start, end = ( |  | ||||||
|                     model_message.content.find("{"), |  | ||||||
|                     model_message.content.rfind("}") + 1, |  | ||||||
|                 ) |  | ||||||
|                 tool_calls = json.loads(model_message.content[start:end]) |  | ||||||
|                 tool_arguments = tool_calls["tool_arguments"] |  | ||||||
|                 tool_name, tool_call_id = ( |  | ||||||
|                     tool_calls["tool_name"], |  | ||||||
|                     f"call_{len(self.logs)}", |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError( |             raise AgentGenerationError( | ||||||
|  | @ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 updated_information = f"Stored '{observation_name}' in memory." |                 updated_information = f"Stored '{observation_name}' in memory." | ||||||
|             else: |             else: | ||||||
|                 updated_information = str(observation).strip() |                 updated_information = str(observation).strip() | ||||||
|             self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO) |             self.logger.log( | ||||||
|  |                 f"Observations: {updated_information.replace('[', '|')}",  # escape potential rich-tag-like components | ||||||
|  |                 level=LogLevel.INFO, | ||||||
|  |             ) | ||||||
|             log_entry.observations = updated_information |             log_entry.observations = updated_information | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -31,6 +31,7 @@ from .local_python_executor import ( | ||||||
| ) | ) | ||||||
| from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | ||||||
| from .types import AgentAudio | from .types import AgentAudio | ||||||
|  | from .utils import truncate_content | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|     from transformers.models.whisper import ( |     from transformers.models.whisper import ( | ||||||
|  | @ -112,18 +113,15 @@ class PythonInterpreterTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def forward(self, code: str) -> str: |     def forward(self, code: str) -> str: | ||||||
|         state = {} |         state = {} | ||||||
|         try: |         output = str( | ||||||
|             output = str( |             self.python_evaluator( | ||||||
|                 self.python_evaluator( |                 code, | ||||||
|                     code, |                 state=state, | ||||||
|                     state=state, |                 static_tools=self.base_python_tools, | ||||||
|                     static_tools=self.base_python_tools, |                 authorized_imports=self.authorized_imports, | ||||||
|                     authorized_imports=self.authorized_imports, |             )[0]  # The second element is boolean is_final_answer | ||||||
|                 )[0]  # The second element is boolean is_final_answer |         ) | ||||||
|             ) |         return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" | ||||||
|             return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" |  | ||||||
|         except Exception as e: |  | ||||||
|             return f"Error: {str(e)}" |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class FinalAnswerTool(Tool): | class FinalAnswerTool(Tool): | ||||||
|  | @ -295,7 +293,7 @@ class VisitWebpageTool(Tool): | ||||||
|             # Remove multiple line breaks |             # Remove multiple line breaks | ||||||
|             markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) |             markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | ||||||
| 
 | 
 | ||||||
|             return markdown_content |             return truncate_content(markdown_content) | ||||||
| 
 | 
 | ||||||
|         except RequestException as e: |         except RequestException as e: | ||||||
|             return f"Error fetching the webpage: {str(e)}" |             return f"Error fetching the webpage: {str(e)}" | ||||||
|  |  | ||||||
|  | @ -14,20 +14,16 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | from dataclasses import dataclass | ||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Dict, List, Optional | from typing import Dict, List, Optional, Union, Any | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import ( | from huggingface_hub import InferenceClient | ||||||
|     InferenceClient, |  | ||||||
|     ChatCompletionOutputMessage, |  | ||||||
|     ChatCompletionOutputToolCall, |  | ||||||
|     ChatCompletionOutputFunctionDefinition, |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| from transformers import ( | from transformers import ( | ||||||
|     AutoModelForCausalLM, |     AutoModelForCausalLM, | ||||||
|  | @ -58,6 +54,27 @@ if _is_package_available("litellm"): | ||||||
|     import litellm |     import litellm | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class ChatMessageToolCallDefinition: | ||||||
|  |     arguments: Any | ||||||
|  |     name: str | ||||||
|  |     description: Optional[str] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class ChatMessageToolCall: | ||||||
|  |     function: ChatMessageToolCallDefinition | ||||||
|  |     id: str | ||||||
|  |     type: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class ChatMessage: | ||||||
|  |     role: str | ||||||
|  |     content: Optional[str] = None | ||||||
|  |     tool_calls: Optional[List[ChatMessageToolCall]] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|     USER = "user" |     USER = "user" | ||||||
|     ASSISTANT = "assistant" |     ASSISTANT = "assistant" | ||||||
|  | @ -140,6 +157,17 @@ def get_clean_message_list( | ||||||
|     return final_message_list |     return final_message_list | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: | ||||||
|  |     try: | ||||||
|  |         start, end = ( | ||||||
|  |             possible_dictionary.find("{"), | ||||||
|  |             possible_dictionary.rfind("}") + 1, | ||||||
|  |         ) | ||||||
|  |         return json.loads(possible_dictionary[start:end]) | ||||||
|  |     except Exception: | ||||||
|  |         return possible_dictionary | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class Model: | class Model: | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.last_input_token_count = None |         self.last_input_token_count = None | ||||||
|  | @ -157,7 +185,7 @@ class Model: | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> ChatCompletionOutputMessage: |     ) -> ChatMessage: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|         Parameters: |         Parameters: | ||||||
|  | @ -228,7 +256,7 @@ class HfApiModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatCompletionOutputMessage: |     ) -> ChatMessage: | ||||||
|         """ |         """ | ||||||
|         Gets an LLM output message for the given list of input messages. |         Gets an LLM output message for the given list of input messages. | ||||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. |         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||||
|  | @ -329,7 +357,7 @@ class TransformersModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatCompletionOutputMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  | @ -365,21 +393,21 @@ class TransformersModel(Model): | ||||||
|         if stop_sequences is not None: |         if stop_sequences is not None: | ||||||
|             output = remove_stop_sequences(output, stop_sequences) |             output = remove_stop_sequences(output, stop_sequences) | ||||||
|         if tools_to_call_from is None: |         if tools_to_call_from is None: | ||||||
|             return ChatCompletionOutputMessage(role="assistant", content=output) |             return ChatMessage(role="assistant", content=output) | ||||||
|         else: |         else: | ||||||
|             if "Action:" in output: |             if "Action:" in output: | ||||||
|                 output = output.split("Action:", 1)[1].strip() |                 output = output.split("Action:", 1)[1].strip() | ||||||
|             parsed_output = json.loads(output) |             parsed_output = json.loads(output) | ||||||
|             tool_name = parsed_output.get("tool_name") |             tool_name = parsed_output.get("tool_name") | ||||||
|             tool_arguments = parsed_output.get("tool_arguments") |             tool_arguments = parsed_output.get("tool_arguments") | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="".join(random.choices("0123456789", k=5)), |                         id="".join(random.choices("0123456789", k=5)), | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name=tool_name, arguments=tool_arguments |                             name=tool_name, arguments=tool_arguments | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|  | @ -414,7 +442,7 @@ class LiteLLMModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatCompletionOutputMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  | @ -485,7 +513,7 @@ class OpenAIServerModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> ChatCompletionOutputMessage: |     ) -> ChatMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -221,6 +221,16 @@ class Tool: | ||||||
|     def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): |     def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): | ||||||
|         if not self.is_initialized: |         if not self.is_initialized: | ||||||
|             self.setup() |             self.setup() | ||||||
|  | 
 | ||||||
|  |         # Handle the arguments might be passed as a single dictionary | ||||||
|  |         if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): | ||||||
|  |             potential_kwargs = args[0] | ||||||
|  | 
 | ||||||
|  |             # If the dictionary keys match our input parameters, convert it to kwargs | ||||||
|  |             if all(key in self.inputs for key in potential_kwargs): | ||||||
|  |                 args = () | ||||||
|  |                 kwargs = potential_kwargs | ||||||
|  | 
 | ||||||
|         if sanitize_inputs_outputs: |         if sanitize_inputs_outputs: | ||||||
|             args, kwargs = handle_agent_input_types(*args, **kwargs) |             args, kwargs = handle_agent_input_types(*args, **kwargs) | ||||||
|         outputs = self.forward(*args, **kwargs) |         outputs = self.forward(*args, **kwargs) | ||||||
|  |  | ||||||
|  | @ -30,10 +30,10 @@ from smolagents.agents import ( | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from smolagents.tools import tool | from smolagents.tools import tool | ||||||
| from smolagents.types import AgentImage, AgentText | from smolagents.types import AgentImage, AgentText | ||||||
| from huggingface_hub import ( | from smolagents.models import ( | ||||||
|     ChatCompletionOutputMessage, |     ChatMessage, | ||||||
|     ChatCompletionOutputToolCall, |     ChatMessageToolCall, | ||||||
|     ChatCompletionOutputFunctionDefinition, |     ChatMessageToolCallDefinition, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -47,28 +47,28 @@ class FakeToolCallModel: | ||||||
|         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None |         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||||
|     ): |     ): | ||||||
|         if len(messages) < 3: |         if len(messages) < 3: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="call_0", |                         id="call_0", | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name="python_interpreter", arguments={"code": "2*3.6452"} |                             name="python_interpreter", arguments={"code": "2*3.6452"} | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|                 ], |                 ], | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="call_1", |                         id="call_1", | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name="final_answer", arguments={"answer": "7.2904"} |                             name="final_answer", arguments={"answer": "7.2904"} | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|  | @ -81,14 +81,14 @@ class FakeToolCallModelImage: | ||||||
|         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None |         self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None | ||||||
|     ): |     ): | ||||||
|         if len(messages) < 3: |         if len(messages) < 3: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="call_0", |                         id="call_0", | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name="fake_image_generation_tool", |                             name="fake_image_generation_tool", | ||||||
|                             arguments={"prompt": "An image of a cat"}, |                             arguments={"prompt": "An image of a cat"}, | ||||||
|                         ), |                         ), | ||||||
|  | @ -96,14 +96,14 @@ class FakeToolCallModelImage: | ||||||
|                 ], |                 ], | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="call_1", |                         id="call_1", | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name="final_answer", arguments="image.png" |                             name="final_answer", arguments="image.png" | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|  | @ -114,7 +114,7 @@ class FakeToolCallModelImage: | ||||||
| def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
|  | @ -125,7 +125,7 @@ result = 2**3.6452 | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
|  | @ -140,7 +140,7 @@ final_answer(7.2904) | ||||||
| def fake_code_model_error(messages, stop_sequences=None) -> str: | def fake_code_model_error(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
|  | @ -154,7 +154,7 @@ print("Ok, calculation done!") | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
|  | @ -169,7 +169,7 @@ final_answer("got an error") | ||||||
| def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: | def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
|  | @ -183,7 +183,7 @@ print("Ok, calculation done!") | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
|  | @ -196,7 +196,7 @@ final_answer("got an error") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_import(messages, stop_sequences=None) -> str: | def fake_code_model_import(messages, stop_sequences=None) -> str: | ||||||
|     return ChatCompletionOutputMessage( |     return ChatMessage( | ||||||
|         role="assistant", |         role="assistant", | ||||||
|         content=""" |         content=""" | ||||||
| Thought: I can answer the question | Thought: I can answer the question | ||||||
|  | @ -212,7 +212,7 @@ final_answer("got an error") | ||||||
| def fake_code_functiondef(messages, stop_sequences=None) -> str: | def fake_code_functiondef(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: Let's define the function. special_marker | Thought: Let's define the function. special_marker | ||||||
|  | @ -226,7 +226,7 @@ def moving_average(x, w): | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|     else:  # We're at step 2 |     else:  # We're at step 2 | ||||||
|         return ChatCompletionOutputMessage( |         return ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|             content=""" |             content=""" | ||||||
| Thought: I can now answer the initial question | Thought: I can now answer the initial question | ||||||
|  | @ -241,7 +241,7 @@ final_answer(res) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     return ChatCompletionOutputMessage( |     return ChatMessage( | ||||||
|         role="assistant", |         role="assistant", | ||||||
|         content=""" |         content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
|  | @ -255,7 +255,7 @@ final_answer(result) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: | def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     return ChatCompletionOutputMessage( |     return ChatMessage( | ||||||
|         role="assistant", |         role="assistant", | ||||||
|         content=""" |         content=""" | ||||||
| Thought: I should multiply 2 by 3.6452. special_marker | Thought: I should multiply 2 by 3.6452. special_marker | ||||||
|  | @ -454,14 +454,14 @@ class AgentTests(unittest.TestCase): | ||||||
|             ): |             ): | ||||||
|                 if tools_to_call_from is not None: |                 if tools_to_call_from is not None: | ||||||
|                     if len(messages) < 3: |                     if len(messages) < 3: | ||||||
|                         return ChatCompletionOutputMessage( |                         return ChatMessage( | ||||||
|                             role="assistant", |                             role="assistant", | ||||||
|                             content="", |                             content="", | ||||||
|                             tool_calls=[ |                             tool_calls=[ | ||||||
|                                 ChatCompletionOutputToolCall( |                                 ChatMessageToolCall( | ||||||
|                                     id="call_0", |                                     id="call_0", | ||||||
|                                     type="function", |                                     type="function", | ||||||
|                                     function=ChatCompletionOutputFunctionDefinition( |                                     function=ChatMessageToolCallDefinition( | ||||||
|                                         name="search_agent", |                                         name="search_agent", | ||||||
|                                         arguments="Who is the current US president?", |                                         arguments="Who is the current US president?", | ||||||
|                                     ), |                                     ), | ||||||
|  | @ -470,14 +470,14 @@ class AgentTests(unittest.TestCase): | ||||||
|                         ) |                         ) | ||||||
|                     else: |                     else: | ||||||
|                         assert "Report on the current US president" in str(messages) |                         assert "Report on the current US president" in str(messages) | ||||||
|                         return ChatCompletionOutputMessage( |                         return ChatMessage( | ||||||
|                             role="assistant", |                             role="assistant", | ||||||
|                             content="", |                             content="", | ||||||
|                             tool_calls=[ |                             tool_calls=[ | ||||||
|                                 ChatCompletionOutputToolCall( |                                 ChatMessageToolCall( | ||||||
|                                     id="call_0", |                                     id="call_0", | ||||||
|                                     type="function", |                                     type="function", | ||||||
|                                     function=ChatCompletionOutputFunctionDefinition( |                                     function=ChatMessageToolCallDefinition( | ||||||
|                                         name="final_answer", arguments="Final report." |                                         name="final_answer", arguments="Final report." | ||||||
|                                     ), |                                     ), | ||||||
|                                 ) |                                 ) | ||||||
|  | @ -485,7 +485,7 @@ class AgentTests(unittest.TestCase): | ||||||
|                         ) |                         ) | ||||||
|                 else: |                 else: | ||||||
|                     if len(messages) < 3: |                     if len(messages) < 3: | ||||||
|                         return ChatCompletionOutputMessage( |                         return ChatMessage( | ||||||
|                             role="assistant", |                             role="assistant", | ||||||
|                             content=""" |                             content=""" | ||||||
| Thought: Let's call our search agent. | Thought: Let's call our search agent. | ||||||
|  | @ -497,7 +497,7 @@ result = search_agent("Who is the current US president?") | ||||||
|                         ) |                         ) | ||||||
|                     else: |                     else: | ||||||
|                         assert "Report on the current US president" in str(messages) |                         assert "Report on the current US president" in str(messages) | ||||||
|                         return ChatCompletionOutputMessage( |                         return ChatMessage( | ||||||
|                             role="assistant", |                             role="assistant", | ||||||
|                             content=""" |                             content=""" | ||||||
| Thought: Let's return the report. | Thought: Let's return the report. | ||||||
|  | @ -518,14 +518,14 @@ final_answer("Final report.") | ||||||
|                 stop_sequences=None, |                 stop_sequences=None, | ||||||
|                 grammar=None, |                 grammar=None, | ||||||
|             ): |             ): | ||||||
|                 return ChatCompletionOutputMessage( |                 return ChatMessage( | ||||||
|                     role="assistant", |                     role="assistant", | ||||||
|                     content="", |                     content="", | ||||||
|                     tool_calls=[ |                     tool_calls=[ | ||||||
|                         ChatCompletionOutputToolCall( |                         ChatMessageToolCall( | ||||||
|                             id="call_0", |                             id="call_0", | ||||||
|                             type="function", |                             type="function", | ||||||
|                             function=ChatCompletionOutputFunctionDefinition( |                             function=ChatMessageToolCallDefinition( | ||||||
|                                 name="final_answer", |                                 name="final_answer", | ||||||
|                                 arguments="Report on the current US president", |                                 arguments="Report on the current US president", | ||||||
|                             ), |                             ), | ||||||
|  | @ -568,7 +568,7 @@ final_answer("Final report.") | ||||||
| 
 | 
 | ||||||
|     def test_code_nontrivial_final_answer_works(self): |     def test_code_nontrivial_final_answer_works(self): | ||||||
|         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): |         def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="""Code: |                 content="""Code: | ||||||
| ```py | ```py | ||||||
|  |  | ||||||
|  | @ -0,0 +1,83 @@ | ||||||
|  | # coding=utf-8 | ||||||
|  | # Copyright 2024 HuggingFace Inc. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | import unittest | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool | ||||||
|  | from smolagents.types import AGENT_TYPE_MAPPING | ||||||
|  | 
 | ||||||
|  | from .test_tools import ToolTesterMixin | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DefaultToolTests(unittest.TestCase): | ||||||
|  |     def test_visit_webpage(self): | ||||||
|  |         arguments = { | ||||||
|  |             "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security" | ||||||
|  |         } | ||||||
|  |         result = VisitWebpageTool()(arguments) | ||||||
|  |         assert isinstance(result, str) | ||||||
|  |         assert ( | ||||||
|  |             "* [About Wikipedia](/wiki/Wikipedia:About)" in result | ||||||
|  |         )  # Proper wikipedia pages have an About | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|  |     def setUp(self): | ||||||
|  |         self.tool = PythonInterpreterTool(authorized_imports=["numpy"]) | ||||||
|  |         self.tool.setup() | ||||||
|  | 
 | ||||||
|  |     def test_exact_match_arg(self): | ||||||
|  |         result = self.tool("(2 / 2) * 4") | ||||||
|  |         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||||
|  | 
 | ||||||
|  |     def test_exact_match_kwarg(self): | ||||||
|  |         result = self.tool(code="(2 / 2) * 4") | ||||||
|  |         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") | ||||||
|  | 
 | ||||||
|  |     def test_agent_type_output(self): | ||||||
|  |         inputs = ["2 * 2"] | ||||||
|  |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|  |         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|  |         self.assertTrue(isinstance(output, output_type)) | ||||||
|  | 
 | ||||||
|  |     def test_agent_types_inputs(self): | ||||||
|  |         inputs = ["2 * 2"] | ||||||
|  |         _inputs = [] | ||||||
|  | 
 | ||||||
|  |         for _input, expected_input in zip(inputs, self.tool.inputs.values()): | ||||||
|  |             input_type = expected_input["type"] | ||||||
|  |             if isinstance(input_type, list): | ||||||
|  |                 _inputs.append( | ||||||
|  |                     [ | ||||||
|  |                         AGENT_TYPE_MAPPING[_input_type](_input) | ||||||
|  |                         for _input_type in input_type | ||||||
|  |                     ] | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) | ||||||
|  | 
 | ||||||
|  |         # Should not raise an error | ||||||
|  |         output = self.tool(*inputs, sanitize_inputs_outputs=True) | ||||||
|  |         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] | ||||||
|  |         self.assertTrue(isinstance(output, output_type)) | ||||||
|  | 
 | ||||||
|  |     def test_imports_work(self): | ||||||
|  |         result = self.tool("import numpy as np") | ||||||
|  |         assert "import from numpy is not allowed" not in result.lower() | ||||||
|  | 
 | ||||||
|  |     def test_unauthorized_imports_fail(self): | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             self.tool("import sympy as sp") | ||||||
|  |         assert "sympy" in str(e).lower() | ||||||
|  | @ -23,9 +23,9 @@ from smolagents import ( | ||||||
|     stream_to_gradio, |     stream_to_gradio, | ||||||
| ) | ) | ||||||
| from huggingface_hub import ( | from huggingface_hub import ( | ||||||
|     ChatCompletionOutputMessage, |     ChatMessage, | ||||||
|     ChatCompletionOutputToolCall, |     ChatMessageToolCall, | ||||||
|     ChatCompletionOutputFunctionDefinition, |     ChatMessageToolCallDefinition, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -36,21 +36,21 @@ class FakeLLMModel: | ||||||
| 
 | 
 | ||||||
|     def __call__(self, prompt, tools_to_call_from=None, **kwargs): |     def __call__(self, prompt, tools_to_call_from=None, **kwargs): | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content="", |                 content="", | ||||||
|                 tool_calls=[ |                 tool_calls=[ | ||||||
|                     ChatCompletionOutputToolCall( |                     ChatMessageToolCall( | ||||||
|                         id="fake_id", |                         id="fake_id", | ||||||
|                         type="function", |                         type="function", | ||||||
|                         function=ChatCompletionOutputFunctionDefinition( |                         function=ChatMessageToolCallDefinition( | ||||||
|                             name="final_answer", arguments={"answer": "image"} |                             name="final_answer", arguments={"answer": "image"} | ||||||
|                         ), |                         ), | ||||||
|                     ) |                     ) | ||||||
|                 ], |                 ], | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             return ChatCompletionOutputMessage( |             return ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 content=""" |                 content=""" | ||||||
| Code: | Code: | ||||||
|  | @ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase): | ||||||
|                 self.last_output_token_count = 20 |                 self.last_output_token_count = 20 | ||||||
| 
 | 
 | ||||||
|             def __call__(self, prompt, **kwargs): |             def __call__(self, prompt, **kwargs): | ||||||
|                 return ChatCompletionOutputMessage( |                 return ChatMessage(role="assistant", content="Malformed answer") | ||||||
|                     role="assistant", content="Malformed answer" |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|  |  | ||||||
|  | @ -18,15 +18,12 @@ import unittest | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool | from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||||
| from smolagents.local_python_executor import ( | from smolagents.local_python_executor import ( | ||||||
|     InterpreterError, |     InterpreterError, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
|     fix_final_answer_code, |     fix_final_answer_code, | ||||||
| ) | ) | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING |  | ||||||
| 
 |  | ||||||
| from .test_tools import ToolTesterMixin |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Fake function we will use as tool | # Fake function we will use as tool | ||||||
|  | @ -34,47 +31,6 @@ def add_two(x): | ||||||
|     return x + 2 |     return x + 2 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): |  | ||||||
|     def setUp(self): |  | ||||||
|         self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"]) |  | ||||||
|         self.tool.setup() |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_arg(self): |  | ||||||
|         result = self.tool("(2 / 2) * 4") |  | ||||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_kwarg(self): |  | ||||||
|         result = self.tool(code="(2 / 2) * 4") |  | ||||||
|         self.assertEqual(result, "Stdout:\n\nOutput: 4.0") |  | ||||||
| 
 |  | ||||||
|     def test_agent_type_output(self): |  | ||||||
|         inputs = ["2 * 2"] |  | ||||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) |  | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |  | ||||||
|         self.assertTrue(isinstance(output, output_type)) |  | ||||||
| 
 |  | ||||||
|     def test_agent_types_inputs(self): |  | ||||||
|         inputs = ["2 * 2"] |  | ||||||
|         _inputs = [] |  | ||||||
| 
 |  | ||||||
|         for _input, expected_input in zip(inputs, self.tool.inputs.values()): |  | ||||||
|             input_type = expected_input["type"] |  | ||||||
|             if isinstance(input_type, list): |  | ||||||
|                 _inputs.append( |  | ||||||
|                     [ |  | ||||||
|                         AGENT_TYPE_MAPPING[_input_type](_input) |  | ||||||
|                         for _input_type in input_type |  | ||||||
|                     ] |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) |  | ||||||
| 
 |  | ||||||
|         # Should not raise an error |  | ||||||
|         output = self.tool(*inputs, sanitize_inputs_outputs=True) |  | ||||||
|         output_type = AGENT_TYPE_MAPPING[self.tool.output_type] |  | ||||||
|         self.assertTrue(isinstance(output, output_type)) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class PythonInterpreterTester(unittest.TestCase): | class PythonInterpreterTester(unittest.TestCase): | ||||||
|     def test_evaluate_assign(self): |     def test_evaluate_assign(self): | ||||||
|         code = "x = 3" |         code = "x = 3" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue