diff --git a/docs/source/en/reference/tools.md b/docs/source/en/reference/tools.md index 9d78774..b8d11d0 100644 --- a/docs/source/en/reference/tools.md +++ b/docs/source/en/reference/tools.md @@ -84,12 +84,12 @@ These types have three specific purposes: ### AgentText -[[autodoc]] smolagents.types.AgentText +[[autodoc]] smolagents.agent_types.AgentText ### AgentImage -[[autodoc]] smolagents.types.AgentImage +[[autodoc]] smolagents.agent_types.AgentImage ### AgentAudio -[[autodoc]] smolagents.types.AgentAudio +[[autodoc]] smolagents.agent_types.AgentAudio diff --git a/docs/source/hi/reference/tools.md b/docs/source/hi/reference/tools.md index ddb24d1..6c27032 100644 --- a/docs/source/hi/reference/tools.md +++ b/docs/source/hi/reference/tools.md @@ -80,12 +80,12 @@ Smolagents एक experimental API है जो किसी भी समय ### AgentText -[[autodoc]] smolagents.types.AgentText +[[autodoc]] smolagents.agent_types.AgentText ### AgentImage -[[autodoc]] smolagents.types.AgentImage +[[autodoc]] smolagents.agent_types.AgentImage ### AgentAudio -[[autodoc]] smolagents.types.AgentAudio +[[autodoc]] smolagents.agent_types.AgentAudio diff --git a/docs/source/zh/reference/tools.md b/docs/source/zh/reference/tools.md index 022ad35..847c497 100644 --- a/docs/source/zh/reference/tools.md +++ b/docs/source/zh/reference/tools.md @@ -80,12 +80,12 @@ These types have three specific purposes: ### AgentText -[[autodoc]] smolagents.types.AgentText +[[autodoc]] smolagents.agent_types.AgentText ### AgentImage -[[autodoc]] smolagents.types.AgentImage +[[autodoc]] smolagents.agent_types.AgentImage ### AgentAudio -[[autodoc]] smolagents.types.AgentAudio +[[autodoc]] smolagents.agent_types.AgentAudio diff --git a/pyproject.toml b/pyproject.toml index db766ef..1efb73f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "pillow>=11.0.0", "markdownify>=0.14.1", "duckduckgo-search>=6.3.7", + "python-dotenv" ] [project.optional-dependencies] @@ -97,3 +98,7 @@ lint.select = ["E", "F", "I", "W"] [tool.ruff.lint.isort] known-first-party = ["smolagents"] lines-after-imports = 2 + +[project.scripts] +smolagent = "smolagents.cli:main" +webagent = "smolagents.vlm_web_browser:main" \ No newline at end of file diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py index 7cbf191..4ccc6b4 100644 --- a/src/smolagents/__init__.py +++ b/src/smolagents/__init__.py @@ -16,7 +16,8 @@ # limitations under the License. __version__ = "1.7.0.dev0" -from .agents import * +from .agent_types import * # noqa: I001 +from .agents import * # Above noqa avoids a circular dependency due to cli.py from .default_tools import * from .e2b_executor import * from .gradio_ui import * @@ -26,5 +27,5 @@ from .models import * from .monitoring import * from .prompts import * from .tools import * -from .types import * from .utils import * +from .cli import * diff --git a/src/smolagents/types.py b/src/smolagents/agent_types.py similarity index 100% rename from src/smolagents/types.py rename to src/smolagents/agent_types.py diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 6114601..0a9ec6c 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -25,13 +25,13 @@ from rich.panel import Panel from rich.rule import Rule from rich.text import Text +from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall from smolagents.monitoring import ( YELLOW_HEX, AgentLogger, LogLevel, ) -from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types from smolagents.utils import ( AgentError, AgentExecutionError, @@ -43,6 +43,7 @@ from smolagents.utils import ( truncate_content, ) +from .agent_types import AgentType from .default_tools import TOOL_MAPPING, FinalAnswerTool from .e2b_executor import E2BExecutor from .local_python_executor import ( @@ -73,7 +74,6 @@ from .tools import ( Tool, get_tool_description_with_args, ) -from .types import AgentType logger = getLogger(__name__) diff --git a/src/smolagents/cli.py b/src/smolagents/cli.py new file mode 100644 index 0000000..255a880 --- /dev/null +++ b/src/smolagents/cli.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +from dotenv import load_dotenv + +from smolagents import CodeAgent, HfApiModel, LiteLLMModel, Model, OpenAIServerModel, Tool, TransformersModel +from smolagents.default_tools import TOOL_MAPPING + + +leopard_prompt = "How many seconds would it take for a leopard at full speed to run through Pont des Arts?" + + +def parse_arguments(description): + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "prompt", + type=str, + nargs="?", # Makes it optional + default=leopard_prompt, + help="The prompt to run with the agent", + ) + parser.add_argument( + "--model-type", + type=str, + default="HfApiModel", + help="The model type to use (e.g., HfApiModel, OpenAIServerModel, LiteLLMModel, TransformersModel)", + ) + parser.add_argument( + "--model-id", + type=str, + default="Qwen/Qwen2.5-Coder-32B-Instruct", + help="The model ID to use for the specified model type", + ) + parser.add_argument( + "--imports", + nargs="*", # accepts zero or more arguments + default=[], + help="Space-separated list of imports to authorize (e.g., 'numpy pandas')", + ) + parser.add_argument( + "--tools", + nargs="*", + default=["web_search"], + help="Space-separated list of tools that the agent can use (e.g., 'tool1 tool2 tool3')", + ) + parser.add_argument( + "--verbosity-level", + type=int, + default=1, + help="The verbosity level, as an int in [0, 1, 2].", + ) + return parser.parse_args() + + +def load_model(model_type: str, model_id: str) -> Model: + if model_type == "OpenAIServerModel": + return OpenAIServerModel( + api_key=os.getenv("FIREWORKS_API_KEY"), + api_base="https://api.fireworks.ai/inference/v1", + model_id=model_id, + ) + elif model_type == "LiteLLMModel": + return LiteLLMModel( + model_id=model_id, + api_key=os.getenv("OPENAI_API_KEY"), + ) + elif model_type == "TransformersModel": + return TransformersModel(model_id=model_id, device_map="auto", flatten_messages_as_text=False) + elif model_type == "HfApiModel": + return HfApiModel( + token=os.getenv("HF_API_KEY"), + model_id=model_id, + ) + else: + raise ValueError(f"Unsupported model type: {model_type}") + + +def main(): + load_dotenv() + + args = parse_arguments(description="Run a CodeAgent with all specified parameters") + + model = load_model(args.model_type, args.model_id) + + available_tools = [] + for tool_name in args.tools: + if "/" in tool_name: + available_tools.append(Tool.from_space(tool_name)) + else: + if tool_name in TOOL_MAPPING: + available_tools.append(TOOL_MAPPING[tool_name]()) + else: + raise ValueError(f"Tool {tool_name} is not recognized either as a default tool or a Space.") + + print(f"Running agent with these tools: {args.tools}") + agent = CodeAgent(tools=available_tools, model=model, additional_authorized_imports=args.imports) + + agent.run(args.prompt) + + +if __name__ == "__main__": + main() diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 650a30e..d3cf1e6 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -18,13 +18,13 @@ import re from dataclasses import dataclass from typing import Any, Dict, Optional +from .agent_types import AgentAudio from .local_python_executor import ( BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code, ) from .tools import PipelineTool, Tool -from .types import AgentAudio @dataclass diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 9491746..c33b60f 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -19,9 +19,9 @@ import re import shutil from typing import Optional +from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from smolagents.agents import ActionStep, MultiStepAgent from smolagents.memory import MemoryStep -from smolagents.types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from smolagents.utils import _is_package_available diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 10b22ea..3285efb 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -44,8 +44,8 @@ from ._function_type_hints_utils import ( get_imports, get_json_schema, ) +from .agent_types import handle_agent_input_types, handle_agent_output_types from .tool_validation import MethodChecker, validate_tool_attributes -from .types import handle_agent_input_types, handle_agent_output_types from .utils import _is_package_available, _is_pillow_available, get_source, instance_to_source diff --git a/examples/vlm_web_browser.py b/src/smolagents/vlm_web_browser.py similarity index 77% rename from examples/vlm_web_browser.py rename to src/smolagents/vlm_web_browser.py index d650154..950f5fc 100644 --- a/examples/vlm_web_browser.py +++ b/src/smolagents/vlm_web_browser.py @@ -1,5 +1,4 @@ import argparse -import os from io import BytesIO from time import sleep @@ -10,8 +9,9 @@ from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys -from smolagents import CodeAgent, HfApiModel, LiteLLMModel, OpenAIServerModel, TransformersModel, tool # noqa: F401 +from smolagents import CodeAgent, tool from smolagents.agents import ActionStep +from smolagents.cli import load_model github_request = """ @@ -27,7 +27,7 @@ Please navigate to https://en.wikipedia.org/wiki/Chicago and give me a sentence def parse_arguments(): parser = argparse.ArgumentParser(description="Run a web browser automation script with a specified model.") parser.add_argument( - "--model", + "--model-type", type=str, default="LiteLLMModel", help="The model type to use (e.g., OpenAIServerModel, LiteLLMModel, TransformersModel, HfApiModel)", @@ -42,36 +42,6 @@ def parse_arguments(): return parser.parse_args() -# Load environment variables -load_dotenv() - -# Parse command line arguments -args = parse_arguments() - -# Initialize the model based on the provided arguments -if args.model == "OpenAIServerModel": - model = OpenAIServerModel( - api_key=os.getenv("FIREWORKS_API_KEY"), - api_base="https://api.fireworks.ai/inference/v1", - model_id=args.model_id, - ) -elif args.model == "LiteLLMModel": - model = LiteLLMModel( - model_id=args.model_id, - api_key=os.getenv("OPENAI_API_KEY"), - ) -elif args.model == "TransformersModel": - model = TransformersModel(model_id=args.model_id, device_map="auto", flatten_messages_as_text=False) -elif args.model == "HfApiModel": - model = HfApiModel( - token=os.getenv("HF_API_KEY"), - model_id=args.model_id, - ) -else: - raise ValueError(f"Unsupported model type: {args.model}") - - -# Prepare callback def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None: sleep(1.0) # Let JavaScript animations happen before taking the screenshot driver = helium.get_driver() @@ -93,18 +63,6 @@ def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None: return -# Initialize driver and agent -chrome_options = webdriver.ChromeOptions() -chrome_options.add_argument("--force-device-scale-factor=1") -chrome_options.add_argument("--window-size=1000,1350") -chrome_options.add_argument("--disable-pdf-viewer") -chrome_options.add_argument("--window-position=0,0") - -driver = helium.start_chrome(headless=False, options=chrome_options) - -# Initialize tools - - @tool def search_item_ctrl_f(text: str, nth_result: int = 1) -> str: """ @@ -137,14 +95,27 @@ def close_popups() -> str: webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform() -agent = CodeAgent( - tools=[go_back, close_popups, search_item_ctrl_f], - model=model, - additional_authorized_imports=["helium"], - step_callbacks=[save_screenshot], - max_steps=20, - verbosity_level=2, -) +def initialize_driver(): + """Initialize the Selenium WebDriver.""" + chrome_options = webdriver.ChromeOptions() + chrome_options.add_argument("--force-device-scale-factor=1") + chrome_options.add_argument("--window-size=1000,1350") + chrome_options.add_argument("--disable-pdf-viewer") + chrome_options.add_argument("--window-position=0,0") + return helium.start_chrome(headless=False, options=chrome_options) + + +def initialize_agent(model): + """Initialize the CodeAgent with the specified model.""" + return CodeAgent( + tools=[go_back, close_popups, search_item_ctrl_f], + model=model, + additional_authorized_imports=["helium"], + step_callbacks=[save_screenshot], + max_steps=20, + verbosity_level=2, + ) + helium_instructions = """ You can use helium to access websites. Don't bother about the helium driver, it's already managed. @@ -207,7 +178,25 @@ Don't kill the browser. When you have modals or cookie banners on screen, you should get rid of them before you can click anything else. """ -# Run the agent with the provided prompt -agent.python_executor("from helium import *", agent.state) -agent.run(args.prompt + helium_instructions) +def main(): + # Load environment variables + load_dotenv() + + # Parse command line arguments + args = parse_arguments() + + # Initialize the model based on the provided arguments + model = load_model(args.model_type, args.model_id) + + global driver + driver = initialize_driver() + agent = initialize_agent(model) + + # Run the agent with the provided prompt + agent.python_executor("from helium import *", agent.state) + agent.run(args.prompt + helium_instructions) + + +if __name__ == "__main__": + main() diff --git a/tests/test_agents.py b/tests/test_agents.py index d123d83..53f2cfd 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -20,6 +20,7 @@ from pathlib import Path from transformers.testing_utils import get_tests_dir +from smolagents.agent_types import AgentImage, AgentText from smolagents.agents import ( AgentMaxStepsError, CodeAgent, @@ -30,7 +31,6 @@ from smolagents.agents import ( from smolagents.default_tools import PythonInterpreterTool from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel from smolagents.tools import tool -from smolagents.types import AgentImage, AgentText from smolagents.utils import BASE_BUILTIN_MODULES diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py index 89b7e85..e3cb4a6 100644 --- a/tests/test_default_tools.py +++ b/tests/test_default_tools.py @@ -16,8 +16,8 @@ import unittest import pytest +from smolagents.agent_types import _AGENT_TYPE_MAPPING from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool -from smolagents.types import _AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 7bb1e5e..fcfb02a 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -21,8 +21,8 @@ from PIL import Image from transformers import is_torch_available from transformers.testing_utils import get_tests_dir, require_torch +from smolagents.agent_types import _AGENT_TYPE_MAPPING from smolagents.default_tools import FinalAnswerTool -from smolagents.types import _AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin diff --git a/tests/test_tools.py b/tests/test_tools.py index 4948a1e..4df4b4d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -26,8 +26,8 @@ import torch from transformers import is_torch_available, is_vision_available from transformers.testing_utils import get_tests_dir +from smolagents.agent_types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool -from smolagents.types import _AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText if is_torch_available(): diff --git a/tests/test_types.py b/tests/test_types.py index 9350da1..73465d0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -25,7 +25,7 @@ from transformers.testing_utils import ( require_vision, ) -from smolagents.types import AgentAudio, AgentImage, AgentText +from smolagents.agent_types import AgentAudio, AgentImage, AgentText def get_new_path(suffix="") -> str: