Add CLI for smolagents (#431)

---------

Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
Aymeric Roucher 2025-01-30 15:56:50 +01:00 committed by GitHub
parent 595781b82d
commit fa1f8d0154
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 190 additions and 77 deletions

View File

@ -84,12 +84,12 @@ These types have three specific purposes:
### AgentText ### AgentText
[[autodoc]] smolagents.types.AgentText [[autodoc]] smolagents.agent_types.AgentText
### AgentImage ### AgentImage
[[autodoc]] smolagents.types.AgentImage [[autodoc]] smolagents.agent_types.AgentImage
### AgentAudio ### AgentAudio
[[autodoc]] smolagents.types.AgentAudio [[autodoc]] smolagents.agent_types.AgentAudio

View File

@ -80,12 +80,12 @@ Smolagents एक experimental API है जो किसी भी समय
### AgentText ### AgentText
[[autodoc]] smolagents.types.AgentText [[autodoc]] smolagents.agent_types.AgentText
### AgentImage ### AgentImage
[[autodoc]] smolagents.types.AgentImage [[autodoc]] smolagents.agent_types.AgentImage
### AgentAudio ### AgentAudio
[[autodoc]] smolagents.types.AgentAudio [[autodoc]] smolagents.agent_types.AgentAudio

View File

@ -80,12 +80,12 @@ These types have three specific purposes:
### AgentText ### AgentText
[[autodoc]] smolagents.types.AgentText [[autodoc]] smolagents.agent_types.AgentText
### AgentImage ### AgentImage
[[autodoc]] smolagents.types.AgentImage [[autodoc]] smolagents.agent_types.AgentImage
### AgentAudio ### AgentAudio
[[autodoc]] smolagents.types.AgentAudio [[autodoc]] smolagents.agent_types.AgentAudio

View File

@ -20,6 +20,7 @@ dependencies = [
"pillow>=11.0.0", "pillow>=11.0.0",
"markdownify>=0.14.1", "markdownify>=0.14.1",
"duckduckgo-search>=6.3.7", "duckduckgo-search>=6.3.7",
"python-dotenv"
] ]
[project.optional-dependencies] [project.optional-dependencies]
@ -97,3 +98,7 @@ lint.select = ["E", "F", "I", "W"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
known-first-party = ["smolagents"] known-first-party = ["smolagents"]
lines-after-imports = 2 lines-after-imports = 2
[project.scripts]
smolagent = "smolagents.cli:main"
webagent = "smolagents.vlm_web_browser:main"

View File

@ -16,7 +16,8 @@
# limitations under the License. # limitations under the License.
__version__ = "1.7.0.dev0" __version__ = "1.7.0.dev0"
from .agents import * from .agent_types import * # noqa: I001
from .agents import * # Above noqa avoids a circular dependency due to cli.py
from .default_tools import * from .default_tools import *
from .e2b_executor import * from .e2b_executor import *
from .gradio_ui import * from .gradio_ui import *
@ -26,5 +27,5 @@ from .models import *
from .monitoring import * from .monitoring import *
from .prompts import * from .prompts import *
from .tools import * from .tools import *
from .types import *
from .utils import * from .utils import *
from .cli import *

View File

@ -25,13 +25,13 @@ from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.text import Text from rich.text import Text
from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.monitoring import ( from smolagents.monitoring import (
YELLOW_HEX, YELLOW_HEX,
AgentLogger, AgentLogger,
LogLevel, LogLevel,
) )
from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.utils import ( from smolagents.utils import (
AgentError, AgentError,
AgentExecutionError, AgentExecutionError,
@ -43,6 +43,7 @@ from smolagents.utils import (
truncate_content, truncate_content,
) )
from .agent_types import AgentType
from .default_tools import TOOL_MAPPING, FinalAnswerTool from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import ( from .local_python_executor import (
@ -73,7 +74,6 @@ from .tools import (
Tool, Tool,
get_tool_description_with_args, get_tool_description_with_args,
) )
from .types import AgentType
logger = getLogger(__name__) logger = getLogger(__name__)

118
src/smolagents/cli.py Normal file
View File

@ -0,0 +1,118 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dotenv import load_dotenv
from smolagents import CodeAgent, HfApiModel, LiteLLMModel, Model, OpenAIServerModel, Tool, TransformersModel
from smolagents.default_tools import TOOL_MAPPING
leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?"
def parse_arguments(description):
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"prompt",
type=str,
nargs="?", # Makes it optional
default=leopard_prompt,
help="The prompt to run with the agent",
)
parser.add_argument(
"--model-type",
type=str,
default="HfApiModel",
help="The model type to use (e.g., HfApiModel, OpenAIServerModel, LiteLLMModel, TransformersModel)",
)
parser.add_argument(
"--model-id",
type=str,
default="Qwen/Qwen2.5-Coder-32B-Instruct",
help="The model ID to use for the specified model type",
)
parser.add_argument(
"--imports",
nargs="*", # accepts zero or more arguments
default=[],
help="Space-separated list of imports to authorize (e.g., 'numpy pandas')",
)
parser.add_argument(
"--tools",
nargs="*",
default=["web_search"],
help="Space-separated list of tools that the agent can use (e.g., 'tool1 tool2 tool3')",
)
parser.add_argument(
"--verbosity-level",
type=int,
default=1,
help="The verbosity level, as an int in [0, 1, 2].",
)
return parser.parse_args()
def load_model(model_type: str, model_id: str) -> Model:
if model_type == "OpenAIServerModel":
return OpenAIServerModel(
api_key=os.getenv("FIREWORKS_API_KEY"),
api_base="https://api.fireworks.ai/inference/v1",
model_id=model_id,
)
elif model_type == "LiteLLMModel":
return LiteLLMModel(
model_id=model_id,
api_key=os.getenv("OPENAI_API_KEY"),
)
elif model_type == "TransformersModel":
return TransformersModel(model_id=model_id, device_map="auto", flatten_messages_as_text=False)
elif model_type == "HfApiModel":
return HfApiModel(
token=os.getenv("HF_API_KEY"),
model_id=model_id,
)
else:
raise ValueError(f"Unsupported model type: {model_type}")
def main():
load_dotenv()
args = parse_arguments(description="Run a CodeAgent with all specified parameters")
model = load_model(args.model_type, args.model_id)
available_tools = []
for tool_name in args.tools:
if "/" in tool_name:
available_tools.append(Tool.from_space(tool_name))
else:
if tool_name in TOOL_MAPPING:
available_tools.append(TOOL_MAPPING[tool_name]())
else:
raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.")
print(f"Running agent with these tools: {args.tools}")
agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=args.imports)
agent.run(args.prompt)
if __name__ == "__main__":
main()

View File

@ -18,13 +18,13 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .agent_types import AgentAudio
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS, BASE_PYTHON_TOOLS,
evaluate_python_code, evaluate_python_code,
) )
from .tools import PipelineTool, Tool from .tools import PipelineTool, Tool
from .types import AgentAudio
@dataclass @dataclass

View File

@ -19,9 +19,9 @@ import re
import shutil import shutil
from typing import Optional from typing import Optional
from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent from smolagents.agents import ActionStep, MultiStepAgent
from smolagents.memory import MemoryStep from smolagents.memory import MemoryStep
from smolagents.types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.utils import _is_package_available from smolagents.utils import _is_package_available

View File

@ -44,8 +44,8 @@ from ._function_type_hints_utils import (
get_imports, get_imports,
get_json_schema, get_json_schema,
) )
from .agent_types import handle_agent_input_types, handle_agent_output_types
from .tool_validation import MethodChecker, validate_tool_attributes from .tool_validation import MethodChecker, validate_tool_attributes
from .types import handle_agent_input_types, handle_agent_output_types
from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source

View File

@ -1,5 +1,4 @@
import argparse import argparse
import os
from io import BytesIO from io import BytesIO
from time import sleep from time import sleep
@ -10,8 +9,9 @@ from selenium import webdriver
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from smolagents import CodeAgent, HfApiModel, LiteLLMModel, OpenAIServerModel, TransformersModel, tool # noqa: F401 from smolagents import CodeAgent, tool
from smolagents.agents import ActionStep from smolagents.agents import ActionStep
from smolagents.cli import load_model
github_request = """ github_request = """
@ -27,7 +27,7 @@ Please navigate to https://en.wikipedia.org/wiki/Chicago and give me a sentence
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Run a web browser automation script with a specified model.") parser = argparse.ArgumentParser(description="Run a web browser automation script with a specified model.")
parser.add_argument( parser.add_argument(
"--model", "--model-type",
type=str, type=str,
default="LiteLLMModel", default="LiteLLMModel",
help="The model type to use (e.g., OpenAIServerModel, LiteLLMModel, TransformersModel, HfApiModel)", help="The model type to use (e.g., OpenAIServerModel, LiteLLMModel, TransformersModel, HfApiModel)",
@ -42,36 +42,6 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
# Load environment variables
load_dotenv()
# Parse command line arguments
args = parse_arguments()
# Initialize the model based on the provided arguments
if args.model == "OpenAIServerModel":
model = OpenAIServerModel(
api_key=os.getenv("FIREWORKS_API_KEY"),
api_base="https://api.fireworks.ai/inference/v1",
model_id=args.model_id,
)
elif args.model == "LiteLLMModel":
model = LiteLLMModel(
model_id=args.model_id,
api_key=os.getenv("OPENAI_API_KEY"),
)
elif args.model == "TransformersModel":
model = TransformersModel(model_id=args.model_id, device_map="auto", flatten_messages_as_text=False)
elif args.model == "HfApiModel":
model = HfApiModel(
token=os.getenv("HF_API_KEY"),
model_id=args.model_id,
)
else:
raise ValueError(f"Unsupported model type: {args.model}")
# Prepare callback
def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None: def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
sleep(1.0) # Let JavaScript animations happen before taking the screenshot sleep(1.0) # Let JavaScript animations happen before taking the screenshot
driver = helium.get_driver() driver = helium.get_driver()
@ -93,18 +63,6 @@ def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
return return
# Initialize driver and agent
chrome_options = webdriver.ChromeOptions()
chrome_options.add_argument("--force-device-scale-factor=1")
chrome_options.add_argument("--window-size=1000,1350")
chrome_options.add_argument("--disable-pdf-viewer")
chrome_options.add_argument("--window-position=0,0")
driver = helium.start_chrome(headless=False, options=chrome_options)
# Initialize tools
@tool @tool
def search_item_ctrl_f(text: str, nth_result: int = 1) -> str: def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
""" """
@ -137,14 +95,27 @@ def close_popups() -> str:
webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform() webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
agent = CodeAgent( def initialize_driver():
tools=[go_back, close_popups, search_item_ctrl_f], """Initialize the Selenium WebDriver."""
model=model, chrome_options = webdriver.ChromeOptions()
additional_authorized_imports=["helium"], chrome_options.add_argument("--force-device-scale-factor=1")
step_callbacks=[save_screenshot], chrome_options.add_argument("--window-size=1000,1350")
max_steps=20, chrome_options.add_argument("--disable-pdf-viewer")
verbosity_level=2, chrome_options.add_argument("--window-position=0,0")
) return helium.start_chrome(headless=False, options=chrome_options)
def initialize_agent(model):
"""Initialize the CodeAgent with the specified model."""
return CodeAgent(
tools=[go_back, close_popups, search_item_ctrl_f],
model=model,
additional_authorized_imports=["helium"],
step_callbacks=[save_screenshot],
max_steps=20,
verbosity_level=2,
)
helium_instructions = """ helium_instructions = """
You can use helium to access websites. Don't bother about the helium driver, it's already managed. You can use helium to access websites. Don't bother about the helium driver, it's already managed.
@ -207,7 +178,25 @@ Don't kill the browser.
When you have modals or cookie banners on screen, you should get rid of them before you can click anything else. When you have modals or cookie banners on screen, you should get rid of them before you can click anything else.
""" """
# Run the agent with the provided prompt
agent.python_executor("from helium import *", agent.state) def main():
agent.run(args.prompt + helium_instructions) # Load environment variables
load_dotenv()
# Parse command line arguments
args = parse_arguments()
# Initialize the model based on the provided arguments
model = load_model(args.model_type, args.model_id)
global driver
driver = initialize_driver()
agent = initialize_agent(model)
# Run the agent with the provided prompt
agent.python_executor("from helium import *", agent.state)
agent.run(args.prompt + helium_instructions)
if __name__ == "__main__":
main()

View File

@ -20,6 +20,7 @@ from pathlib import Path
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents.agent_types import AgentImage, AgentText
from smolagents.agents import ( from smolagents.agents import (
AgentMaxStepsError, AgentMaxStepsError,
CodeAgent, CodeAgent,
@ -30,7 +31,6 @@ from smolagents.agents import (
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
from smolagents.tools import tool from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.utils import BASE_BUILTIN_MODULES from smolagents.utils import BASE_BUILTIN_MODULES

View File

@ -16,8 +16,8 @@ import unittest
import pytest import pytest
from smolagents.agent_types import _AGENT_TYPE_MAPPING
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
from smolagents.types import _AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin

View File

@ -21,8 +21,8 @@ from PIL import Image
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from smolagents.agent_types import _AGENT_TYPE_MAPPING
from smolagents.default_tools import FinalAnswerTool from smolagents.default_tools import FinalAnswerTool
from smolagents.types import _AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin

View File

@ -26,8 +26,8 @@ import torch
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents.agent_types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
if is_torch_available(): if is_torch_available():

View File

@ -25,7 +25,7 @@ from transformers.testing_utils import (
require_vision, require_vision,
) )
from smolagents.types import AgentAudio, AgentImage, AgentText from smolagents.agent_types import AgentAudio, AgentImage, AgentText
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str: