Make default tools more robust (#186)
This commit is contained in:
parent
12a2e6f4b4
commit
5f32373551
|
@ -36,6 +36,9 @@ jobs:
|
||||||
- name: Agent tests
|
- name: Agent tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_agents.py
|
uv run pytest -sv ./tests/test_agents.py
|
||||||
|
- name: Default tools tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv ./tests/test_default_tools.py
|
||||||
- name: Final answer tests
|
- name: Final answer tests
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -sv ./tests/test_final_answer.py
|
uv run pytest -sv ./tests/test_final_answer.py
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -29,8 +29,7 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
"Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -173,7 +172,7 @@
|
||||||
"[132 rows x 4 columns]"
|
"[132 rows x 4 columns]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -196,19 +195,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
|
|
||||||
"* 'fields' has been removed\n",
|
|
||||||
" warnings.warn(message, UserWarning)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import time\n",
|
"import time\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
@ -408,100 +397,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"open_model_ids = [\n",
|
"open_model_ids = [\n",
|
||||||
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
||||||
|
@ -554,42 +452,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'gpt-4o'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
|
|
||||||
"100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from smolagents import LiteLLMModel\n",
|
"from smolagents import LiteLLMModel\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -614,7 +479,7 @@
|
||||||
" agent = CodeAgent(\n",
|
" agent = CodeAgent(\n",
|
||||||
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
||||||
" model=LiteLLMModel(model_id),\n",
|
" model=LiteLLMModel(model_id),\n",
|
||||||
" additional_authorized_imports=[\"numpy\"],\n",
|
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
||||||
" max_steps=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||||
|
@ -631,34 +496,39 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# import glob\n",
|
"# import glob\n",
|
||||||
"# import json\n",
|
"# import json\n",
|
||||||
|
"\n",
|
||||||
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
|
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# for file_path in jsonl_files:\n",
|
"# for file_path in jsonl_files:\n",
|
||||||
"# print(file_path)\n",
|
"# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
|
||||||
"# # Read all lines and filter out SimpleQA sources\n",
|
"# print(file_path)\n",
|
||||||
"# filtered_lines = []\n",
|
"# # Read all lines and filter out SimpleQA sources\n",
|
||||||
"# removed = 0\n",
|
"# filtered_lines = []\n",
|
||||||
"# with open(file_path, 'r', encoding='utf-8') as f:\n",
|
"# removed = 0\n",
|
||||||
"# for line in f:\n",
|
"# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
|
||||||
"# try:\n",
|
"# for line in f:\n",
|
||||||
"# data = json.loads(line.strip())\n",
|
"# try:\n",
|
||||||
"# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
|
"# data = json.loads(line.strip())\n",
|
||||||
"# removed +=1\n",
|
"# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
|
||||||
"# else:\n",
|
"# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
|
||||||
"# filtered_lines.append(line)\n",
|
"# # removed +=1\n",
|
||||||
"# except json.JSONDecodeError:\n",
|
"# # else:\n",
|
||||||
"# print(\"Invalid line:\", line)\n",
|
"# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
|
||||||
"# continue # Skip invalid JSON lines\n",
|
"# except json.JSONDecodeError:\n",
|
||||||
"# print(f\"Removed {removed} lines.\")\n",
|
"# print(\"Invalid line:\", line)\n",
|
||||||
"# # Write filtered content back to the same file\n",
|
"# continue # Skip invalid JSON lines\n",
|
||||||
"# with open(file_path, 'w', encoding='utf-8') as f:\n",
|
"# print(f\"Removed {removed} lines.\")\n",
|
||||||
"# f.writelines(filtered_lines)"
|
"# # Write filtered content back to the same file\n",
|
||||||
|
"# with open(\n",
|
||||||
|
"# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
|
||||||
|
"# ) as f:\n",
|
||||||
|
"# f.writelines(filtered_lines)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -670,14 +540,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
|
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
" warnings.warn(\n"
|
" warnings.warn(\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -731,7 +601,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -752,7 +622,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -794,28 +664,28 @@
|
||||||
" <th>1</th>\n",
|
" <th>1</th>\n",
|
||||||
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>74.0</td>\n",
|
" <td>76.0</td>\n",
|
||||||
" <td>30.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>2</th>\n",
|
" <th>2</th>\n",
|
||||||
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>70.0</td>\n",
|
" <td>88.0</td>\n",
|
||||||
" <td>10.0</td>\n",
|
" <td>10.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>3</th>\n",
|
" <th>3</th>\n",
|
||||||
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>18.8</td>\n",
|
" <td>25.0</td>\n",
|
||||||
" <td>3.1</td>\n",
|
" <td>3.1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>4</th>\n",
|
" <th>4</th>\n",
|
||||||
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>76.0</td>\n",
|
" <td>86.0</td>\n",
|
||||||
" <td>60.0</td>\n",
|
" <td>60.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
|
@ -829,63 +699,63 @@
|
||||||
" <th>6</th>\n",
|
" <th>6</th>\n",
|
||||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>40.6</td>\n",
|
" <td>NaN</td>\n",
|
||||||
" <td>3.1</td>\n",
|
" <td>3.1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>7</th>\n",
|
" <th>7</th>\n",
|
||||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>67.0</td>\n",
|
" <td>NaN</td>\n",
|
||||||
" <td>50.0</td>\n",
|
" <td>50.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>8</th>\n",
|
" <th>8</th>\n",
|
||||||
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>90.0</td>\n",
|
" <td>NaN</td>\n",
|
||||||
" <td>34.0</td>\n",
|
" <td>34.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>9</th>\n",
|
" <th>9</th>\n",
|
||||||
" <td>gpt-4o</td>\n",
|
" <td>gpt-4o</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>28.1</td>\n",
|
" <td>25.6</td>\n",
|
||||||
" <td>3.1</td>\n",
|
" <td>3.1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>10</th>\n",
|
" <th>10</th>\n",
|
||||||
" <td>gpt-4o</td>\n",
|
" <td>gpt-4o</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>70.0</td>\n",
|
" <td>58.0</td>\n",
|
||||||
" <td>40.0</td>\n",
|
" <td>40.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>11</th>\n",
|
" <th>11</th>\n",
|
||||||
" <td>gpt-4o</td>\n",
|
" <td>gpt-4o</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>88.0</td>\n",
|
" <td>86.0</td>\n",
|
||||||
" <td>6.0</td>\n",
|
" <td>6.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>12</th>\n",
|
" <th>12</th>\n",
|
||||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>3.1</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>13</th>\n",
|
" <th>13</th>\n",
|
||||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>42.0</td>\n",
|
" <td>14.0</td>\n",
|
||||||
" <td>18.0</td>\n",
|
" <td>18.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>14</th>\n",
|
" <th>14</th>\n",
|
||||||
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>54.0</td>\n",
|
" <td>2.0</td>\n",
|
||||||
" <td>6.0</td>\n",
|
" <td>6.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
|
@ -899,49 +769,49 @@
|
||||||
" <th>16</th>\n",
|
" <th>16</th>\n",
|
||||||
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>32.0</td>\n",
|
" <td>40.0</td>\n",
|
||||||
" <td>12.0</td>\n",
|
" <td>12.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>17</th>\n",
|
" <th>17</th>\n",
|
||||||
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>4.0</td>\n",
|
" <td>20.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>18</th>\n",
|
" <th>18</th>\n",
|
||||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>34.4</td>\n",
|
" <td>31.2</td>\n",
|
||||||
" <td>3.1</td>\n",
|
" <td>3.1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>19</th>\n",
|
" <th>19</th>\n",
|
||||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>82.0</td>\n",
|
" <td>72.0</td>\n",
|
||||||
" <td>40.0</td>\n",
|
" <td>40.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>20</th>\n",
|
" <th>20</th>\n",
|
||||||
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>84.0</td>\n",
|
" <td>78.0</td>\n",
|
||||||
" <td>12.0</td>\n",
|
" <td>12.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>21</th>\n",
|
" <th>21</th>\n",
|
||||||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||||
" <td>GAIA</td>\n",
|
" <td>GAIA</td>\n",
|
||||||
" <td>3.1</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
|
" <td>3.1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>22</th>\n",
|
" <th>22</th>\n",
|
||||||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||||
" <td>MATH</td>\n",
|
" <td>MATH</td>\n",
|
||||||
" <td>20.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" <td>22.0</td>\n",
|
" <td>22.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
|
@ -949,7 +819,7 @@
|
||||||
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
|
||||||
" <td>SimpleQA</td>\n",
|
" <td>SimpleQA</td>\n",
|
||||||
" <td>30.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>6.0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" </tbody>\n",
|
" </tbody>\n",
|
||||||
"</table>\n",
|
"</table>\n",
|
||||||
|
@ -958,29 +828,29 @@
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"action_type model_id source code vanilla\n",
|
"action_type model_id source code vanilla\n",
|
||||||
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
|
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
|
||||||
"1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
|
"1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
|
||||||
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
|
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
|
||||||
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
|
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
|
||||||
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
|
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
|
||||||
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
|
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
|
||||||
"6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
|
"6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
|
||||||
"7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
|
"7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
|
||||||
"8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
|
"8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
|
||||||
"9 gpt-4o GAIA 28.1 3.1\n",
|
"9 gpt-4o GAIA 25.6 3.1\n",
|
||||||
"10 gpt-4o MATH 70.0 40.0\n",
|
"10 gpt-4o MATH 58.0 40.0\n",
|
||||||
"11 gpt-4o SimpleQA 88.0 6.0\n",
|
"11 gpt-4o SimpleQA 86.0 6.0\n",
|
||||||
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
|
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
|
||||||
"13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
|
"13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
|
||||||
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
|
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
|
||||||
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
|
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
|
||||||
"16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
|
"16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
|
||||||
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
|
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
|
||||||
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
|
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
|
||||||
"19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
|
"19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
|
||||||
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
|
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
|
||||||
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
|
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
|
||||||
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
|
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
|
||||||
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
|
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -1005,6 +875,15 @@
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "display_data"
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "",
|
||||||
|
"evalue": "",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
||||||
|
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tools_to_call_from=list(self.tools.values()),
|
tools_to_call_from=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
|
tool_call = model_message.tool_calls[0]
|
||||||
# Extract tool call from model output
|
tool_name, tool_call_id = tool_call.function.name, tool_call.id
|
||||||
if (
|
tool_arguments = tool_call.function.arguments
|
||||||
type(model_message.tool_calls) is list
|
|
||||||
and len(model_message.tool_calls) > 0
|
|
||||||
):
|
|
||||||
tool_calls = model_message.tool_calls[0]
|
|
||||||
tool_arguments = tool_calls.function.arguments
|
|
||||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
|
||||||
else:
|
|
||||||
start, end = (
|
|
||||||
model_message.content.find("{"),
|
|
||||||
model_message.content.rfind("}") + 1,
|
|
||||||
)
|
|
||||||
tool_calls = json.loads(model_message.content[start:end])
|
|
||||||
tool_arguments = tool_calls["tool_arguments"]
|
|
||||||
tool_name, tool_call_id = (
|
|
||||||
tool_calls["tool_name"],
|
|
||||||
f"call_{len(self.logs)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(
|
raise AgentGenerationError(
|
||||||
|
@ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
updated_information = f"Stored '{observation_name}' in memory."
|
updated_information = f"Stored '{observation_name}' in memory."
|
||||||
else:
|
else:
|
||||||
updated_information = str(observation).strip()
|
updated_information = str(observation).strip()
|
||||||
self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
|
self.logger.log(
|
||||||
|
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
|
||||||
|
level=LogLevel.INFO,
|
||||||
|
)
|
||||||
log_entry.observations = updated_information
|
log_entry.observations = updated_information
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from .local_python_executor import (
|
||||||
)
|
)
|
||||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||||
from .types import AgentAudio
|
from .types import AgentAudio
|
||||||
|
from .utils import truncate_content
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.models.whisper import (
|
from transformers.models.whisper import (
|
||||||
|
@ -112,18 +113,15 @@ class PythonInterpreterTool(Tool):
|
||||||
|
|
||||||
def forward(self, code: str) -> str:
|
def forward(self, code: str) -> str:
|
||||||
state = {}
|
state = {}
|
||||||
try:
|
output = str(
|
||||||
output = str(
|
self.python_evaluator(
|
||||||
self.python_evaluator(
|
code,
|
||||||
code,
|
state=state,
|
||||||
state=state,
|
static_tools=self.base_python_tools,
|
||||||
static_tools=self.base_python_tools,
|
authorized_imports=self.authorized_imports,
|
||||||
authorized_imports=self.authorized_imports,
|
)[0] # The second element is boolean is_final_answer
|
||||||
)[0] # The second element is boolean is_final_answer
|
)
|
||||||
)
|
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
||||||
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
class FinalAnswerTool(Tool):
|
class FinalAnswerTool(Tool):
|
||||||
|
@ -295,7 +293,7 @@ class VisitWebpageTool(Tool):
|
||||||
# Remove multiple line breaks
|
# Remove multiple line breaks
|
||||||
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
||||||
|
|
||||||
return markdown_content
|
return truncate_content(markdown_content)
|
||||||
|
|
||||||
except RequestException as e:
|
except RequestException as e:
|
||||||
return f"Error fetching the webpage: {str(e)}"
|
return f"Error fetching the webpage: {str(e)}"
|
||||||
|
|
|
@ -14,20 +14,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union, Any
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import InferenceClient
|
||||||
InferenceClient,
|
|
||||||
ChatCompletionOutputMessage,
|
|
||||||
ChatCompletionOutputToolCall,
|
|
||||||
ChatCompletionOutputFunctionDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
@ -58,6 +54,27 @@ if _is_package_available("litellm"):
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessageToolCallDefinition:
|
||||||
|
arguments: Any
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessageToolCall:
|
||||||
|
function: ChatMessageToolCallDefinition
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
role: str
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ChatMessageToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
|
@ -140,6 +157,17 @@ def get_clean_message_list(
|
||||||
return final_message_list
|
return final_message_list
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
|
||||||
|
try:
|
||||||
|
start, end = (
|
||||||
|
possible_dictionary.find("{"),
|
||||||
|
possible_dictionary.rfind("}") + 1,
|
||||||
|
)
|
||||||
|
return json.loads(possible_dictionary[start:end])
|
||||||
|
except Exception:
|
||||||
|
return possible_dictionary
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_input_token_count = None
|
self.last_input_token_count = None
|
||||||
|
@ -157,7 +185,7 @@ class Model:
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
) -> ChatCompletionOutputMessage:
|
) -> ChatMessage:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -228,7 +256,7 @@ class HfApiModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatCompletionOutputMessage:
|
) -> ChatMessage:
|
||||||
"""
|
"""
|
||||||
Gets an LLM output message for the given list of input messages.
|
Gets an LLM output message for the given list of input messages.
|
||||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||||
|
@ -329,7 +357,7 @@ class TransformersModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatCompletionOutputMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
@ -365,21 +393,21 @@ class TransformersModel(Model):
|
||||||
if stop_sequences is not None:
|
if stop_sequences is not None:
|
||||||
output = remove_stop_sequences(output, stop_sequences)
|
output = remove_stop_sequences(output, stop_sequences)
|
||||||
if tools_to_call_from is None:
|
if tools_to_call_from is None:
|
||||||
return ChatCompletionOutputMessage(role="assistant", content=output)
|
return ChatMessage(role="assistant", content=output)
|
||||||
else:
|
else:
|
||||||
if "Action:" in output:
|
if "Action:" in output:
|
||||||
output = output.split("Action:", 1)[1].strip()
|
output = output.split("Action:", 1)[1].strip()
|
||||||
parsed_output = json.loads(output)
|
parsed_output = json.loads(output)
|
||||||
tool_name = parsed_output.get("tool_name")
|
tool_name = parsed_output.get("tool_name")
|
||||||
tool_arguments = parsed_output.get("tool_arguments")
|
tool_arguments = parsed_output.get("tool_arguments")
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="".join(random.choices("0123456789", k=5)),
|
id="".join(random.choices("0123456789", k=5)),
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name=tool_name, arguments=tool_arguments
|
name=tool_name, arguments=tool_arguments
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -414,7 +442,7 @@ class LiteLLMModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatCompletionOutputMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
@ -485,7 +513,7 @@ class OpenAIServerModel(Model):
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
tools_to_call_from: Optional[List[Tool]] = None,
|
tools_to_call_from: Optional[List[Tool]] = None,
|
||||||
) -> ChatCompletionOutputMessage:
|
) -> ChatMessage:
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
|
|
@ -221,6 +221,16 @@ class Tool:
|
||||||
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
||||||
if not self.is_initialized:
|
if not self.is_initialized:
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|
||||||
|
# Handle the arguments might be passed as a single dictionary
|
||||||
|
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
|
||||||
|
potential_kwargs = args[0]
|
||||||
|
|
||||||
|
# If the dictionary keys match our input parameters, convert it to kwargs
|
||||||
|
if all(key in self.inputs for key in potential_kwargs):
|
||||||
|
args = ()
|
||||||
|
kwargs = potential_kwargs
|
||||||
|
|
||||||
if sanitize_inputs_outputs:
|
if sanitize_inputs_outputs:
|
||||||
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||||
outputs = self.forward(*args, **kwargs)
|
outputs = self.forward(*args, **kwargs)
|
||||||
|
|
|
@ -30,10 +30,10 @@ from smolagents.agents import (
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.tools import tool
|
from smolagents.tools import tool
|
||||||
from smolagents.types import AgentImage, AgentText
|
from smolagents.types import AgentImage, AgentText
|
||||||
from huggingface_hub import (
|
from smolagents.models import (
|
||||||
ChatCompletionOutputMessage,
|
ChatMessage,
|
||||||
ChatCompletionOutputToolCall,
|
ChatMessageToolCall,
|
||||||
ChatCompletionOutputFunctionDefinition,
|
ChatMessageToolCallDefinition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,28 +47,28 @@ class FakeToolCallModel:
|
||||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||||
):
|
):
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_0",
|
id="call_0",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="python_interpreter", arguments={"code": "2*3.6452"}
|
name="python_interpreter", arguments={"code": "2*3.6452"}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_1",
|
id="call_1",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="final_answer", arguments={"answer": "7.2904"}
|
name="final_answer", arguments={"answer": "7.2904"}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -81,14 +81,14 @@ class FakeToolCallModelImage:
|
||||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||||
):
|
):
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_0",
|
id="call_0",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="fake_image_generation_tool",
|
name="fake_image_generation_tool",
|
||||||
arguments={"prompt": "An image of a cat"},
|
arguments={"prompt": "An image of a cat"},
|
||||||
),
|
),
|
||||||
|
@ -96,14 +96,14 @@ class FakeToolCallModelImage:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_1",
|
id="call_1",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="final_answer", arguments="image.png"
|
name="final_answer", arguments="image.png"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -114,7 +114,7 @@ class FakeToolCallModelImage:
|
||||||
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
@ -125,7 +125,7 @@ result = 2**3.6452
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
|
@ -140,7 +140,7 @@ final_answer(7.2904)
|
||||||
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
def fake_code_model_error(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
@ -154,7 +154,7 @@ print("Ok, calculation done!")
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
|
@ -169,7 +169,7 @@ final_answer("got an error")
|
||||||
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
@ -183,7 +183,7 @@ print("Ok, calculation done!")
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
|
@ -196,7 +196,7 @@ final_answer("got an error")
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
def fake_code_model_import(messages, stop_sequences=None) -> str:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I can answer the question
|
Thought: I can answer the question
|
||||||
|
@ -212,7 +212,7 @@ final_answer("got an error")
|
||||||
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
def fake_code_functiondef(messages, stop_sequences=None) -> str:
|
||||||
prompt = str(messages)
|
prompt = str(messages)
|
||||||
if "special_marker" not in prompt:
|
if "special_marker" not in prompt:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: Let's define the function. special_marker
|
Thought: Let's define the function. special_marker
|
||||||
|
@ -226,7 +226,7 @@ def moving_average(x, w):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
else: # We're at step 2
|
else: # We're at step 2
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I can now answer the initial question
|
Thought: I can now answer the initial question
|
||||||
|
@ -241,7 +241,7 @@ final_answer(res)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
@ -255,7 +255,7 @@ final_answer(result)
|
||||||
|
|
||||||
|
|
||||||
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
@ -454,14 +454,14 @@ class AgentTests(unittest.TestCase):
|
||||||
):
|
):
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_0",
|
id="call_0",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="search_agent",
|
name="search_agent",
|
||||||
arguments="Who is the current US president?",
|
arguments="Who is the current US president?",
|
||||||
),
|
),
|
||||||
|
@ -470,14 +470,14 @@ class AgentTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert "Report on the current US president" in str(messages)
|
assert "Report on the current US president" in str(messages)
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_0",
|
id="call_0",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="final_answer", arguments="Final report."
|
name="final_answer", arguments="Final report."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -485,7 +485,7 @@ class AgentTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if len(messages) < 3:
|
if len(messages) < 3:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: Let's call our search agent.
|
Thought: Let's call our search agent.
|
||||||
|
@ -497,7 +497,7 @@ result = search_agent("Who is the current US president?")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert "Report on the current US president" in str(messages)
|
assert "Report on the current US president" in str(messages)
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Thought: Let's return the report.
|
Thought: Let's return the report.
|
||||||
|
@ -518,14 +518,14 @@ final_answer("Final report.")
|
||||||
stop_sequences=None,
|
stop_sequences=None,
|
||||||
grammar=None,
|
grammar=None,
|
||||||
):
|
):
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="call_0",
|
id="call_0",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="final_answer",
|
name="final_answer",
|
||||||
arguments="Report on the current US president",
|
arguments="Report on the current US president",
|
||||||
),
|
),
|
||||||
|
@ -568,7 +568,7 @@ final_answer("Final report.")
|
||||||
|
|
||||||
def test_code_nontrivial_final_answer_works(self):
|
def test_code_nontrivial_final_answer_works(self):
|
||||||
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""Code:
|
content="""Code:
|
||||||
```py
|
```py
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
||||||
|
from smolagents.types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultToolTests(unittest.TestCase):
|
||||||
|
def test_visit_webpage(self):
|
||||||
|
arguments = {
|
||||||
|
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
|
||||||
|
}
|
||||||
|
result = VisitWebpageTool()(arguments)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert (
|
||||||
|
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
|
||||||
|
) # Proper wikipedia pages have an About
|
||||||
|
|
||||||
|
|
||||||
|
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
def setUp(self):
|
||||||
|
self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
|
||||||
|
self.tool.setup()
|
||||||
|
|
||||||
|
def test_exact_match_arg(self):
|
||||||
|
result = self.tool("(2 / 2) * 4")
|
||||||
|
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||||
|
|
||||||
|
def test_exact_match_kwarg(self):
|
||||||
|
result = self.tool(code="(2 / 2) * 4")
|
||||||
|
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
||||||
|
|
||||||
|
def test_agent_type_output(self):
|
||||||
|
inputs = ["2 * 2"]
|
||||||
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
|
def test_agent_types_inputs(self):
|
||||||
|
inputs = ["2 * 2"]
|
||||||
|
_inputs = []
|
||||||
|
|
||||||
|
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||||
|
input_type = expected_input["type"]
|
||||||
|
if isinstance(input_type, list):
|
||||||
|
_inputs.append(
|
||||||
|
[
|
||||||
|
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||||
|
for _input_type in input_type
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
||||||
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
|
def test_imports_work(self):
|
||||||
|
result = self.tool("import numpy as np")
|
||||||
|
assert "import from numpy is not allowed" not in result.lower()
|
||||||
|
|
||||||
|
def test_unauthorized_imports_fail(self):
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
self.tool("import sympy as sp")
|
||||||
|
assert "sympy" in str(e).lower()
|
|
@ -23,9 +23,9 @@ from smolagents import (
|
||||||
stream_to_gradio,
|
stream_to_gradio,
|
||||||
)
|
)
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
ChatCompletionOutputMessage,
|
ChatMessage,
|
||||||
ChatCompletionOutputToolCall,
|
ChatMessageToolCall,
|
||||||
ChatCompletionOutputFunctionDefinition,
|
ChatMessageToolCallDefinition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,21 +36,21 @@ class FakeLLMModel:
|
||||||
|
|
||||||
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
|
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ChatCompletionOutputToolCall(
|
ChatMessageToolCall(
|
||||||
id="fake_id",
|
id="fake_id",
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionOutputFunctionDefinition(
|
function=ChatMessageToolCallDefinition(
|
||||||
name="final_answer", arguments={"answer": "image"}
|
name="final_answer", arguments={"answer": "image"}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="""
|
content="""
|
||||||
Code:
|
Code:
|
||||||
|
@ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase):
|
||||||
self.last_output_token_count = 20
|
self.last_output_token_count = 20
|
||||||
|
|
||||||
def __call__(self, prompt, **kwargs):
|
def __call__(self, prompt, **kwargs):
|
||||||
return ChatCompletionOutputMessage(
|
return ChatMessage(role="assistant", content="Malformed answer")
|
||||||
role="assistant", content="Malformed answer"
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
|
|
@ -18,15 +18,12 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
|
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from smolagents.local_python_executor import (
|
from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
fix_final_answer_code,
|
fix_final_answer_code,
|
||||||
)
|
)
|
||||||
from smolagents.types import AGENT_TYPE_MAPPING
|
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
# Fake function we will use as tool
|
# Fake function we will use as tool
|
||||||
|
@ -34,47 +31,6 @@ def add_two(x):
|
||||||
return x + 2
|
return x + 2
|
||||||
|
|
||||||
|
|
||||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
|
|
||||||
self.tool.setup()
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
result = self.tool("(2 / 2) * 4")
|
|
||||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
result = self.tool(code="(2 / 2) * 4")
|
|
||||||
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
|
|
||||||
|
|
||||||
def test_agent_type_output(self):
|
|
||||||
inputs = ["2 * 2"]
|
|
||||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
|
||||||
|
|
||||||
def test_agent_types_inputs(self):
|
|
||||||
inputs = ["2 * 2"]
|
|
||||||
_inputs = []
|
|
||||||
|
|
||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
|
||||||
input_type = expected_input["type"]
|
|
||||||
if isinstance(input_type, list):
|
|
||||||
_inputs.append(
|
|
||||||
[
|
|
||||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
|
||||||
for _input_type in input_type
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
|
||||||
|
|
||||||
# Should not raise an error
|
|
||||||
output = self.tool(*inputs, sanitize_inputs_outputs=True)
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInterpreterTester(unittest.TestCase):
|
class PythonInterpreterTester(unittest.TestCase):
|
||||||
def test_evaluate_assign(self):
|
def test_evaluate_assign(self):
|
||||||
code = "x = 3"
|
code = "x = 3"
|
||||||
|
|
Loading…
Reference in New Issue