From 5f323735511f54168b688cdb0dee10ab5bdcd909 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:57:11 +0100 Subject: [PATCH] Make default tools more robust (#186) --- .github/workflows/tests.yml | 3 + examples/benchmark.ipynb | 301 +++++++++---------------------- src/smolagents/agents.py | 29 +-- src/smolagents/default_tools.py | 24 ++- src/smolagents/models.py | 60 ++++-- src/smolagents/tools.py | 10 + tests/test_agents.py | 78 ++++---- tests/test_default_tools.py | 83 +++++++++ tests/test_monitoring.py | 18 +- tests/test_python_interpreter.py | 46 +---- 10 files changed, 296 insertions(+), 356 deletions(-) create mode 100644 tests/test_default_tools.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a595bed..c720ec0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,6 +36,9 @@ jobs: - name: Agent tests run: | uv run pytest -sv ./tests/test_agents.py + - name: Default tools tests + run: | + uv run pytest -sv ./tests/test_default_tools.py - name: Final answer tests run: | uv run pytest -sv ./tests/test_final_answer.py diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index 1009f28..8b49b0a 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -29,8 +29,7 @@ "output_type": "stream", "text": [ "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n" + " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { @@ -173,7 +172,7 @@ "[132 rows x 4 columns]" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -196,19 +195,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n", - "* 'fields' has been removed\n", - " warnings.warn(message, UserWarning)\n" - ] - } - ], + "outputs": [], "source": [ "import time\n", "import json\n", @@ -408,100 +397,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n" - ] - } - ], + "outputs": [], "source": [ "open_model_ids = [\n", " \"meta-llama/Llama-3.3-70B-Instruct\",\n", @@ -554,42 +452,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'gpt-4o'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n", - "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n" - ] - } - ], + "outputs": [], "source": [ "from smolagents import LiteLLMModel\n", "\n", @@ -614,7 +479,7 @@ " agent = CodeAgent(\n", " tools=[GoogleSearchTool(), VisitWebpageTool()],\n", " model=LiteLLMModel(model_id),\n", - " additional_authorized_imports=[\"numpy\"],\n", + " additional_authorized_imports=[\"numpy\", \"sympy\"],\n", " max_steps=10,\n", " )\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", @@ -631,34 +496,39 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# import glob\n", "# import json\n", + "\n", "# jsonl_files = glob.glob(f\"output/*.jsonl\")\n", "\n", "# for file_path in jsonl_files:\n", - "# print(file_path)\n", - "# # Read all lines and filter out SimpleQA sources\n", - "# filtered_lines = []\n", - "# removed = 0\n", - "# with open(file_path, 'r', encoding='utf-8') as f:\n", - "# for line in f:\n", - "# try:\n", - "# data = json.loads(line.strip())\n", - "# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", - "# removed +=1\n", - "# else:\n", - "# filtered_lines.append(line)\n", - "# except json.JSONDecodeError:\n", - "# print(\"Invalid line:\", line)\n", - "# continue # Skip invalid JSON lines\n", - "# print(f\"Removed {removed} lines.\")\n", - "# # Write filtered content back to the same file\n", - "# with open(file_path, 'w', encoding='utf-8') as f:\n", - "# f.writelines(filtered_lines)" + "# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n", + "# print(file_path)\n", + "# # Read all lines and filter out SimpleQA sources\n", + "# filtered_lines = []\n", + "# removed = 0\n", + "# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", + "# for line in f:\n", + "# try:\n", + "# data = json.loads(line.strip())\n", + "# data[\"answer\"] = data[\"answer\"][\"content\"]\n", + "# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n", + "# # removed +=1\n", + "# # else:\n", + "# filtered_lines.append(json.dumps(data) + \"\\n\")\n", + "# except json.JSONDecodeError:\n", + "# print(\"Invalid line:\", line)\n", + "# continue # Skip invalid JSON lines\n", + "# print(f\"Removed {removed} lines.\")\n", + "# # Write filtered content back to the same file\n", + "# with open(\n", + "# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n", + "# ) as f:\n", + "# f.writelines(filtered_lines)" ] }, { @@ -670,14 +540,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", + "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", " warnings.warn(\n" ] } @@ -731,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -752,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -794,28 +664,28 @@ " 1\n", " Qwen/Qwen2.5-72B-Instruct\n", " MATH\n", - " 74.0\n", + " 76.0\n", " 30.0\n", " \n", " \n", " 2\n", " Qwen/Qwen2.5-72B-Instruct\n", " SimpleQA\n", - " 70.0\n", + " 88.0\n", " 10.0\n", " \n", " \n", " 3\n", " Qwen/Qwen2.5-Coder-32B-Instruct\n", " GAIA\n", - " 18.8\n", + " 25.0\n", " 3.1\n", " \n", " \n", " 4\n", " Qwen/Qwen2.5-Coder-32B-Instruct\n", " MATH\n", - " 76.0\n", + " 86.0\n", " 60.0\n", " \n", " \n", @@ -829,63 +699,63 @@ " 6\n", " anthropic/claude-3-5-sonnet-latest\n", " GAIA\n", - " 40.6\n", + " NaN\n", " 3.1\n", " \n", " \n", " 7\n", " anthropic/claude-3-5-sonnet-latest\n", " MATH\n", - " 67.0\n", + " NaN\n", " 50.0\n", " \n", " \n", " 8\n", " anthropic/claude-3-5-sonnet-latest\n", " SimpleQA\n", - " 90.0\n", + " NaN\n", " 34.0\n", " \n", " \n", " 9\n", " gpt-4o\n", " GAIA\n", - " 28.1\n", + " 25.6\n", " 3.1\n", " \n", " \n", " 10\n", " gpt-4o\n", " MATH\n", - " 70.0\n", + " 58.0\n", " 40.0\n", " \n", " \n", " 11\n", " gpt-4o\n", " SimpleQA\n", - " 88.0\n", + " 86.0\n", " 6.0\n", " \n", " \n", " 12\n", " meta-llama/Llama-3.1-8B-Instruct\n", " GAIA\n", - " 0.0\n", + " 3.1\n", " 0.0\n", " \n", " \n", " 13\n", " meta-llama/Llama-3.1-8B-Instruct\n", " MATH\n", - " 42.0\n", + " 14.0\n", " 18.0\n", " \n", " \n", " 14\n", " meta-llama/Llama-3.1-8B-Instruct\n", " SimpleQA\n", - " 54.0\n", + " 2.0\n", " 6.0\n", " \n", " \n", @@ -899,49 +769,49 @@ " 16\n", " meta-llama/Llama-3.2-3B-Instruct\n", " MATH\n", - " 32.0\n", + " 40.0\n", " 12.0\n", " \n", " \n", " 17\n", " meta-llama/Llama-3.2-3B-Instruct\n", " SimpleQA\n", - " 4.0\n", + " 20.0\n", " 0.0\n", " \n", " \n", " 18\n", " meta-llama/Llama-3.3-70B-Instruct\n", " GAIA\n", - " 34.4\n", + " 31.2\n", " 3.1\n", " \n", " \n", " 19\n", " meta-llama/Llama-3.3-70B-Instruct\n", " MATH\n", - " 82.0\n", + " 72.0\n", " 40.0\n", " \n", " \n", " 20\n", " meta-llama/Llama-3.3-70B-Instruct\n", " SimpleQA\n", - " 84.0\n", + " 78.0\n", " 12.0\n", " \n", " \n", " 21\n", " mistralai/Mistral-Nemo-Instruct-2407\n", " GAIA\n", - " 3.1\n", " 0.0\n", + " 3.1\n", " \n", " \n", " 22\n", " mistralai/Mistral-Nemo-Instruct-2407\n", " MATH\n", - " 20.0\n", + " 30.0\n", " 22.0\n", " \n", " \n", @@ -949,7 +819,7 @@ " mistralai/Mistral-Nemo-Instruct-2407\n", " SimpleQA\n", " 30.0\n", - " 0.0\n", + " 6.0\n", " \n", " \n", "\n", @@ -958,29 +828,29 @@ "text/plain": [ "action_type model_id source code vanilla\n", "0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n", - "1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n", - "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n", - "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n", - "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n", + "1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n", + "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n", + "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n", + "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n", "5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n", - "6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n", - "7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n", - "8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n", - "9 gpt-4o GAIA 28.1 3.1\n", - "10 gpt-4o MATH 70.0 40.0\n", - "11 gpt-4o SimpleQA 88.0 6.0\n", - "12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n", - "13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n", - "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n", + "6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n", + "7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n", + "8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n", + "9 gpt-4o GAIA 25.6 3.1\n", + "10 gpt-4o MATH 58.0 40.0\n", + "11 gpt-4o SimpleQA 86.0 6.0\n", + "12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n", + "13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n", + "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n", "15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n", - "16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n", - "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n", - "18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n", - "19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n", - "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n", - "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n", - "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n", - "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0" + "16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n", + "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n", + "18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n", + "19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n", + "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n", + "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n", + "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n", + "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0" ] }, "metadata": {}, @@ -1005,6 +875,15 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mnotebook controller is DISPOSED. \n", + "\u001b[1;31mView Jupyter log for further details." + ] } ], "source": [ diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index b3d0c5a..cfa8a6f 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -import json from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent): tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) - - # Extract tool call from model output - if ( - type(model_message.tool_calls) is list - and len(model_message.tool_calls) > 0 - ): - tool_calls = model_message.tool_calls[0] - tool_arguments = tool_calls.function.arguments - tool_name, tool_call_id = tool_calls.function.name, tool_calls.id - else: - start, end = ( - model_message.content.find("{"), - model_message.content.rfind("}") + 1, - ) - tool_calls = json.loads(model_message.content[start:end]) - tool_arguments = tool_calls["tool_arguments"] - tool_name, tool_call_id = ( - tool_calls["tool_name"], - f"call_{len(self.logs)}", - ) + tool_call = model_message.tool_calls[0] + tool_name, tool_call_id = tool_call.function.name, tool_call.id + tool_arguments = tool_call.function.arguments except Exception as e: raise AgentGenerationError( @@ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent): updated_information = f"Stored '{observation_name}' in memory." else: updated_information = str(observation).strip() - self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO) + self.logger.log( + f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components + level=LogLevel.INFO, + ) log_entry.observations = updated_information return None diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 75fe8d0..59f6820 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -31,6 +31,7 @@ from .local_python_executor import ( ) from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .types import AgentAudio +from .utils import truncate_content if is_torch_available(): from transformers.models.whisper import ( @@ -112,18 +113,15 @@ class PythonInterpreterTool(Tool): def forward(self, code: str) -> str: state = {} - try: - output = str( - self.python_evaluator( - code, - state=state, - static_tools=self.base_python_tools, - authorized_imports=self.authorized_imports, - )[0] # The second element is boolean is_final_answer - ) - return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" - except Exception as e: - return f"Error: {str(e)}" + output = str( + self.python_evaluator( + code, + state=state, + static_tools=self.base_python_tools, + authorized_imports=self.authorized_imports, + )[0] # The second element is boolean is_final_answer + ) + return f"Stdout:\n{state['print_outputs']}\nOutput: {output}" class FinalAnswerTool(Tool): @@ -295,7 +293,7 @@ class VisitWebpageTool(Tool): # Remove multiple line breaks markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) - return markdown_content + return truncate_content(markdown_content) except RequestException as e: return f"Error fetching the webpage: {str(e)}" diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 70ef5d1..f25ced9 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -14,20 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import json import logging import os import random from copy import deepcopy from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union, Any -from huggingface_hub import ( - InferenceClient, - ChatCompletionOutputMessage, - ChatCompletionOutputToolCall, - ChatCompletionOutputFunctionDefinition, -) +from huggingface_hub import InferenceClient from transformers import ( AutoModelForCausalLM, @@ -58,6 +54,27 @@ if _is_package_available("litellm"): import litellm +@dataclass +class ChatMessageToolCallDefinition: + arguments: Any + name: str + description: Optional[str] = None + + +@dataclass +class ChatMessageToolCall: + function: ChatMessageToolCallDefinition + id: str + type: str + + +@dataclass +class ChatMessage: + role: str + content: Optional[str] = None + tool_calls: Optional[List[ChatMessageToolCall]] = None + + class MessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" @@ -140,6 +157,17 @@ def get_clean_message_list( return final_message_list +def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: + try: + start, end = ( + possible_dictionary.find("{"), + possible_dictionary.rfind("}") + 1, + ) + return json.loads(possible_dictionary[start:end]) + except Exception: + return possible_dictionary + + class Model: def __init__(self): self.last_input_token_count = None @@ -157,7 +185,7 @@ class Model: stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None, max_tokens: int = 1500, - ) -> ChatCompletionOutputMessage: + ) -> ChatMessage: """Process the input messages and return the model's response. Parameters: @@ -228,7 +256,7 @@ class HfApiModel(Model): grammar: Optional[str] = None, max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, - ) -> ChatCompletionOutputMessage: + ) -> ChatMessage: """ Gets an LLM output message for the given list of input messages. If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. @@ -329,7 +357,7 @@ class TransformersModel(Model): grammar: Optional[str] = None, max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, - ) -> ChatCompletionOutputMessage: + ) -> ChatMessage: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) @@ -365,21 +393,21 @@ class TransformersModel(Model): if stop_sequences is not None: output = remove_stop_sequences(output, stop_sequences) if tools_to_call_from is None: - return ChatCompletionOutputMessage(role="assistant", content=output) + return ChatMessage(role="assistant", content=output) else: if "Action:" in output: output = output.split("Action:", 1)[1].strip() parsed_output = json.loads(output) tool_name = parsed_output.get("tool_name") tool_arguments = parsed_output.get("tool_arguments") - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="".join(random.choices("0123456789", k=5)), type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name=tool_name, arguments=tool_arguments ), ) @@ -414,7 +442,7 @@ class LiteLLMModel(Model): grammar: Optional[str] = None, max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, - ) -> ChatCompletionOutputMessage: + ) -> ChatMessage: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) @@ -485,7 +513,7 @@ class OpenAIServerModel(Model): grammar: Optional[str] = None, max_tokens: int = 1500, tools_to_call_from: Optional[List[Tool]] = None, - ) -> ChatCompletionOutputMessage: + ) -> ChatMessage: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index d5ec6b0..04a203d 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -221,6 +221,16 @@ class Tool: def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): if not self.is_initialized: self.setup() + + # Handle the arguments might be passed as a single dictionary + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict): + potential_kwargs = args[0] + + # If the dictionary keys match our input parameters, convert it to kwargs + if all(key in self.inputs for key in potential_kwargs): + args = () + kwargs = potential_kwargs + if sanitize_inputs_outputs: args, kwargs = handle_agent_input_types(*args, **kwargs) outputs = self.forward(*args, **kwargs) diff --git a/tests/test_agents.py b/tests/test_agents.py index 38538ce..1cd0a67 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -30,10 +30,10 @@ from smolagents.agents import ( from smolagents.default_tools import PythonInterpreterTool from smolagents.tools import tool from smolagents.types import AgentImage, AgentText -from huggingface_hub import ( - ChatCompletionOutputMessage, - ChatCompletionOutputToolCall, - ChatCompletionOutputFunctionDefinition, +from smolagents.models import ( + ChatMessage, + ChatMessageToolCall, + ChatMessageToolCallDefinition, ) @@ -47,28 +47,28 @@ class FakeToolCallModel: self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None ): if len(messages) < 3: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_0", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="python_interpreter", arguments={"code": "2*3.6452"} ), ) ], ) else: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_1", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="final_answer", arguments={"answer": "7.2904"} ), ) @@ -81,14 +81,14 @@ class FakeToolCallModelImage: self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None ): if len(messages) < 3: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_0", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="fake_image_generation_tool", arguments={"prompt": "An image of a cat"}, ), @@ -96,14 +96,14 @@ class FakeToolCallModelImage: ], ) else: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_1", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="final_answer", arguments="image.png" ), ) @@ -114,7 +114,7 @@ class FakeToolCallModelImage: def fake_code_model(messages, stop_sequences=None, grammar=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I should multiply 2 by 3.6452. special_marker @@ -125,7 +125,7 @@ result = 2**3.6452 """, ) else: # We're at step 2 - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I can now answer the initial question @@ -140,7 +140,7 @@ final_answer(7.2904) def fake_code_model_error(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I should multiply 2 by 3.6452. special_marker @@ -154,7 +154,7 @@ print("Ok, calculation done!") """, ) else: # We're at step 2 - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I can now answer the initial question @@ -169,7 +169,7 @@ final_answer("got an error") def fake_code_model_syntax_error(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I should multiply 2 by 3.6452. special_marker @@ -183,7 +183,7 @@ print("Ok, calculation done!") """, ) else: # We're at step 2 - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I can now answer the initial question @@ -196,7 +196,7 @@ final_answer("got an error") def fake_code_model_import(messages, stop_sequences=None) -> str: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I can answer the question @@ -212,7 +212,7 @@ final_answer("got an error") def fake_code_functiondef(messages, stop_sequences=None) -> str: prompt = str(messages) if "special_marker" not in prompt: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: Let's define the function. special_marker @@ -226,7 +226,7 @@ def moving_average(x, w): """, ) else: # We're at step 2 - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I can now answer the initial question @@ -241,7 +241,7 @@ final_answer(res) def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I should multiply 2 by 3.6452. special_marker @@ -255,7 +255,7 @@ final_answer(result) def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: I should multiply 2 by 3.6452. special_marker @@ -454,14 +454,14 @@ class AgentTests(unittest.TestCase): ): if tools_to_call_from is not None: if len(messages) < 3: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_0", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="search_agent", arguments="Who is the current US president?", ), @@ -470,14 +470,14 @@ class AgentTests(unittest.TestCase): ) else: assert "Report on the current US president" in str(messages) - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_0", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="final_answer", arguments="Final report." ), ) @@ -485,7 +485,7 @@ class AgentTests(unittest.TestCase): ) else: if len(messages) < 3: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: Let's call our search agent. @@ -497,7 +497,7 @@ result = search_agent("Who is the current US president?") ) else: assert "Report on the current US president" in str(messages) - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Thought: Let's return the report. @@ -518,14 +518,14 @@ final_answer("Final report.") stop_sequences=None, grammar=None, ): - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="call_0", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="final_answer", arguments="Report on the current US president", ), @@ -568,7 +568,7 @@ final_answer("Final report.") def test_code_nontrivial_final_answer_works(self): def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="""Code: ```py diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py new file mode 100644 index 0000000..d966b84 --- /dev/null +++ b/tests/test_default_tools.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import pytest + +from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool +from smolagents.types import AGENT_TYPE_MAPPING + +from .test_tools import ToolTesterMixin + + +class DefaultToolTests(unittest.TestCase): + def test_visit_webpage(self): + arguments = { + "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security" + } + result = VisitWebpageTool()(arguments) + assert isinstance(result, str) + assert ( + "* [About Wikipedia](/wiki/Wikipedia:About)" in result + ) # Proper wikipedia pages have an About + + +class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): + def setUp(self): + self.tool = PythonInterpreterTool(authorized_imports=["numpy"]) + self.tool.setup() + + def test_exact_match_arg(self): + result = self.tool("(2 / 2) * 4") + self.assertEqual(result, "Stdout:\n\nOutput: 4.0") + + def test_exact_match_kwarg(self): + result = self.tool(code="(2 / 2) * 4") + self.assertEqual(result, "Stdout:\n\nOutput: 4.0") + + def test_agent_type_output(self): + inputs = ["2 * 2"] + output = self.tool(*inputs, sanitize_inputs_outputs=True) + output_type = AGENT_TYPE_MAPPING[self.tool.output_type] + self.assertTrue(isinstance(output, output_type)) + + def test_agent_types_inputs(self): + inputs = ["2 * 2"] + _inputs = [] + + for _input, expected_input in zip(inputs, self.tool.inputs.values()): + input_type = expected_input["type"] + if isinstance(input_type, list): + _inputs.append( + [ + AGENT_TYPE_MAPPING[_input_type](_input) + for _input_type in input_type + ] + ) + else: + _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + + # Should not raise an error + output = self.tool(*inputs, sanitize_inputs_outputs=True) + output_type = AGENT_TYPE_MAPPING[self.tool.output_type] + self.assertTrue(isinstance(output, output_type)) + + def test_imports_work(self): + result = self.tool("import numpy as np") + assert "import from numpy is not allowed" not in result.lower() + + def test_unauthorized_imports_fail(self): + with pytest.raises(Exception) as e: + self.tool("import sympy as sp") + assert "sympy" in str(e).lower() diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index 11594e7..e55afb4 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -23,9 +23,9 @@ from smolagents import ( stream_to_gradio, ) from huggingface_hub import ( - ChatCompletionOutputMessage, - ChatCompletionOutputToolCall, - ChatCompletionOutputFunctionDefinition, + ChatMessage, + ChatMessageToolCall, + ChatMessageToolCallDefinition, ) @@ -36,21 +36,21 @@ class FakeLLMModel: def __call__(self, prompt, tools_to_call_from=None, **kwargs): if tools_to_call_from is not None: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content="", tool_calls=[ - ChatCompletionOutputToolCall( + ChatMessageToolCall( id="fake_id", type="function", - function=ChatCompletionOutputFunctionDefinition( + function=ChatMessageToolCallDefinition( name="final_answer", arguments={"answer": "image"} ), ) ], ) else: - return ChatCompletionOutputMessage( + return ChatMessage( role="assistant", content=""" Code: @@ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase): self.last_output_token_count = 20 def __call__(self, prompt, **kwargs): - return ChatCompletionOutputMessage( - role="assistant", content="Malformed answer" - ) + return ChatMessage(role="assistant", content="Malformed answer") agent = CodeAgent( tools=[], diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 8c7aacc..75a146e 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -18,15 +18,12 @@ import unittest import numpy as np import pytest -from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool +from smolagents.default_tools import BASE_PYTHON_TOOLS from smolagents.local_python_executor import ( InterpreterError, evaluate_python_code, fix_final_answer_code, ) -from smolagents.types import AGENT_TYPE_MAPPING - -from .test_tools import ToolTesterMixin # Fake function we will use as tool @@ -34,47 +31,6 @@ def add_two(x): return x + 2 -class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): - def setUp(self): - self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"]) - self.tool.setup() - - def test_exact_match_arg(self): - result = self.tool("(2 / 2) * 4") - self.assertEqual(result, "Stdout:\n\nOutput: 4.0") - - def test_exact_match_kwarg(self): - result = self.tool(code="(2 / 2) * 4") - self.assertEqual(result, "Stdout:\n\nOutput: 4.0") - - def test_agent_type_output(self): - inputs = ["2 * 2"] - output = self.tool(*inputs, sanitize_inputs_outputs=True) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - self.assertTrue(isinstance(output, output_type)) - - def test_agent_types_inputs(self): - inputs = ["2 * 2"] - _inputs = [] - - for _input, expected_input in zip(inputs, self.tool.inputs.values()): - input_type = expected_input["type"] - if isinstance(input_type, list): - _inputs.append( - [ - AGENT_TYPE_MAPPING[_input_type](_input) - for _input_type in input_type - ] - ) - else: - _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) - - # Should not raise an error - output = self.tool(*inputs, sanitize_inputs_outputs=True) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - self.assertTrue(isinstance(output, output_type)) - - class PythonInterpreterTester(unittest.TestCase): def test_evaluate_assign(self): code = "x = 3"