add support for MCP Servers tools as `ToolCollection` (#232)
* add support for tool collection from mcp servers * add forgotten documentation * fix link missing in documentation * fix linting in CI, bumpruff to use modern version * mcpadapt added as optional dependencies * use classmethod for from_hub and from_mcp to better reflect the fact that they return a ToolCollection * Update src/smolagents/tools.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> * Update src/smolagents/tools.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> * Test ToolCollection.from_mcp * Rename to mcp extra * Add mcp extra to test extra * add a test for from_mcp * fix typo * fix tests * Test ToolCollection.from_mcp (cherry picked from commit 9284d9ea8cf24d3c934e35a38dfe34f3ce31cef3) * Make all pytest tests --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
parent
7d6599e430
commit
a4d029da88
|
@ -204,13 +204,17 @@ agent.run(
|
||||||
|
|
||||||
### Use a collection of tools
|
### Use a collection of tools
|
||||||
|
|
||||||
You can leverage tool collections by using the `ToolCollection` object, with the slug of the collection you want to use.
|
You can leverage tool collections by using the `ToolCollection` object. It supports loading either a collection from the Hub or an MCP server tools.
|
||||||
|
|
||||||
|
#### Tool Collection from a collection in the Hub
|
||||||
|
|
||||||
|
You can leverage it with the slug of the collection you want to use.
|
||||||
Then pass them as a list to initialize your agent, and start using them!
|
Then pass them as a list to initialize your agent, and start using them!
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from smolagents import ToolCollection, CodeAgent
|
from smolagents import ToolCollection, CodeAgent
|
||||||
|
|
||||||
image_tool_collection = ToolCollection(
|
image_tool_collection = ToolCollection.from_hub(
|
||||||
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
|
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
|
||||||
token="<YOUR_HUGGINGFACEHUB_API_TOKEN>"
|
token="<YOUR_HUGGINGFACEHUB_API_TOKEN>"
|
||||||
)
|
)
|
||||||
|
@ -220,3 +224,24 @@ agent.run("Please draw me a picture of rivers and lakes.")
|
||||||
```
|
```
|
||||||
|
|
||||||
To speed up the start, tools are loaded only if called by the agent.
|
To speed up the start, tools are loaded only if called by the agent.
|
||||||
|
|
||||||
|
#### Tool Collection from any MCP server
|
||||||
|
|
||||||
|
Leverage tools from the hundreds of MCP servers available on [glama.ai](https://glama.ai/mcp/servers) or [smithery.ai](https://smithery.ai/).
|
||||||
|
|
||||||
|
The MCP servers tools can be loaded in a `ToolCollection` object as follow:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from smolagents import ToolCollection, CodeAgent
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
|
||||||
|
server_parameters = StdioServerParameters(
|
||||||
|
command="uv",
|
||||||
|
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
env={"UV_PYTHON": "3.12", **os.environ},
|
||||||
|
)
|
||||||
|
|
||||||
|
with ToolCollection.from_mcp(server_parameters) as tool_collection:
|
||||||
|
agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True)
|
||||||
|
agent.run("Please find a remedy for hangover.")
|
||||||
|
```
|
|
@ -209,7 +209,7 @@ agent.run(
|
||||||
```py
|
```py
|
||||||
from smolagents import ToolCollection, CodeAgent
|
from smolagents import ToolCollection, CodeAgent
|
||||||
|
|
||||||
image_tool_collection = ToolCollection(
|
image_tool_collection = ToolCollection.from_hub(
|
||||||
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
|
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
|
||||||
token="<YOUR_HUGGINGFACEHUB_API_TOKEN>"
|
token="<YOUR_HUGGINGFACEHUB_API_TOKEN>"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""An example of loading a ToolCollection directly from an MCP server.
|
||||||
|
|
||||||
|
Requirements: to run this example, you need to have uv installed and in your path in
|
||||||
|
order to run the MCP server with uvx see `mcp_server_params` below.
|
||||||
|
|
||||||
|
Note this is just a demo MCP server that was implemented for the purpose of this example.
|
||||||
|
It only provide a single tool to search amongst pubmed papers abstracts.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> uv run examples/tool_calling_agent_mcp.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
||||||
|
|
||||||
|
mcp_server_params = StdioServerParameters(
|
||||||
|
command="uvx",
|
||||||
|
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
env={"UV_PYTHON": "3.12", **os.environ},
|
||||||
|
)
|
||||||
|
|
||||||
|
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
||||||
|
# print(tool_collection.tools[0](request={"term": "efficient treatment hangover"}))
|
||||||
|
agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel())
|
||||||
|
agent.run("Find studies about hangover?")
|
|
@ -36,13 +36,18 @@ torch = [
|
||||||
litellm = [
|
litellm = [
|
||||||
"litellm>=1.55.10",
|
"litellm>=1.55.10",
|
||||||
]
|
]
|
||||||
openai = ["openai>=1.58.1"]
|
mcp = [
|
||||||
|
"mcpadapt>=0.0.6"
|
||||||
|
]
|
||||||
|
openai = [
|
||||||
|
"openai>=1.58.1"
|
||||||
|
]
|
||||||
quality = [
|
quality = [
|
||||||
"ruff>=0.9.0",
|
"ruff>=0.9.0",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest>=8.1.0",
|
"pytest>=8.1.0",
|
||||||
"smolagents[audio,litellm,openai,torch]",
|
"smolagents[audio,litellm,mcp,openai,torch]",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"smolagents[quality,test]",
|
"smolagents[quality,test]",
|
||||||
|
|
|
@ -23,9 +23,10 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from contextlib import contextmanager
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, Union, get_type_hints
|
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
create_repo,
|
create_repo,
|
||||||
|
@ -35,6 +36,7 @@ from huggingface_hub import (
|
||||||
upload_folder,
|
upload_folder,
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
@ -275,7 +277,8 @@ class Tool:
|
||||||
raise (ValueError("\n".join(method_checker.errors)))
|
raise (ValueError("\n".join(method_checker.errors)))
|
||||||
|
|
||||||
forward_source_code = inspect.getsource(self.forward)
|
forward_source_code = inspect.getsource(self.forward)
|
||||||
tool_code = textwrap.dedent(f"""
|
tool_code = textwrap.dedent(
|
||||||
|
f"""
|
||||||
from smolagents import Tool
|
from smolagents import Tool
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -284,7 +287,8 @@ class Tool:
|
||||||
description = "{self.description}"
|
description = "{self.description}"
|
||||||
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
|
||||||
output_type = "{self.output_type}"
|
output_type = "{self.output_type}"
|
||||||
""").strip()
|
"""
|
||||||
|
).strip()
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def add_self_argument(source_code: str) -> str:
|
def add_self_argument(source_code: str) -> str:
|
||||||
|
@ -325,7 +329,8 @@ class Tool:
|
||||||
app_file = os.path.join(output_dir, "app.py")
|
app_file = os.path.join(output_dir, "app.py")
|
||||||
with open(app_file, "w", encoding="utf-8") as f:
|
with open(app_file, "w", encoding="utf-8") as f:
|
||||||
f.write(
|
f.write(
|
||||||
textwrap.dedent(f"""
|
textwrap.dedent(
|
||||||
|
f"""
|
||||||
from smolagents import launch_gradio_demo
|
from smolagents import launch_gradio_demo
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from tool import {class_name}
|
from tool import {class_name}
|
||||||
|
@ -333,7 +338,8 @@ class Tool:
|
||||||
tool = {class_name}()
|
tool = {class_name}()
|
||||||
|
|
||||||
launch_gradio_demo(tool)
|
launch_gradio_demo(tool)
|
||||||
""").lstrip()
|
"""
|
||||||
|
).lstrip()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save requirements file
|
# Save requirements file
|
||||||
|
@ -870,42 +876,105 @@ def add_description(description):
|
||||||
|
|
||||||
class ToolCollection:
|
class ToolCollection:
|
||||||
"""
|
"""
|
||||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
Tool collections enable loading a collection of tools in the agent's toolbox.
|
||||||
|
|
||||||
> [!NOTE]
|
Collections can be loaded from a collection in the Hub or from an MCP server, see:
|
||||||
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
|
- [`ToolCollection.from_hub`]
|
||||||
> like for this collection to showcase them.
|
- [`ToolCollection.from_mcp`]
|
||||||
|
|
||||||
Args:
|
For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`]
|
||||||
collection_slug (str):
|
|
||||||
The collection slug referencing the collection.
|
|
||||||
token (str, *optional*):
|
|
||||||
The authentication token if the collection is private.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
>>> from transformers import ToolCollection, CodeAgent
|
|
||||||
|
|
||||||
>>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
|
||||||
>>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
|
||||||
|
|
||||||
>>> agent.run("Please draw me a picture of rivers and lakes.")
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, tools: List[Tool]):
|
||||||
self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False
|
self.tools = tools
|
||||||
):
|
|
||||||
self._collection = get_collection(collection_slug, token=token)
|
@classmethod
|
||||||
self._hub_repo_ids = {
|
def from_hub(
|
||||||
item.item_id for item in self._collection.items if item.item_type == "space"
|
cls,
|
||||||
|
collection_slug: str,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
) -> "ToolCollection":
|
||||||
|
"""Loads a tool collection from the Hub.
|
||||||
|
|
||||||
|
it adds a collection of tools from all Spaces in the collection to the agent's toolbox
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
|
||||||
|
> like for this collection to showcase them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_slug (str): The collection slug referencing the collection.
|
||||||
|
token (str, *optional*): The authentication token if the collection is private.
|
||||||
|
trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolCollection: A tool collection instance loaded with the tools.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
>>> from smolagents import ToolCollection, CodeAgent
|
||||||
|
|
||||||
|
>>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
||||||
|
>>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
||||||
|
|
||||||
|
>>> agent.run("Please draw me a picture of rivers and lakes.")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
_collection = get_collection(collection_slug, token=token)
|
||||||
|
_hub_repo_ids = {
|
||||||
|
item.item_id for item in _collection.items if item.item_type == "space"
|
||||||
}
|
}
|
||||||
self.tools = {
|
|
||||||
|
tools = {
|
||||||
Tool.from_hub(repo_id, token, trust_remote_code)
|
Tool.from_hub(repo_id, token, trust_remote_code)
|
||||||
for repo_id in self._hub_repo_ids
|
for repo_id in _hub_repo_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cls(tools)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def from_mcp(cls, server_parameters) -> "ToolCollection":
|
||||||
|
"""Automatically load a tool collection from an MCP server.
|
||||||
|
|
||||||
|
Note: a separate thread will be spawned to run an asyncio event loop handling
|
||||||
|
the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_parameters (mcp.StdioServerParameters): The server parameters to use to
|
||||||
|
connect to the MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolCollection: A tool collection instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
>>> from smolagents import ToolCollection, CodeAgent
|
||||||
|
>>> from mcp import StdioServerParameters
|
||||||
|
|
||||||
|
>>> server_parameters = StdioServerParameters(
|
||||||
|
>>> command="uv",
|
||||||
|
>>> args=["--quiet", "pubmedmcp@0.1.3"],
|
||||||
|
>>> env={"UV_PYTHON": "3.12", **os.environ},
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
>>> with ToolCollection.from_mcp(server_parameters) as tool_collection:
|
||||||
|
>>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True)
|
||||||
|
>>> agent.run("Please find a remedy for hangover.")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from mcpadapt.core import MCPAdapt
|
||||||
|
from mcpadapt.smolagents_adapter import SmolAgentsAdapter
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"""Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`."""
|
||||||
|
)
|
||||||
|
|
||||||
|
with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools:
|
||||||
|
yield cls(tools)
|
||||||
|
|
||||||
|
|
||||||
def tool(tool_function: Callable) -> Tool:
|
def tool(tool_function: Callable) -> Tool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,14 +14,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from textwrap import dedent
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import mcp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
|
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
|
||||||
from smolagents.types import (
|
from smolagents.types import (
|
||||||
AGENT_TYPE_MAPPING,
|
AGENT_TYPE_MAPPING,
|
||||||
AgentAudio,
|
AgentAudio,
|
||||||
|
@ -385,3 +388,61 @@ class ToolTests(unittest.TestCase):
|
||||||
|
|
||||||
GetWeatherTool3()
|
GetWeatherTool3()
|
||||||
assert "Nullable" in str(e)
|
assert "Nullable" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_server_parameters():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mcp_adapt():
|
||||||
|
with patch("mcpadapt.core.MCPAdapt") as mock:
|
||||||
|
mock.return_value.__enter__.return_value = ["tool1", "tool2"]
|
||||||
|
mock.return_value.__exit__.return_value = None
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_smolagents_adapter():
|
||||||
|
with patch("mcpadapt.smolagents_adapter.SmolAgentsAdapter") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCollection:
|
||||||
|
def test_from_mcp(
|
||||||
|
self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
|
||||||
|
):
|
||||||
|
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
|
||||||
|
assert isinstance(tool_collection, ToolCollection)
|
||||||
|
assert len(tool_collection.tools) == 2
|
||||||
|
assert "tool1" in tool_collection.tools
|
||||||
|
assert "tool2" in tool_collection.tools
|
||||||
|
|
||||||
|
def test_integration_from_mcp(self):
|
||||||
|
# define the most simple mcp server with one tool that echoes the input text
|
||||||
|
mcp_server_script = dedent("""\
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
mcp = FastMCP("Echo Server")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def echo_tool(text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
mcp.run()
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
mcp_server_params = mcp.StdioServerParameters(
|
||||||
|
command="python",
|
||||||
|
args=["-c", mcp_server_script],
|
||||||
|
)
|
||||||
|
|
||||||
|
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
||||||
|
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
||||||
|
assert tool_collection.tools[0].name == "echo_tool", (
|
||||||
|
"Expected tool name to be 'echo_tool'"
|
||||||
|
)
|
||||||
|
assert tool_collection.tools[0](text="Hello") == "Hello", (
|
||||||
|
"Expected tool to echo the input text"
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue