Sort imports and add test workflows
This commit is contained in:
		
							parent
							
								
									417c6685b0
								
							
						
					
					
						commit
						c22fedaee1
					
				|  | @ -15,19 +15,21 @@ jobs: | |||
|         with: | ||||
|           python-version: 3.10 | ||||
| 
 | ||||
|       - name: Install Python dependencies | ||||
|         run: pip install -e .[quality] | ||||
| 
 | ||||
|       - name: Run Quality check | ||||
|         run: make quality | ||||
|       - name: Check if failure | ||||
|         if: ${{ failure() }} | ||||
|       # Setup venv | ||||
|       # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. | ||||
|       - name: Setup venv + uv | ||||
|         run: | | ||||
|           echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY | ||||
|           pip install --upgrade uv | ||||
|           uv venv | ||||
| 
 | ||||
|       - name: Run Style check | ||||
|         run: make style | ||||
|       - name: Check if failure | ||||
|         if: ${{ failure() }} | ||||
|         run: | | ||||
|           echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY | ||||
|       - name: Install dependencies | ||||
|         run: uv pip install "smolagents[test] @ ." | ||||
|       - run: uv run ruff check tests src # linter | ||||
|       - run: uv run ruff format --check tests src # formatter | ||||
| 
 | ||||
|       # Run type checking at least on smolagents root file to check all modules | ||||
|       # that can be lazy-loaded actually exist. | ||||
|       # - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback | ||||
| 
 | ||||
|       # Run mypy on full package | ||||
|       # - run: uv run mypy src | ||||
|  | @ -0,0 +1,47 @@ | |||
| name: Python tests | ||||
| 
 | ||||
| on: [pull_request] | ||||
| 
 | ||||
| jobs: | ||||
|   build-ubuntu: | ||||
|     runs-on: ubuntu-latest | ||||
|     env: | ||||
|       UV_HTTP_TIMEOUT: 600 # max 10min to install deps | ||||
| 
 | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         python-version: ["3.10", "3.12"] | ||||
| 
 | ||||
|     steps: | ||||
|       - uses: actions/checkout@v2 | ||||
|       - name: Set up Python ${{ matrix.python-version }} | ||||
|         uses: actions/setup-python@v2 | ||||
|         with: | ||||
|           python-version: ${{ matrix.python-version }} | ||||
| 
 | ||||
| 
 | ||||
|       # Setup venv | ||||
|       # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. | ||||
|       - name: Setup venv + uv | ||||
|         run: | | ||||
|           pip install --upgrade uv | ||||
|           uv venv | ||||
| 
 | ||||
|       # Install dependencies | ||||
|       - name: Install dependencies | ||||
|         run: | | ||||
|           uv pip install "smolagents[test] @ ." | ||||
| 
 | ||||
|       - name: Agent tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_agents.py | ||||
|       - name: Tool tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_toolss.py | ||||
|       - name: Python interpreter tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_python_interpreter.py | ||||
|       - name: Final answer tests | ||||
|         run: | | ||||
|           uv run pytest -sv ./tests/test_final_answer.py | ||||
|  | @ -346,7 +346,7 @@ | |||
|     "import glob\n", | ||||
|     "\n", | ||||
|     "res = []\n", | ||||
|     "for f in glob.glob(f\"output/*.jsonl\"):\n", | ||||
|     "for f in glob.glob(\"output/*.jsonl\"):\n", | ||||
|     "    res.append(pd.read_json(f, lines=True))\n", | ||||
|     "result_df = pd.concat(res)\n", | ||||
|     "\n", | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| from smolagents.agents import ToolCallingAgent | ||||
| from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel | ||||
| from smolagents import tool, LiteLLMModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| # Choose which LLM engine to use! | ||||
|  |  | |||
|  | @ -29,8 +29,12 @@ dependencies = [ | |||
|   "litellm>=1.55.10", | ||||
| ] | ||||
| 
 | ||||
| [tool.ruff] | ||||
| ignore = ["F403"] | ||||
| 
 | ||||
| [project.optional-dependencies] | ||||
| test = [ | ||||
|   "pytest>=8.1.0", | ||||
|     "sqlalchemy" | ||||
|   "sqlalchemy", | ||||
|   "ruff>=0.5.0", | ||||
| ] | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ def execute_code(code): | |||
| 
 | ||||
|     try: | ||||
|         exec(code, exec_globals, exec_locals) | ||||
|     except Exception as e: | ||||
|     except Exception: | ||||
|         traceback.print_exc(file=stderr) | ||||
|      | ||||
|     output = stdout.getvalue() | ||||
|  |  | |||
|  | @ -21,14 +21,13 @@ from typing import TYPE_CHECKING | |||
| from transformers.utils import _LazyModule | ||||
| from transformers.utils.import_utils import define_import_structure | ||||
| 
 | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from .agents import * | ||||
|     from .default_tools import * | ||||
|     from .gradio_ui import * | ||||
|     from .models import * | ||||
|     from .local_python_executor import * | ||||
|     from .e2b_executor import * | ||||
|     from .gradio_ui import * | ||||
|     from .local_python_executor import * | ||||
|     from .models import * | ||||
|     from .monitoring import * | ||||
|     from .prompts import * | ||||
|     from .tools import * | ||||
|  |  | |||
|  | @ -15,49 +15,50 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import time | ||||
| from typing import Any, Callable, Dict, List, Optional, Union, Tuple | ||||
| from dataclasses import dataclass | ||||
| from rich.syntax import Syntax | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||
| 
 | ||||
| from rich.console import Group | ||||
| from rich.panel import Panel | ||||
| from rich.rule import Rule | ||||
| from rich.syntax import Syntax | ||||
| from rich.text import Text | ||||
| 
 | ||||
| from .utils import ( | ||||
|     console, | ||||
|     parse_code_blob, | ||||
|     parse_json_tool_call, | ||||
|     truncate_content, | ||||
|     AgentError, | ||||
|     AgentParsingError, | ||||
|     AgentExecutionError, | ||||
|     AgentGenerationError, | ||||
|     AgentMaxStepsError, | ||||
| ) | ||||
| from .types import AgentAudio, AgentImage, handle_agent_output_types | ||||
| from .default_tools import FinalAnswerTool | ||||
| from .e2b_executor import E2BExecutor | ||||
| from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter | ||||
| from .models import MessageRole | ||||
| from .monitoring import Monitor | ||||
| from .prompts import ( | ||||
|     CODE_SYSTEM_PROMPT, | ||||
|     TOOL_CALLING_SYSTEM_PROMPT, | ||||
|     MANAGED_AGENT_PROMPT, | ||||
|     PLAN_UPDATE_FINAL_PLAN_REDACTION, | ||||
|     SYSTEM_PROMPT_FACTS, | ||||
|     SYSTEM_PROMPT_FACTS_UPDATE, | ||||
|     USER_PROMPT_FACTS_UPDATE, | ||||
|     USER_PROMPT_PLAN_UPDATE, | ||||
|     USER_PROMPT_PLAN, | ||||
|     SYSTEM_PROMPT_PLAN_UPDATE, | ||||
|     SYSTEM_PROMPT_PLAN, | ||||
|     MANAGED_AGENT_PROMPT, | ||||
|     SYSTEM_PROMPT_PLAN_UPDATE, | ||||
|     TOOL_CALLING_SYSTEM_PROMPT, | ||||
|     USER_PROMPT_FACTS_UPDATE, | ||||
|     USER_PROMPT_PLAN, | ||||
|     USER_PROMPT_PLAN_UPDATE, | ||||
| ) | ||||
| from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter | ||||
| from .e2b_executor import E2BExecutor | ||||
| from .tools import ( | ||||
|     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, | ||||
|     Tool, | ||||
|     get_tool_description_with_args, | ||||
|     Toolbox, | ||||
|     get_tool_description_with_args, | ||||
| ) | ||||
| from .types import AgentAudio, AgentImage, handle_agent_output_types | ||||
| from .utils import ( | ||||
|     AgentError, | ||||
|     AgentExecutionError, | ||||
|     AgentGenerationError, | ||||
|     AgentMaxStepsError, | ||||
|     AgentParsingError, | ||||
|     console, | ||||
|     parse_code_blob, | ||||
|     parse_json_tool_call, | ||||
|     truncate_content, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,20 +18,20 @@ import json | |||
| import re | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, Optional | ||||
| from huggingface_hub import hf_hub_download, list_spaces | ||||
| 
 | ||||
| from transformers.utils import is_offline_mode | ||||
| from huggingface_hub import hf_hub_download, list_spaces | ||||
| from transformers.models.whisper import ( | ||||
|     WhisperProcessor, | ||||
|     WhisperForConditionalGeneration, | ||||
|     WhisperProcessor, | ||||
| ) | ||||
| from transformers.utils import is_offline_mode | ||||
| 
 | ||||
| from .local_python_executor import ( | ||||
|     BASE_BUILTIN_MODULES, | ||||
|     BASE_PYTHON_TOOLS, | ||||
|     evaluate_python_code, | ||||
| ) | ||||
| from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool | ||||
| from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | ||||
| from .types import AgentAudio | ||||
| 
 | ||||
| 
 | ||||
|  | @ -271,8 +271,8 @@ class VisitWebpageTool(Tool): | |||
| 
 | ||||
|     def forward(self, url: str) -> str: | ||||
|         try: | ||||
|             from markdownify import markdownify | ||||
|             import requests | ||||
|             from markdownify import markdownify | ||||
|             from requests.exceptions import RequestException | ||||
|         except ImportError: | ||||
|             raise ImportError( | ||||
|  |  | |||
|  | @ -14,18 +14,19 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from dotenv import load_dotenv | ||||
| import textwrap | ||||
| import base64 | ||||
| import pickle | ||||
| import textwrap | ||||
| from io import BytesIO | ||||
| from typing import Any, List, Tuple | ||||
| 
 | ||||
| from dotenv import load_dotenv | ||||
| from e2b_code_interpreter import Sandbox | ||||
| from PIL import Image | ||||
| 
 | ||||
| from e2b_code_interpreter import Sandbox | ||||
| from typing import List, Tuple, Any | ||||
| from .tool_validation import validate_tool_attributes | ||||
| from .utils import instance_to_source, BASE_BUILTIN_MODULES, console | ||||
| from .tools import Tool | ||||
| from .utils import BASE_BUILTIN_MODULES, console, instance_to_source | ||||
| 
 | ||||
| load_dotenv() | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,10 +14,11 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||
| from .agents import MultiStepAgent, AgentStep, ActionStep | ||||
| import gradio as gr | ||||
| 
 | ||||
| from .agents import ActionStep, AgentStep, MultiStepAgent | ||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||
| 
 | ||||
| 
 | ||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||
|     """Extract ChatMessage objects from agent steps""" | ||||
|  |  | |||
|  | @ -17,14 +17,15 @@ | |||
| import ast | ||||
| import builtins | ||||
| import difflib | ||||
| import math | ||||
| from collections.abc import Mapping | ||||
| from importlib import import_module | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | ||||
| import math | ||||
| 
 | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| 
 | ||||
| from .utils import truncate_content, BASE_BUILTIN_MODULES | ||||
| from .utils import BASE_BUILTIN_MODULES, truncate_content | ||||
| 
 | ||||
| 
 | ||||
| class InterpreterError(ValueError): | ||||
|  |  | |||
|  | @ -14,24 +14,23 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from copy import deepcopy | ||||
| from enum import Enum | ||||
| import json | ||||
| from typing import Dict, List, Optional | ||||
| from transformers import ( | ||||
|     AutoTokenizer, | ||||
|     AutoModelForCausalLM, | ||||
|     StoppingCriteria, | ||||
|     StoppingCriteriaList, | ||||
| ) | ||||
| 
 | ||||
| import litellm | ||||
| import logging | ||||
| import os | ||||
| import random | ||||
| import torch | ||||
| from copy import deepcopy | ||||
| from enum import Enum | ||||
| from typing import Dict, List, Optional, Tuple, Union | ||||
| 
 | ||||
| import litellm | ||||
| import torch | ||||
| from huggingface_hub import InferenceClient | ||||
| from transformers import ( | ||||
|     AutoModelForCausalLM, | ||||
|     AutoTokenizer, | ||||
|     StoppingCriteria, | ||||
|     StoppingCriteriaList, | ||||
| ) | ||||
| 
 | ||||
| from .tools import Tool | ||||
| from .utils import parse_json_tool_call | ||||
|  | @ -352,16 +351,16 @@ class TransformersModel(Model): | |||
|         ) | ||||
| 
 | ||||
|         # Get LLM output | ||||
|         prompt = self.tokenizer.apply_chat_template( | ||||
|         prompt_tensor = self.tokenizer.apply_chat_template( | ||||
|             messages, | ||||
|             return_tensors="pt", | ||||
|             return_dict=True, | ||||
|         ) | ||||
|         prompt = prompt.to(self.model.device) | ||||
|         count_prompt_tokens = prompt["input_ids"].shape[1] | ||||
|         prompt_tensor = prompt_tensor.to(self.model.device) | ||||
|         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] | ||||
| 
 | ||||
|         out = self.model.generate( | ||||
|             **prompt, | ||||
|             **prompt_tensor, | ||||
|             max_new_tokens=max_tokens, | ||||
|             stopping_criteria=( | ||||
|                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None | ||||
|  | @ -383,7 +382,7 @@ class TransformersModel(Model): | |||
|         available_tools: List[Tool], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         max_tokens: int = 500, | ||||
|     ) -> str: | ||||
|     ) -> Tuple[str, Union[str, None], str]: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  |  | |||
|  | @ -14,9 +14,10 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from .utils import console | ||||
| from rich.text import Text | ||||
| 
 | ||||
| from .utils import console | ||||
| 
 | ||||
| 
 | ||||
| class Monitor: | ||||
|     def __init__(self, tracked_model): | ||||
|  |  | |||
|  | @ -1,8 +1,9 @@ | |||
| import ast | ||||
| import inspect | ||||
| import builtins | ||||
| from typing import Set | ||||
| import inspect | ||||
| import textwrap | ||||
| from typing import Set | ||||
| 
 | ||||
| from .utils import BASE_BUILTIN_MODULES | ||||
| 
 | ||||
| _BUILTIN_NAMES = set(vars(builtins)) | ||||
|  |  | |||
|  | @ -18,14 +18,16 @@ import ast | |||
| import importlib | ||||
| import inspect | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import sys | ||||
| import tempfile | ||||
| import torch | ||||
| import textwrap | ||||
| from functools import lru_cache, wraps | ||||
| from pathlib import Path | ||||
| from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||
| 
 | ||||
| import torch | ||||
| from huggingface_hub import ( | ||||
|     create_repo, | ||||
|     get_collection, | ||||
|  | @ -35,7 +37,8 @@ from huggingface_hub import ( | |||
| ) | ||||
| from huggingface_hub.utils import RepositoryNotFoundError | ||||
| from packaging import version | ||||
| import logging | ||||
| from transformers import AutoProcessor | ||||
| from transformers.dynamic_module_utils import get_imports | ||||
| from transformers.utils import ( | ||||
|     TypeHintParsingException, | ||||
|     cached_file, | ||||
|  | @ -45,13 +48,9 @@ from transformers.utils import ( | |||
| ) | ||||
| from transformers.utils.chat_template_utils import _parse_type_hint | ||||
| 
 | ||||
| from transformers.dynamic_module_utils import get_imports | ||||
| from transformers import AutoProcessor | ||||
| 
 | ||||
| from .tool_validation import MethodChecker, validate_tool_attributes | ||||
| from .types import ImageType, handle_agent_input_types, handle_agent_output_types | ||||
| from .utils import instance_to_source | ||||
| from .tool_validation import validate_tool_attributes, MethodChecker | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,21 +12,20 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| import os | ||||
| import pathlib | ||||
| import tempfile | ||||
| import uuid | ||||
| from io import BytesIO | ||||
| import requests | ||||
| import numpy as np | ||||
| 
 | ||||
| import numpy as np | ||||
| import requests | ||||
| from transformers.utils import ( | ||||
|     is_soundfile_availble, | ||||
|     is_torch_available, | ||||
|     is_vision_available, | ||||
| ) | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,14 +14,14 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import ast | ||||
| import inspect | ||||
| import json | ||||
| import re | ||||
| from typing import Tuple, Dict, Union | ||||
| import ast | ||||
| from rich.console import Console | ||||
| import inspect | ||||
| import types | ||||
| from typing import Dict, Tuple, Union | ||||
| 
 | ||||
| from rich.console import Console | ||||
| from transformers.utils.import_utils import _is_package_available | ||||
| 
 | ||||
| _pygments_available = _is_package_available("pygments") | ||||
|  |  | |||
|  | @ -16,22 +16,22 @@ import os | |||
| import tempfile | ||||
| import unittest | ||||
| import uuid | ||||
| import pytest | ||||
| 
 | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from smolagents.types import AgentText, AgentImage | ||||
| import pytest | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| 
 | ||||
| from smolagents.agents import ( | ||||
|     AgentMaxStepsError, | ||||
|     ManagedAgent, | ||||
|     CodeAgent, | ||||
|     ToolCallingAgent, | ||||
|     ManagedAgent, | ||||
|     Toolbox, | ||||
|     ToolCall, | ||||
|     ToolCallingAgent, | ||||
| ) | ||||
| from smolagents.tools import tool | ||||
| from smolagents.default_tools import PythonInterpreterTool | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| from smolagents.tools import tool | ||||
| from smolagents.types import AgentImage, AgentText | ||||
| 
 | ||||
| 
 | ||||
| def get_new_path(suffix="") -> str: | ||||
|  |  | |||
|  | @ -17,12 +17,13 @@ import ast | |||
| import os | ||||
| import re | ||||
| import shutil | ||||
| import tempfile | ||||
| import subprocess | ||||
| import tempfile | ||||
| import traceback | ||||
| import pytest | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
| 
 | ||||
| import pytest | ||||
| from dotenv import load_dotenv | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,16 +18,14 @@ from pathlib import Path | |||
| 
 | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| 
 | ||||
| from transformers import is_torch_available | ||||
| from transformers.testing_utils import get_tests_dir, require_torch | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| 
 | ||||
| from smolagents.default_tools import FinalAnswerTool | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| 
 | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|     import torch | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,9 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import unittest | ||||
| from smolagents import models, tool | ||||
| from typing import Optional | ||||
| 
 | ||||
| from smolagents import models, tool | ||||
| 
 | ||||
| 
 | ||||
| class ModelTests(unittest.TestCase): | ||||
|     def test_get_json_schema_has_nullable_args(self): | ||||
|  |  | |||
|  | @ -16,8 +16,8 @@ | |||
| import unittest | ||||
| 
 | ||||
| from smolagents import ( | ||||
|     AgentImage, | ||||
|     AgentError, | ||||
|     AgentImage, | ||||
|     CodeAgent, | ||||
|     ToolCallingAgent, | ||||
|     stream_to_gradio, | ||||
|  |  | |||
|  | @ -19,12 +19,12 @@ import numpy as np | |||
| import pytest | ||||
| 
 | ||||
| from smolagents import load_tool | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||
| from smolagents.local_python_executor import ( | ||||
|     InterpreterError, | ||||
|     evaluate_python_code, | ||||
| ) | ||||
| from smolagents.types import AGENT_TYPE_MAPPING | ||||
| 
 | ||||
| from .test_tools import ToolTesterMixin | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,21 +14,20 @@ | |||
| # limitations under the License. | ||||
| import unittest | ||||
| from pathlib import Path | ||||
| from typing import Dict, Union, Optional | ||||
| from typing import Dict, Optional, Union | ||||
| 
 | ||||
| import numpy as np | ||||
| import pytest | ||||
| 
 | ||||
| from transformers import is_torch_available, is_vision_available | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| 
 | ||||
| from smolagents.tools import AUTHORIZED_TYPES, Tool, tool | ||||
| from smolagents.types import ( | ||||
|     AGENT_TYPE_MAPPING, | ||||
|     AgentAudio, | ||||
|     AgentImage, | ||||
|     AgentText, | ||||
| ) | ||||
| from smolagents.tools import Tool, tool, AUTHORIZED_TYPES | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| 
 | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|     import torch | ||||
|  |  | |||
|  | @ -18,7 +18,8 @@ import unittest | |||
| import uuid | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | ||||
| import torch | ||||
| from PIL import Image | ||||
| from transformers.testing_utils import ( | ||||
|     require_soundfile, | ||||
|     require_torch, | ||||
|  | @ -28,9 +29,7 @@ from transformers.utils import ( | |||
|     is_soundfile_availble, | ||||
| ) | ||||
| 
 | ||||
| import torch | ||||
| from PIL import Image | ||||
| 
 | ||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | ||||
| 
 | ||||
| if is_soundfile_availble(): | ||||
|     import soundfile as sf | ||||
|  |  | |||
|  | @ -1,8 +1,7 @@ | |||
| import os | ||||
| import unittest | ||||
| import shutil | ||||
| import tempfile | ||||
| 
 | ||||
| import unittest | ||||
| from pathlib import Path | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue