diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 0f2f39a..6c1df46 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -14,20 +14,22 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.10 - - - name: Install Python dependencies - run: pip install -e .[quality] - - name: Run Quality check - run: make quality - - name: Check if failure - if: ${{ failure() }} + # Setup venv + # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. + - name: Setup venv + uv run: | - echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY + pip install --upgrade uv + uv venv - - name: Run Style check - run: make style - - name: Check if failure - if: ${{ failure() }} - run: | - echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY \ No newline at end of file + - name: Install dependencies + run: uv pip install "smolagents[test] @ ." + - run: uv run ruff check tests src # linter + - run: uv run ruff format --check tests src # formatter + + # Run type checking at least on smolagents root file to check all modules + # that can be lazy-loaded actually exist. + # - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback + + # Run mypy on full package + # - run: uv run mypy src \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..18983ef --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,47 @@ +name: Python tests + +on: [pull_request] + +jobs: + build-ubuntu: + runs-on: ubuntu-latest + env: + UV_HTTP_TIMEOUT: 600 # max 10min to install deps + + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.12"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + + # Setup venv + # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. + - name: Setup venv + uv + run: | + pip install --upgrade uv + uv venv + + # Install dependencies + - name: Install dependencies + run: | + uv pip install "smolagents[test] @ ." + + - name: Agent tests + run: | + uv run pytest -sv ./tests/test_agents.py + - name: Tool tests + run: | + uv run pytest -sv ./tests/test_toolss.py + - name: Python interpreter tests + run: | + uv run pytest -sv ./tests/test_python_interpreter.py + - name: Final answer tests + run: | + uv run pytest -sv ./tests/test_final_answer.py \ No newline at end of file diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index 0ec669c..dcb3532 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -346,7 +346,7 @@ "import glob\n", "\n", "res = []\n", - "for f in glob.glob(f\"output/*.jsonl\"):\n", + "for f in glob.glob(\"output/*.jsonl\"):\n", " res.append(pd.read_json(f, lines=True))\n", "result_df = pd.concat(res)\n", "\n", diff --git a/examples/tool_calling_agent_from_any_llm.py b/examples/tool_calling_agent_from_any_llm.py index 3e69da7..68155fe 100644 --- a/examples/tool_calling_agent_from_any_llm.py +++ b/examples/tool_calling_agent_from_any_llm.py @@ -1,5 +1,5 @@ from smolagents.agents import ToolCallingAgent -from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel +from smolagents import tool, LiteLLMModel from typing import Optional # Choose which LLM engine to use! diff --git a/pyproject.toml b/pyproject.toml index 7c664a1..1706213 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,25 +12,29 @@ authors = [ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "torch", - "torchaudio", - "torchvision", - "transformers>=4.0.0", - "requests>=2.32.3", - "rich>=13.9.4", - "pandas>=2.2.3", - "jinja2>=3.1.4", - "pillow>=11.0.0", - "markdownify>=0.14.1", - "gradio>=5.8.0", - "duckduckgo-search>=6.3.7", - "python-dotenv>=1.0.1", - "e2b-code-interpreter>=1.0.3", - "litellm>=1.55.10", + "torch", + "torchaudio", + "torchvision", + "transformers>=4.0.0", + "requests>=2.32.3", + "rich>=13.9.4", + "pandas>=2.2.3", + "jinja2>=3.1.4", + "pillow>=11.0.0", + "markdownify>=0.14.1", + "gradio>=5.8.0", + "duckduckgo-search>=6.3.7", + "python-dotenv>=1.0.1", + "e2b-code-interpreter>=1.0.3", + "litellm>=1.55.10", ] +[tool.ruff] +ignore = ["F403"] + [project.optional-dependencies] test = [ - "pytest>=8.1.0", - "sqlalchemy" + "pytest>=8.1.0", + "sqlalchemy", + "ruff>=0.5.0", ] diff --git a/server.py b/server.py index ebdc61d..b381d53 100644 --- a/server.py +++ b/server.py @@ -14,7 +14,7 @@ def execute_code(code): try: exec(code, exec_globals, exec_locals) - except Exception as e: + except Exception: traceback.print_exc(file=stderr) output = stdout.getvalue() diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py index 85cf6da..9243205 100644 --- a/src/smolagents/__init__.py +++ b/src/smolagents/__init__.py @@ -21,14 +21,13 @@ from typing import TYPE_CHECKING from transformers.utils import _LazyModule from transformers.utils.import_utils import define_import_structure - if TYPE_CHECKING: from .agents import * from .default_tools import * - from .gradio_ui import * - from .models import * - from .local_python_executor import * from .e2b_executor import * + from .gradio_ui import * + from .local_python_executor import * + from .models import * from .monitoring import * from .prompts import * from .tools import * diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index ab21b7c..f388c04 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -15,49 +15,50 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import Any, Callable, Dict, List, Optional, Union, Tuple from dataclasses import dataclass -from rich.syntax import Syntax +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + from rich.console import Group from rich.panel import Panel from rich.rule import Rule +from rich.syntax import Syntax from rich.text import Text -from .utils import ( - console, - parse_code_blob, - parse_json_tool_call, - truncate_content, - AgentError, - AgentParsingError, - AgentExecutionError, - AgentGenerationError, - AgentMaxStepsError, -) -from .types import AgentAudio, AgentImage, handle_agent_output_types from .default_tools import FinalAnswerTool +from .e2b_executor import E2BExecutor +from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter from .models import MessageRole from .monitoring import Monitor from .prompts import ( CODE_SYSTEM_PROMPT, - TOOL_CALLING_SYSTEM_PROMPT, + MANAGED_AGENT_PROMPT, PLAN_UPDATE_FINAL_PLAN_REDACTION, SYSTEM_PROMPT_FACTS, SYSTEM_PROMPT_FACTS_UPDATE, - USER_PROMPT_FACTS_UPDATE, - USER_PROMPT_PLAN_UPDATE, - USER_PROMPT_PLAN, - SYSTEM_PROMPT_PLAN_UPDATE, SYSTEM_PROMPT_PLAN, - MANAGED_AGENT_PROMPT, + SYSTEM_PROMPT_PLAN_UPDATE, + TOOL_CALLING_SYSTEM_PROMPT, + USER_PROMPT_FACTS_UPDATE, + USER_PROMPT_PLAN, + USER_PROMPT_PLAN_UPDATE, ) -from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter -from .e2b_executor import E2BExecutor from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, - get_tool_description_with_args, Toolbox, + get_tool_description_with_args, +) +from .types import AgentAudio, AgentImage, handle_agent_output_types +from .utils import ( + AgentError, + AgentExecutionError, + AgentGenerationError, + AgentMaxStepsError, + AgentParsingError, + console, + parse_code_blob, + parse_json_tool_call, + truncate_content, ) diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 362b338..8b6f44f 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -18,20 +18,20 @@ import json import re from dataclasses import dataclass from typing import Dict, Optional -from huggingface_hub import hf_hub_download, list_spaces -from transformers.utils import is_offline_mode +from huggingface_hub import hf_hub_download, list_spaces from transformers.models.whisper import ( - WhisperProcessor, WhisperForConditionalGeneration, + WhisperProcessor, ) +from transformers.utils import is_offline_mode from .local_python_executor import ( BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code, ) -from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool +from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .types import AgentAudio @@ -271,8 +271,8 @@ class VisitWebpageTool(Tool): def forward(self, url: str) -> str: try: - from markdownify import markdownify import requests + from markdownify import markdownify from requests.exceptions import RequestException except ImportError: raise ImportError( diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index cf539a6..68f5579 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -14,18 +14,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dotenv import load_dotenv -import textwrap import base64 import pickle +import textwrap from io import BytesIO +from typing import Any, List, Tuple + +from dotenv import load_dotenv +from e2b_code_interpreter import Sandbox from PIL import Image -from e2b_code_interpreter import Sandbox -from typing import List, Tuple, Any from .tool_validation import validate_tool_attributes -from .utils import instance_to_source, BASE_BUILTIN_MODULES, console from .tools import Tool +from .utils import BASE_BUILTIN_MODULES, console, instance_to_source load_dotenv() diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 4a724db..b745a80 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -14,10 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types -from .agents import MultiStepAgent, AgentStep, ActionStep import gradio as gr +from .agents import ActionStep, AgentStep, MultiStepAgent +from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types + def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): """Extract ChatMessage objects from agent steps""" diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index db935a6..6c7cb4a 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -17,14 +17,15 @@ import ast import builtins import difflib +import math from collections.abc import Mapping from importlib import import_module from typing import Any, Callable, Dict, List, Optional, Tuple -import math + import numpy as np import pandas as pd -from .utils import truncate_content, BASE_BUILTIN_MODULES +from .utils import BASE_BUILTIN_MODULES, truncate_content class InterpreterError(ValueError): diff --git a/src/smolagents/models.py b/src/smolagents/models.py index f83b948..9b1fd55 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -14,24 +14,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from enum import Enum import json -from typing import Dict, List, Optional -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - StoppingCriteria, - StoppingCriteriaList, -) - -import litellm import logging import os import random -import torch +from copy import deepcopy +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union +import litellm +import torch from huggingface_hub import InferenceClient +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteria, + StoppingCriteriaList, +) from .tools import Tool from .utils import parse_json_tool_call @@ -352,16 +351,16 @@ class TransformersModel(Model): ) # Get LLM output - prompt = self.tokenizer.apply_chat_template( + prompt_tensor = self.tokenizer.apply_chat_template( messages, return_tensors="pt", return_dict=True, ) - prompt = prompt.to(self.model.device) - count_prompt_tokens = prompt["input_ids"].shape[1] + prompt_tensor = prompt_tensor.to(self.model.device) + count_prompt_tokens = prompt_tensor["input_ids"].shape[1] out = self.model.generate( - **prompt, + **prompt_tensor, max_new_tokens=max_tokens, stopping_criteria=( self.make_stopping_criteria(stop_sequences) if stop_sequences else None @@ -383,7 +382,7 @@ class TransformersModel(Model): available_tools: List[Tool], stop_sequences: Optional[List[str]] = None, max_tokens: int = 500, - ) -> str: + ) -> Tuple[str, Union[str, None], str]: messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index a31140c..daa53cd 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -14,9 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .utils import console from rich.text import Text +from .utils import console + class Monitor: def __init__(self, tracked_model): diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index e410e89..821c315 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -1,8 +1,9 @@ import ast -import inspect import builtins -from typing import Set +import inspect import textwrap +from typing import Set + from .utils import BASE_BUILTIN_MODULES _BUILTIN_NAMES = set(vars(builtins)) diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 0cbc6aa..7acff0d 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -18,14 +18,16 @@ import ast import importlib import inspect import json +import logging import os import sys import tempfile -import torch import textwrap from functools import lru_cache, wraps from pathlib import Path from typing import Callable, Dict, List, Optional, Union, get_type_hints + +import torch from huggingface_hub import ( create_repo, get_collection, @@ -35,7 +37,8 @@ from huggingface_hub import ( ) from huggingface_hub.utils import RepositoryNotFoundError from packaging import version -import logging +from transformers import AutoProcessor +from transformers.dynamic_module_utils import get_imports from transformers.utils import ( TypeHintParsingException, cached_file, @@ -45,13 +48,9 @@ from transformers.utils import ( ) from transformers.utils.chat_template_utils import _parse_type_hint -from transformers.dynamic_module_utils import get_imports -from transformers import AutoProcessor - +from .tool_validation import MethodChecker, validate_tool_attributes from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .utils import instance_to_source -from .tool_validation import validate_tool_attributes, MethodChecker - logger = logging.getLogger(__name__) diff --git a/src/smolagents/types.py b/src/smolagents/types.py index d817608..dbc5d5b 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -12,21 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import pathlib import tempfile import uuid from io import BytesIO -import requests -import numpy as np +import numpy as np +import requests from transformers.utils import ( is_soundfile_availble, is_torch_available, is_vision_available, ) -import logging - logger = logging.getLogger(__name__) diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 9c32e51..902ebb7 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -14,14 +14,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ast +import inspect import json import re -from typing import Tuple, Dict, Union -import ast -from rich.console import Console -import inspect import types +from typing import Dict, Tuple, Union +from rich.console import Console from transformers.utils.import_utils import _is_package_available _pygments_available = _is_package_available("pygments") diff --git a/tests/test_agents.py b/tests/test_agents.py index 8804042..3ac4bed 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -16,22 +16,22 @@ import os import tempfile import unittest import uuid -import pytest - from pathlib import Path -from smolagents.types import AgentText, AgentImage +import pytest +from transformers.testing_utils import get_tests_dir + from smolagents.agents import ( AgentMaxStepsError, - ManagedAgent, CodeAgent, - ToolCallingAgent, + ManagedAgent, Toolbox, ToolCall, + ToolCallingAgent, ) -from smolagents.tools import tool from smolagents.default_tools import PythonInterpreterTool -from transformers.testing_utils import get_tests_dir +from smolagents.tools import tool +from smolagents.types import AgentImage, AgentText def get_new_path(suffix="") -> str: diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index cf88d5a..9177df2 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -17,12 +17,13 @@ import ast import os import re import shutil -import tempfile import subprocess +import tempfile import traceback -import pytest from pathlib import Path from typing import List + +import pytest from dotenv import load_dotenv diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 65a9728..873dcdc 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -18,16 +18,14 @@ from pathlib import Path import numpy as np from PIL import Image - from transformers import is_torch_available from transformers.testing_utils import get_tests_dir, require_torch -from smolagents.types import AGENT_TYPE_MAPPING from smolagents.default_tools import FinalAnswerTool +from smolagents.types import AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin - if is_torch_available(): import torch diff --git a/tests/test_models.py b/tests/test_models.py index 3e73393..dbd93ce 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from smolagents import models, tool from typing import Optional +from smolagents import models, tool + class ModelTests(unittest.TestCase): def test_get_json_schema_has_nullable_args(self): diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index d642817..5f2401d 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -16,8 +16,8 @@ import unittest from smolagents import ( - AgentImage, AgentError, + AgentImage, CodeAgent, ToolCallingAgent, stream_to_gradio, diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index de044c5..3508161 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -19,12 +19,12 @@ import numpy as np import pytest from smolagents import load_tool -from smolagents.types import AGENT_TYPE_MAPPING from smolagents.default_tools import BASE_PYTHON_TOOLS from smolagents.local_python_executor import ( InterpreterError, evaluate_python_code, ) +from smolagents.types import AGENT_TYPE_MAPPING from .test_tools import ToolTesterMixin diff --git a/tests/test_tools.py b/tests/test_tools.py index 9e8e3df..cfa61c1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -14,21 +14,20 @@ # limitations under the License. import unittest from pathlib import Path -from typing import Dict, Union, Optional +from typing import Dict, Optional, Union import numpy as np import pytest - from transformers import is_torch_available, is_vision_available +from transformers.testing_utils import get_tests_dir + +from smolagents.tools import AUTHORIZED_TYPES, Tool, tool from smolagents.types import ( AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText, ) -from smolagents.tools import Tool, tool, AUTHORIZED_TYPES -from transformers.testing_utils import get_tests_dir - if is_torch_available(): import torch diff --git a/tests/test_types.py b/tests/test_types.py index 8026a57..e988e8b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -18,7 +18,8 @@ import unittest import uuid from pathlib import Path -from smolagents.types import AgentAudio, AgentImage, AgentText +import torch +from PIL import Image from transformers.testing_utils import ( require_soundfile, require_torch, @@ -28,9 +29,7 @@ from transformers.utils import ( is_soundfile_availble, ) -import torch -from PIL import Image - +from smolagents.types import AgentAudio, AgentImage, AgentText if is_soundfile_availble(): import soundfile as sf diff --git a/tests/test_utils.py b/tests/test_utils.py index 34ca2db..4bd0f81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,7 @@ import os -import unittest import shutil import tempfile - +import unittest from pathlib import Path