add support for MCP Servers tools as `ToolCollection` (#232)
* add support for tool collection from mcp servers * add forgotten documentation * fix link missing in documentation * fix linting in CI, bumpruff to use modern version * mcpadapt added as optional dependencies * use classmethod for from_hub and from_mcp to better reflect the fact that they return a ToolCollection * Update src/smolagents/tools.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> * Update src/smolagents/tools.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> * Test ToolCollection.from_mcp * Rename to mcp extra * Add mcp extra to test extra * add a test for from_mcp * fix typo * fix tests * Test ToolCollection.from_mcp (cherry picked from commit 9284d9ea8cf24d3c934e35a38dfe34f3ce31cef3) * Make all pytest tests --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									7d6599e430
								
							
						
					
					
						commit
						a4d029da88
					
				|  | @ -204,13 +204,17 @@ agent.run( | ||||||
| 
 | 
 | ||||||
| ### Use a collection of tools | ### Use a collection of tools | ||||||
| 
 | 
 | ||||||
| You can leverage tool collections by using the `ToolCollection` object, with the slug of the collection you want to use. | You can leverage tool collections by using the `ToolCollection` object. It supports loading either a collection from the Hub or an MCP server tools. | ||||||
|  | 
 | ||||||
|  | #### Tool Collection from a collection in the Hub | ||||||
|  | 
 | ||||||
|  | You can leverage it with the slug of the collection you want to use. | ||||||
| Then pass them as a list to initialize your agent, and start using them! | Then pass them as a list to initialize your agent, and start using them! | ||||||
| 
 | 
 | ||||||
| ```py | ```py | ||||||
| from smolagents import ToolCollection, CodeAgent | from smolagents import ToolCollection, CodeAgent | ||||||
| 
 | 
 | ||||||
| image_tool_collection = ToolCollection( | image_tool_collection = ToolCollection.from_hub( | ||||||
|     collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f", |     collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f", | ||||||
|     token="<YOUR_HUGGINGFACEHUB_API_TOKEN>" |     token="<YOUR_HUGGINGFACEHUB_API_TOKEN>" | ||||||
| ) | ) | ||||||
|  | @ -220,3 +224,24 @@ agent.run("Please draw me a picture of rivers and lakes.") | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| To speed up the start, tools are loaded only if called by the agent. | To speed up the start, tools are loaded only if called by the agent. | ||||||
|  | 
 | ||||||
|  | #### Tool Collection from any MCP server | ||||||
|  | 
 | ||||||
|  | Leverage tools from the hundreds of MCP servers available on [glama.ai](https://glama.ai/mcp/servers) or [smithery.ai](https://smithery.ai/). | ||||||
|  | 
 | ||||||
|  | The MCP servers tools can be loaded in a `ToolCollection` object as follow: | ||||||
|  | 
 | ||||||
|  | ```py | ||||||
|  | from smolagents import ToolCollection, CodeAgent | ||||||
|  | from mcp import StdioServerParameters | ||||||
|  | 
 | ||||||
|  | server_parameters = StdioServerParameters( | ||||||
|  |     command="uv", | ||||||
|  |     args=["--quiet", "pubmedmcp@0.1.3"], | ||||||
|  |     env={"UV_PYTHON": "3.12", **os.environ}, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | with ToolCollection.from_mcp(server_parameters) as tool_collection: | ||||||
|  |     agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True) | ||||||
|  |     agent.run("Please find a remedy for hangover.") | ||||||
|  | ``` | ||||||
|  | @ -209,7 +209,7 @@ agent.run( | ||||||
| ```py | ```py | ||||||
| from smolagents import ToolCollection, CodeAgent | from smolagents import ToolCollection, CodeAgent | ||||||
| 
 | 
 | ||||||
| image_tool_collection = ToolCollection( | image_tool_collection = ToolCollection.from_hub( | ||||||
|     collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f", |     collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f", | ||||||
|     token="<YOUR_HUGGINGFACEHUB_API_TOKEN>" |     token="<YOUR_HUGGINGFACEHUB_API_TOKEN>" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | """An example of loading a ToolCollection directly from an MCP server. | ||||||
|  | 
 | ||||||
|  | Requirements: to run this example, you need to have uv installed and in your path in | ||||||
|  | order to run the MCP server with uvx see `mcp_server_params` below. | ||||||
|  | 
 | ||||||
|  | Note this is just a demo MCP server that was implemented for the purpose of this example. | ||||||
|  | It only provide a single tool to search amongst pubmed papers abstracts. | ||||||
|  | 
 | ||||||
|  | Usage: | ||||||
|  | >>> uv run examples/tool_calling_agent_mcp.py | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | from mcp import StdioServerParameters | ||||||
|  | from smolagents import CodeAgent, HfApiModel, ToolCollection | ||||||
|  | 
 | ||||||
|  | mcp_server_params = StdioServerParameters( | ||||||
|  |     command="uvx", | ||||||
|  |     args=["--quiet", "pubmedmcp@0.1.3"], | ||||||
|  |     env={"UV_PYTHON": "3.12", **os.environ}, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | with ToolCollection.from_mcp(mcp_server_params) as tool_collection: | ||||||
|  |     # print(tool_collection.tools[0](request={"term": "efficient treatment hangover"})) | ||||||
|  |     agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel()) | ||||||
|  |     agent.run("Find studies about hangover?") | ||||||
|  | @ -36,13 +36,18 @@ torch = [ | ||||||
| litellm = [ | litellm = [ | ||||||
|   "litellm>=1.55.10", |   "litellm>=1.55.10", | ||||||
| ] | ] | ||||||
| openai = ["openai>=1.58.1"] | mcp = [ | ||||||
|  |   "mcpadapt>=0.0.6" | ||||||
|  | ] | ||||||
|  | openai = [ | ||||||
|  |   "openai>=1.58.1" | ||||||
|  | ] | ||||||
| quality = [ | quality = [ | ||||||
|   "ruff>=0.9.0", |   "ruff>=0.9.0", | ||||||
| ] | ] | ||||||
| test = [ | test = [ | ||||||
|   "pytest>=8.1.0", |   "pytest>=8.1.0", | ||||||
|   "smolagents[audio,litellm,openai,torch]", |   "smolagents[audio,litellm,mcp,openai,torch]", | ||||||
| ] | ] | ||||||
| dev = [ | dev = [ | ||||||
|   "smolagents[quality,test]", |   "smolagents[quality,test]", | ||||||
|  |  | ||||||
|  | @ -23,9 +23,10 @@ import os | ||||||
| import sys | import sys | ||||||
| import tempfile | import tempfile | ||||||
| import textwrap | import textwrap | ||||||
|  | from contextlib import contextmanager | ||||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable, Dict, Optional, Union, get_type_hints | from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import ( | from huggingface_hub import ( | ||||||
|     create_repo, |     create_repo, | ||||||
|  | @ -35,6 +36,7 @@ from huggingface_hub import ( | ||||||
|     upload_folder, |     upload_folder, | ||||||
| ) | ) | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError | from huggingface_hub.utils import RepositoryNotFoundError | ||||||
|  | 
 | ||||||
| from packaging import version | from packaging import version | ||||||
| from transformers.dynamic_module_utils import get_imports | from transformers.dynamic_module_utils import get_imports | ||||||
| from transformers.utils import ( | from transformers.utils import ( | ||||||
|  | @ -275,7 +277,8 @@ class Tool: | ||||||
|                 raise (ValueError("\n".join(method_checker.errors))) |                 raise (ValueError("\n".join(method_checker.errors))) | ||||||
| 
 | 
 | ||||||
|             forward_source_code = inspect.getsource(self.forward) |             forward_source_code = inspect.getsource(self.forward) | ||||||
|             tool_code = textwrap.dedent(f""" |             tool_code = textwrap.dedent( | ||||||
|  |                 f""" | ||||||
|             from smolagents import Tool |             from smolagents import Tool | ||||||
|             from typing import Optional |             from typing import Optional | ||||||
| 
 | 
 | ||||||
|  | @ -284,7 +287,8 @@ class Tool: | ||||||
|                 description = "{self.description}" |                 description = "{self.description}" | ||||||
|                 inputs = {json.dumps(self.inputs, separators=(",", ":"))} |                 inputs = {json.dumps(self.inputs, separators=(",", ":"))} | ||||||
|                 output_type = "{self.output_type}" |                 output_type = "{self.output_type}" | ||||||
|             """).strip() |             """ | ||||||
|  |             ).strip() | ||||||
|             import re |             import re | ||||||
| 
 | 
 | ||||||
|             def add_self_argument(source_code: str) -> str: |             def add_self_argument(source_code: str) -> str: | ||||||
|  | @ -325,7 +329,8 @@ class Tool: | ||||||
|         app_file = os.path.join(output_dir, "app.py") |         app_file = os.path.join(output_dir, "app.py") | ||||||
|         with open(app_file, "w", encoding="utf-8") as f: |         with open(app_file, "w", encoding="utf-8") as f: | ||||||
|             f.write( |             f.write( | ||||||
|                 textwrap.dedent(f""" |                 textwrap.dedent( | ||||||
|  |                     f""" | ||||||
|             from smolagents import launch_gradio_demo |             from smolagents import launch_gradio_demo | ||||||
|             from typing import Optional |             from typing import Optional | ||||||
|             from tool import {class_name} |             from tool import {class_name} | ||||||
|  | @ -333,7 +338,8 @@ class Tool: | ||||||
|             tool = {class_name}() |             tool = {class_name}() | ||||||
| 
 | 
 | ||||||
|             launch_gradio_demo(tool) |             launch_gradio_demo(tool) | ||||||
|             """).lstrip() |             """ | ||||||
|  |                 ).lstrip() | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # Save requirements file |         # Save requirements file | ||||||
|  | @ -870,42 +876,105 @@ def add_description(description): | ||||||
| 
 | 
 | ||||||
| class ToolCollection: | class ToolCollection: | ||||||
|     """ |     """ | ||||||
|     Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox. |     Tool collections enable loading a collection of tools in the agent's toolbox. | ||||||
|  | 
 | ||||||
|  |     Collections can be loaded from a collection in the Hub or from an MCP server, see: | ||||||
|  |     - [`ToolCollection.from_hub`] | ||||||
|  |     - [`ToolCollection.from_mcp`] | ||||||
|  | 
 | ||||||
|  |     For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`] | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, tools: List[Tool]): | ||||||
|  |         self.tools = tools | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_hub( | ||||||
|  |         cls, | ||||||
|  |         collection_slug: str, | ||||||
|  |         token: Optional[str] = None, | ||||||
|  |         trust_remote_code: bool = False, | ||||||
|  |     ) -> "ToolCollection": | ||||||
|  |         """Loads a tool collection from the Hub. | ||||||
|  | 
 | ||||||
|  |         it adds a collection of tools from all Spaces in the collection to the agent's toolbox | ||||||
| 
 | 
 | ||||||
|         > [!NOTE] |         > [!NOTE] | ||||||
|         > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd |         > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd | ||||||
|         > like for this collection to showcase them. |         > like for this collection to showcase them. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|         collection_slug (str): |             collection_slug (str): The collection slug referencing the collection. | ||||||
|             The collection slug referencing the collection. |             token (str, *optional*): The authentication token if the collection is private. | ||||||
|         token (str, *optional*): |             trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code. | ||||||
|             The authentication token if the collection is private. | 
 | ||||||
|  |         Returns: | ||||||
|  |             ToolCollection: A tool collection instance loaded with the tools. | ||||||
| 
 | 
 | ||||||
|         Example: |         Example: | ||||||
| 
 |  | ||||||
|         ```py |         ```py | ||||||
|     >>> from transformers import ToolCollection, CodeAgent |         >>> from smolagents import ToolCollection, CodeAgent | ||||||
| 
 | 
 | ||||||
|     >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f") |         >>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f") | ||||||
|         >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True) |         >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True) | ||||||
| 
 | 
 | ||||||
|         >>> agent.run("Please draw me a picture of rivers and lakes.") |         >>> agent.run("Please draw me a picture of rivers and lakes.") | ||||||
|         ``` |         ``` | ||||||
|         """ |         """ | ||||||
|  |         _collection = get_collection(collection_slug, token=token) | ||||||
|  |         _hub_repo_ids = { | ||||||
|  |             item.item_id for item in _collection.items if item.item_type == "space" | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|     def __init__( |         tools = { | ||||||
|         self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False |  | ||||||
|     ): |  | ||||||
|         self._collection = get_collection(collection_slug, token=token) |  | ||||||
|         self._hub_repo_ids = { |  | ||||||
|             item.item_id for item in self._collection.items if item.item_type == "space" |  | ||||||
|         } |  | ||||||
|         self.tools = { |  | ||||||
|             Tool.from_hub(repo_id, token, trust_remote_code) |             Tool.from_hub(repo_id, token, trust_remote_code) | ||||||
|             for repo_id in self._hub_repo_ids |             for repo_id in _hub_repo_ids | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         return cls(tools) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     @contextmanager | ||||||
|  |     def from_mcp(cls, server_parameters) -> "ToolCollection": | ||||||
|  |         """Automatically load a tool collection from an MCP server. | ||||||
|  | 
 | ||||||
|  |         Note: a separate thread will be spawned to run an asyncio event loop handling | ||||||
|  |         the MCP server. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             server_parameters (mcp.StdioServerParameters): The server parameters to use to | ||||||
|  |             connect to the MCP server. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             ToolCollection: A tool collection instance. | ||||||
|  | 
 | ||||||
|  |         Example: | ||||||
|  |         ```py | ||||||
|  |         >>> from smolagents import ToolCollection, CodeAgent | ||||||
|  |         >>> from mcp import StdioServerParameters | ||||||
|  | 
 | ||||||
|  |         >>> server_parameters = StdioServerParameters( | ||||||
|  |         >>>     command="uv", | ||||||
|  |         >>>     args=["--quiet", "pubmedmcp@0.1.3"], | ||||||
|  |         >>>     env={"UV_PYTHON": "3.12", **os.environ}, | ||||||
|  |         >>> ) | ||||||
|  | 
 | ||||||
|  |         >>> with ToolCollection.from_mcp(server_parameters) as tool_collection: | ||||||
|  |         >>>     agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True) | ||||||
|  |         >>>     agent.run("Please find a remedy for hangover.") | ||||||
|  |         ``` | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             from mcpadapt.core import MCPAdapt | ||||||
|  |             from mcpadapt.smolagents_adapter import SmolAgentsAdapter | ||||||
|  |         except ImportError: | ||||||
|  |             raise ImportError( | ||||||
|  |                 """Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`.""" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools: | ||||||
|  |             yield cls(tools) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def tool(tool_function: Callable) -> Tool: | def tool(tool_function: Callable) -> Tool: | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  | @ -14,14 +14,17 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from textwrap import dedent | ||||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||||
|  | from unittest.mock import patch, MagicMock | ||||||
| 
 | 
 | ||||||
|  | import mcp | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest | import pytest | ||||||
| from transformers import is_torch_available, is_vision_available | from transformers import is_torch_available, is_vision_available | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| from smolagents.tools import AUTHORIZED_TYPES, Tool, tool | from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool | ||||||
| from smolagents.types import ( | from smolagents.types import ( | ||||||
|     AGENT_TYPE_MAPPING, |     AGENT_TYPE_MAPPING, | ||||||
|     AgentAudio, |     AgentAudio, | ||||||
|  | @ -385,3 +388,61 @@ class ToolTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|             GetWeatherTool3() |             GetWeatherTool3() | ||||||
|         assert "Nullable" in str(e) |         assert "Nullable" in str(e) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def mock_server_parameters(): | ||||||
|  |     return MagicMock() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def mock_mcp_adapt(): | ||||||
|  |     with patch("mcpadapt.core.MCPAdapt") as mock: | ||||||
|  |         mock.return_value.__enter__.return_value = ["tool1", "tool2"] | ||||||
|  |         mock.return_value.__exit__.return_value = None | ||||||
|  |         yield mock | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def mock_smolagents_adapter(): | ||||||
|  |     with patch("mcpadapt.smolagents_adapter.SmolAgentsAdapter") as mock: | ||||||
|  |         yield mock | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TestToolCollection: | ||||||
|  |     def test_from_mcp( | ||||||
|  |         self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter | ||||||
|  |     ): | ||||||
|  |         with ToolCollection.from_mcp(mock_server_parameters) as tool_collection: | ||||||
|  |             assert isinstance(tool_collection, ToolCollection) | ||||||
|  |             assert len(tool_collection.tools) == 2 | ||||||
|  |             assert "tool1" in tool_collection.tools | ||||||
|  |             assert "tool2" in tool_collection.tools | ||||||
|  | 
 | ||||||
|  |     def test_integration_from_mcp(self): | ||||||
|  |         # define the most simple mcp server with one tool that echoes the input text | ||||||
|  |         mcp_server_script = dedent("""\ | ||||||
|  |             from mcp.server.fastmcp import FastMCP | ||||||
|  | 
 | ||||||
|  |             mcp = FastMCP("Echo Server") | ||||||
|  | 
 | ||||||
|  |             @mcp.tool() | ||||||
|  |             def echo_tool(text: str) -> str: | ||||||
|  |                 return text | ||||||
|  | 
 | ||||||
|  |             mcp.run() | ||||||
|  |         """).strip() | ||||||
|  | 
 | ||||||
|  |         mcp_server_params = mcp.StdioServerParameters( | ||||||
|  |             command="python", | ||||||
|  |             args=["-c", mcp_server_script], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         with ToolCollection.from_mcp(mcp_server_params) as tool_collection: | ||||||
|  |             assert len(tool_collection.tools) == 1, "Expected 1 tool" | ||||||
|  |             assert tool_collection.tools[0].name == "echo_tool", ( | ||||||
|  |                 "Expected tool name to be 'echo_tool'" | ||||||
|  |             ) | ||||||
|  |             assert tool_collection.tools[0](text="Hello") == "Hello", ( | ||||||
|  |                 "Expected tool to echo the input text" | ||||||
|  |             ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue