From 5f323735511f54168b688cdb0dee10ab5bdcd909 Mon Sep 17 00:00:00 2001
From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
Date: Tue, 14 Jan 2025 14:57:11 +0100
Subject: [PATCH] Make default tools more robust (#186)
---
.github/workflows/tests.yml | 3 +
examples/benchmark.ipynb | 301 +++++++++----------------------
src/smolagents/agents.py | 29 +--
src/smolagents/default_tools.py | 24 ++-
src/smolagents/models.py | 60 ++++--
src/smolagents/tools.py | 10 +
tests/test_agents.py | 78 ++++----
tests/test_default_tools.py | 83 +++++++++
tests/test_monitoring.py | 18 +-
tests/test_python_interpreter.py | 46 +----
10 files changed, 296 insertions(+), 356 deletions(-)
create mode 100644 tests/test_default_tools.py
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index a595bed..c720ec0 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -36,6 +36,9 @@ jobs:
- name: Agent tests
run: |
uv run pytest -sv ./tests/test_agents.py
+ - name: Default tools tests
+ run: |
+ uv run pytest -sv ./tests/test_default_tools.py
- name: Final answer tests
run: |
uv run pytest -sv ./tests/test_final_answer.py
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index 1009f28..8b49b0a 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -21,7 +21,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -29,8 +29,7 @@
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
+ " from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
@@ -173,7 +172,7 @@
"[132 rows x 4 columns]"
]
},
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -196,19 +195,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
- "* 'fields' has been removed\n",
- " warnings.warn(message, UserWarning)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import time\n",
"import json\n",
@@ -408,100 +397,9 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
@@ -554,42 +452,9 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'gpt-4o'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
@@ -614,7 +479,7 @@
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=LiteLLMModel(model_id),\n",
- " additional_authorized_imports=[\"numpy\"],\n",
+ " additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
@@ -631,34 +496,39 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# import glob\n",
"# import json\n",
+ "\n",
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
"\n",
"# for file_path in jsonl_files:\n",
- "# print(file_path)\n",
- "# # Read all lines and filter out SimpleQA sources\n",
- "# filtered_lines = []\n",
- "# removed = 0\n",
- "# with open(file_path, 'r', encoding='utf-8') as f:\n",
- "# for line in f:\n",
- "# try:\n",
- "# data = json.loads(line.strip())\n",
- "# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
- "# removed +=1\n",
- "# else:\n",
- "# filtered_lines.append(line)\n",
- "# except json.JSONDecodeError:\n",
- "# print(\"Invalid line:\", line)\n",
- "# continue # Skip invalid JSON lines\n",
- "# print(f\"Removed {removed} lines.\")\n",
- "# # Write filtered content back to the same file\n",
- "# with open(file_path, 'w', encoding='utf-8') as f:\n",
- "# f.writelines(filtered_lines)"
+ "# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
+ "# print(file_path)\n",
+ "# # Read all lines and filter out SimpleQA sources\n",
+ "# filtered_lines = []\n",
+ "# removed = 0\n",
+ "# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
+ "# for line in f:\n",
+ "# try:\n",
+ "# data = json.loads(line.strip())\n",
+ "# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
+ "# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
+ "# # removed +=1\n",
+ "# # else:\n",
+ "# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
+ "# except json.JSONDecodeError:\n",
+ "# print(\"Invalid line:\", line)\n",
+ "# continue # Skip invalid JSON lines\n",
+ "# print(f\"Removed {removed} lines.\")\n",
+ "# # Write filtered content back to the same file\n",
+ "# with open(\n",
+ "# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
+ "# ) as f:\n",
+ "# f.writelines(filtered_lines)"
]
},
{
@@ -670,14 +540,14 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
+ "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
@@ -731,7 +601,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -752,7 +622,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -794,28 +664,28 @@
"
1 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
" MATH | \n",
- " 74.0 | \n",
+ " 76.0 | \n",
" 30.0 | \n",
" \n",
" \n",
" 2 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
" SimpleQA | \n",
- " 70.0 | \n",
+ " 88.0 | \n",
" 10.0 | \n",
"
\n",
" \n",
" 3 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" GAIA | \n",
- " 18.8 | \n",
+ " 25.0 | \n",
" 3.1 | \n",
"
\n",
" \n",
" 4 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" MATH | \n",
- " 76.0 | \n",
+ " 86.0 | \n",
" 60.0 | \n",
"
\n",
" \n",
@@ -829,63 +699,63 @@
" 6 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" GAIA | \n",
- " 40.6 | \n",
+ " NaN | \n",
" 3.1 | \n",
"
\n",
" \n",
" 7 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" MATH | \n",
- " 67.0 | \n",
+ " NaN | \n",
" 50.0 | \n",
"
\n",
" \n",
" 8 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" SimpleQA | \n",
- " 90.0 | \n",
+ " NaN | \n",
" 34.0 | \n",
"
\n",
" \n",
" 9 | \n",
" gpt-4o | \n",
" GAIA | \n",
- " 28.1 | \n",
+ " 25.6 | \n",
" 3.1 | \n",
"
\n",
" \n",
" 10 | \n",
" gpt-4o | \n",
" MATH | \n",
- " 70.0 | \n",
+ " 58.0 | \n",
" 40.0 | \n",
"
\n",
" \n",
" 11 | \n",
" gpt-4o | \n",
" SimpleQA | \n",
- " 88.0 | \n",
+ " 86.0 | \n",
" 6.0 | \n",
"
\n",
" \n",
" 12 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" GAIA | \n",
- " 0.0 | \n",
+ " 3.1 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 13 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" MATH | \n",
- " 42.0 | \n",
+ " 14.0 | \n",
" 18.0 | \n",
"
\n",
" \n",
" 14 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" SimpleQA | \n",
- " 54.0 | \n",
+ " 2.0 | \n",
" 6.0 | \n",
"
\n",
" \n",
@@ -899,49 +769,49 @@
" 16 | \n",
" meta-llama/Llama-3.2-3B-Instruct | \n",
" MATH | \n",
- " 32.0 | \n",
+ " 40.0 | \n",
" 12.0 | \n",
"
\n",
" \n",
" 17 | \n",
" meta-llama/Llama-3.2-3B-Instruct | \n",
" SimpleQA | \n",
- " 4.0 | \n",
+ " 20.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 18 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" GAIA | \n",
- " 34.4 | \n",
+ " 31.2 | \n",
" 3.1 | \n",
"
\n",
" \n",
" 19 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" MATH | \n",
- " 82.0 | \n",
+ " 72.0 | \n",
" 40.0 | \n",
"
\n",
" \n",
" 20 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" SimpleQA | \n",
- " 84.0 | \n",
+ " 78.0 | \n",
" 12.0 | \n",
"
\n",
" \n",
" 21 | \n",
" mistralai/Mistral-Nemo-Instruct-2407 | \n",
" GAIA | \n",
- " 3.1 | \n",
" 0.0 | \n",
+ " 3.1 | \n",
"
\n",
" \n",
" 22 | \n",
" mistralai/Mistral-Nemo-Instruct-2407 | \n",
" MATH | \n",
- " 20.0 | \n",
+ " 30.0 | \n",
" 22.0 | \n",
"
\n",
" \n",
@@ -949,7 +819,7 @@
" mistralai/Mistral-Nemo-Instruct-2407 | \n",
" SimpleQA | \n",
" 30.0 | \n",
- " 0.0 | \n",
+ " 6.0 | \n",
"
\n",
" \n",
"\n",
@@ -958,29 +828,29 @@
"text/plain": [
"action_type model_id source code vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
- "1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
- "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
- "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
- "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
+ "1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
+ "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
+ "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
+ "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
- "6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
- "7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
- "8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
- "9 gpt-4o GAIA 28.1 3.1\n",
- "10 gpt-4o MATH 70.0 40.0\n",
- "11 gpt-4o SimpleQA 88.0 6.0\n",
- "12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
- "13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
- "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
+ "6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
+ "7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
+ "8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
+ "9 gpt-4o GAIA 25.6 3.1\n",
+ "10 gpt-4o MATH 58.0 40.0\n",
+ "11 gpt-4o SimpleQA 86.0 6.0\n",
+ "12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
+ "13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
+ "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
- "16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
- "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
- "18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
- "19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
- "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
- "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
- "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
- "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
+ "16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
+ "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
+ "18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
+ "19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
+ "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
+ "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
+ "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
+ "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
]
},
"metadata": {},
@@ -1005,6 +875,15 @@
},
"metadata": {},
"output_type": "display_data"
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
}
],
"source": [
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index b3d0c5a..cfa8a6f 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent):
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
-
- # Extract tool call from model output
- if (
- type(model_message.tool_calls) is list
- and len(model_message.tool_calls) > 0
- ):
- tool_calls = model_message.tool_calls[0]
- tool_arguments = tool_calls.function.arguments
- tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
- else:
- start, end = (
- model_message.content.find("{"),
- model_message.content.rfind("}") + 1,
- )
- tool_calls = json.loads(model_message.content[start:end])
- tool_arguments = tool_calls["tool_arguments"]
- tool_name, tool_call_id = (
- tool_calls["tool_name"],
- f"call_{len(self.logs)}",
- )
+ tool_call = model_message.tool_calls[0]
+ tool_name, tool_call_id = tool_call.function.name, tool_call.id
+ tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(
@@ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
- self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
+ self.logger.log(
+ f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
+ level=LogLevel.INFO,
+ )
log_entry.observations = updated_information
return None
diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py
index 75fe8d0..59f6820 100644
--- a/src/smolagents/default_tools.py
+++ b/src/smolagents/default_tools.py
@@ -31,6 +31,7 @@ from .local_python_executor import (
)
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
+from .utils import truncate_content
if is_torch_available():
from transformers.models.whisper import (
@@ -112,18 +113,15 @@ class PythonInterpreterTool(Tool):
def forward(self, code: str) -> str:
state = {}
- try:
- output = str(
- self.python_evaluator(
- code,
- state=state,
- static_tools=self.base_python_tools,
- authorized_imports=self.authorized_imports,
- )[0] # The second element is boolean is_final_answer
- )
- return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
- except Exception as e:
- return f"Error: {str(e)}"
+ output = str(
+ self.python_evaluator(
+ code,
+ state=state,
+ static_tools=self.base_python_tools,
+ authorized_imports=self.authorized_imports,
+ )[0] # The second element is boolean is_final_answer
+ )
+ return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
class FinalAnswerTool(Tool):
@@ -295,7 +293,7 @@ class VisitWebpageTool(Tool):
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
- return markdown_content
+ return truncate_content(markdown_content)
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index 70ef5d1..f25ced9 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -14,20 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
import json
import logging
import os
import random
from copy import deepcopy
from enum import Enum
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union, Any
-from huggingface_hub import (
- InferenceClient,
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
-)
+from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
@@ -58,6 +54,27 @@ if _is_package_available("litellm"):
import litellm
+@dataclass
+class ChatMessageToolCallDefinition:
+ arguments: Any
+ name: str
+ description: Optional[str] = None
+
+
+@dataclass
+class ChatMessageToolCall:
+ function: ChatMessageToolCallDefinition
+ id: str
+ type: str
+
+
+@dataclass
+class ChatMessage:
+ role: str
+ content: Optional[str] = None
+ tool_calls: Optional[List[ChatMessageToolCall]] = None
+
+
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@@ -140,6 +157,17 @@ def get_clean_message_list(
return final_message_list
+def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
+ try:
+ start, end = (
+ possible_dictionary.find("{"),
+ possible_dictionary.rfind("}") + 1,
+ )
+ return json.loads(possible_dictionary[start:end])
+ except Exception:
+ return possible_dictionary
+
+
class Model:
def __init__(self):
self.last_input_token_count = None
@@ -157,7 +185,7 @@ class Model:
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
"""Process the input messages and return the model's response.
Parameters:
@@ -228,7 +256,7 @@ class HfApiModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
@@ -329,7 +357,7 @@ class TransformersModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@@ -365,21 +393,21 @@ class TransformersModel(Model):
if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None:
- return ChatCompletionOutputMessage(role="assistant", content=output)
+ return ChatMessage(role="assistant", content=output)
else:
if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name=tool_name, arguments=tool_arguments
),
)
@@ -414,7 +442,7 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@@ -485,7 +513,7 @@ class OpenAIServerModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py
index d5ec6b0..04a203d 100644
--- a/src/smolagents/tools.py
+++ b/src/smolagents/tools.py
@@ -221,6 +221,16 @@ class Tool:
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
if not self.is_initialized:
self.setup()
+
+ # Handle the arguments might be passed as a single dictionary
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
+ potential_kwargs = args[0]
+
+ # If the dictionary keys match our input parameters, convert it to kwargs
+ if all(key in self.inputs for key in potential_kwargs):
+ args = ()
+ kwargs = potential_kwargs
+
if sanitize_inputs_outputs:
args, kwargs = handle_agent_input_types(*args, **kwargs)
outputs = self.forward(*args, **kwargs)
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 38538ce..1cd0a67 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -30,10 +30,10 @@ from smolagents.agents import (
from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
-from huggingface_hub import (
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
+from smolagents.models import (
+ ChatMessage,
+ ChatMessageToolCall,
+ ChatMessageToolCallDefinition,
)
@@ -47,28 +47,28 @@ class FakeToolCallModel:
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="python_interpreter", arguments={"code": "2*3.6452"}
),
)
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_1",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
)
@@ -81,14 +81,14 @@ class FakeToolCallModelImage:
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="fake_image_generation_tool",
arguments={"prompt": "An image of a cat"},
),
@@ -96,14 +96,14 @@ class FakeToolCallModelImage:
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_1",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments="image.png"
),
)
@@ -114,7 +114,7 @@ class FakeToolCallModelImage:
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -125,7 +125,7 @@ result = 2**3.6452
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -140,7 +140,7 @@ final_answer(7.2904)
def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -154,7 +154,7 @@ print("Ok, calculation done!")
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -169,7 +169,7 @@ final_answer("got an error")
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -183,7 +183,7 @@ print("Ok, calculation done!")
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -196,7 +196,7 @@ final_answer("got an error")
def fake_code_model_import(messages, stop_sequences=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can answer the question
@@ -212,7 +212,7 @@ final_answer("got an error")
def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's define the function. special_marker
@@ -226,7 +226,7 @@ def moving_average(x, w):
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -241,7 +241,7 @@ final_answer(res)
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -255,7 +255,7 @@ final_answer(result)
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -454,14 +454,14 @@ class AgentTests(unittest.TestCase):
):
if tools_to_call_from is not None:
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
@@ -470,14 +470,14 @@ class AgentTests(unittest.TestCase):
)
else:
assert "Report on the current US president" in str(messages)
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments="Final report."
),
)
@@ -485,7 +485,7 @@ class AgentTests(unittest.TestCase):
)
else:
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
@@ -497,7 +497,7 @@ result = search_agent("Who is the current US president?")
)
else:
assert "Report on the current US president" in str(messages)
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's return the report.
@@ -518,14 +518,14 @@ final_answer("Final report.")
stop_sequences=None,
grammar=None,
):
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer",
arguments="Report on the current US president",
),
@@ -568,7 +568,7 @@ final_answer("Final report.")
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""Code:
```py
diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py
new file mode 100644
index 0000000..d966b84
--- /dev/null
+++ b/tests/test_default_tools.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+import pytest
+
+from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
+from smolagents.types import AGENT_TYPE_MAPPING
+
+from .test_tools import ToolTesterMixin
+
+
+class DefaultToolTests(unittest.TestCase):
+ def test_visit_webpage(self):
+ arguments = {
+ "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
+ }
+ result = VisitWebpageTool()(arguments)
+ assert isinstance(result, str)
+ assert (
+ "* [About Wikipedia](/wiki/Wikipedia:About)" in result
+ ) # Proper wikipedia pages have an About
+
+
+class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ result = self.tool("(2 / 2) * 4")
+ self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(code="(2 / 2) * 4")
+ self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
+
+ def test_agent_type_output(self):
+ inputs = ["2 * 2"]
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_agent_types_inputs(self):
+ inputs = ["2 * 2"]
+ _inputs = []
+
+ for _input, expected_input in zip(inputs, self.tool.inputs.values()):
+ input_type = expected_input["type"]
+ if isinstance(input_type, list):
+ _inputs.append(
+ [
+ AGENT_TYPE_MAPPING[_input_type](_input)
+ for _input_type in input_type
+ ]
+ )
+ else:
+ _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
+
+ # Should not raise an error
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_imports_work(self):
+ result = self.tool("import numpy as np")
+ assert "import from numpy is not allowed" not in result.lower()
+
+ def test_unauthorized_imports_fail(self):
+ with pytest.raises(Exception) as e:
+ self.tool("import sympy as sp")
+ assert "sympy" in str(e).lower()
diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py
index 11594e7..e55afb4 100644
--- a/tests/test_monitoring.py
+++ b/tests/test_monitoring.py
@@ -23,9 +23,9 @@ from smolagents import (
stream_to_gradio,
)
from huggingface_hub import (
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
+ ChatMessage,
+ ChatMessageToolCall,
+ ChatMessageToolCallDefinition,
)
@@ -36,21 +36,21 @@ class FakeLLMModel:
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
if tools_to_call_from is not None:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="fake_id",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "image"}
),
)
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Code:
@@ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase):
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
- return ChatCompletionOutputMessage(
- role="assistant", content="Malformed answer"
- )
+ return ChatMessage(role="assistant", content="Malformed answer")
agent = CodeAgent(
tools=[],
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 8c7aacc..75a146e 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -18,15 +18,12 @@ import unittest
import numpy as np
import pytest
-from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
+from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
fix_final_answer_code,
)
-from smolagents.types import AGENT_TYPE_MAPPING
-
-from .test_tools import ToolTesterMixin
# Fake function we will use as tool
@@ -34,47 +31,6 @@ def add_two(x):
return x + 2
-class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
- def setUp(self):
- self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
- self.tool.setup()
-
- def test_exact_match_arg(self):
- result = self.tool("(2 / 2) * 4")
- self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
-
- def test_exact_match_kwarg(self):
- result = self.tool(code="(2 / 2) * 4")
- self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
-
- def test_agent_type_output(self):
- inputs = ["2 * 2"]
- output = self.tool(*inputs, sanitize_inputs_outputs=True)
- output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
- self.assertTrue(isinstance(output, output_type))
-
- def test_agent_types_inputs(self):
- inputs = ["2 * 2"]
- _inputs = []
-
- for _input, expected_input in zip(inputs, self.tool.inputs.values()):
- input_type = expected_input["type"]
- if isinstance(input_type, list):
- _inputs.append(
- [
- AGENT_TYPE_MAPPING[_input_type](_input)
- for _input_type in input_type
- ]
- )
- else:
- _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
-
- # Should not raise an error
- output = self.tool(*inputs, sanitize_inputs_outputs=True)
- output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
- self.assertTrue(isinstance(output, output_type))
-
-
class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"