Sort imports and add test workflows
This commit is contained in:
parent
417c6685b0
commit
c22fedaee1
|
@ -14,20 +14,22 @@ jobs:
|
|||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.10
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: pip install -e .[quality]
|
||||
|
||||
- name: Run Quality check
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
# Setup venv
|
||||
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
||||
- name: Setup venv + uv
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
|
||||
pip install --upgrade uv
|
||||
uv venv
|
||||
|
||||
- name: Run Style check
|
||||
run: make style
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY
|
||||
- name: Install dependencies
|
||||
run: uv pip install "smolagents[test] @ ."
|
||||
- run: uv run ruff check tests src # linter
|
||||
- run: uv run ruff format --check tests src # formatter
|
||||
|
||||
# Run type checking at least on smolagents root file to check all modules
|
||||
# that can be lazy-loaded actually exist.
|
||||
# - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
|
||||
|
||||
# Run mypy on full package
|
||||
# - run: uv run mypy src
|
|
@ -0,0 +1,47 @@
|
|||
name: Python tests
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
build-ubuntu:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
UV_HTTP_TIMEOUT: 600 # max 10min to install deps
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
|
||||
# Setup venv
|
||||
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
|
||||
- name: Setup venv + uv
|
||||
run: |
|
||||
pip install --upgrade uv
|
||||
uv venv
|
||||
|
||||
# Install dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install "smolagents[test] @ ."
|
||||
|
||||
- name: Agent tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_agents.py
|
||||
- name: Tool tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_toolss.py
|
||||
- name: Python interpreter tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_python_interpreter.py
|
||||
- name: Final answer tests
|
||||
run: |
|
||||
uv run pytest -sv ./tests/test_final_answer.py
|
|
@ -346,7 +346,7 @@
|
|||
"import glob\n",
|
||||
"\n",
|
||||
"res = []\n",
|
||||
"for f in glob.glob(f\"output/*.jsonl\"):\n",
|
||||
"for f in glob.glob(\"output/*.jsonl\"):\n",
|
||||
" res.append(pd.read_json(f, lines=True))\n",
|
||||
"result_df = pd.concat(res)\n",
|
||||
"\n",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from smolagents.agents import ToolCallingAgent
|
||||
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
|
||||
from smolagents import tool, LiteLLMModel
|
||||
from typing import Optional
|
||||
|
||||
# Choose which LLM engine to use!
|
||||
|
|
|
@ -12,25 +12,29 @@ authors = [
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"transformers>=4.0.0",
|
||||
"requests>=2.32.3",
|
||||
"rich>=13.9.4",
|
||||
"pandas>=2.2.3",
|
||||
"jinja2>=3.1.4",
|
||||
"pillow>=11.0.0",
|
||||
"markdownify>=0.14.1",
|
||||
"gradio>=5.8.0",
|
||||
"duckduckgo-search>=6.3.7",
|
||||
"python-dotenv>=1.0.1",
|
||||
"e2b-code-interpreter>=1.0.3",
|
||||
"litellm>=1.55.10",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"transformers>=4.0.0",
|
||||
"requests>=2.32.3",
|
||||
"rich>=13.9.4",
|
||||
"pandas>=2.2.3",
|
||||
"jinja2>=3.1.4",
|
||||
"pillow>=11.0.0",
|
||||
"markdownify>=0.14.1",
|
||||
"gradio>=5.8.0",
|
||||
"duckduckgo-search>=6.3.7",
|
||||
"python-dotenv>=1.0.1",
|
||||
"e2b-code-interpreter>=1.0.3",
|
||||
"litellm>=1.55.10",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
ignore = ["F403"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.1.0",
|
||||
"sqlalchemy"
|
||||
"pytest>=8.1.0",
|
||||
"sqlalchemy",
|
||||
"ruff>=0.5.0",
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ def execute_code(code):
|
|||
|
||||
try:
|
||||
exec(code, exec_globals, exec_locals)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc(file=stderr)
|
||||
|
||||
output = stdout.getvalue()
|
||||
|
|
|
@ -21,14 +21,13 @@ from typing import TYPE_CHECKING
|
|||
from transformers.utils import _LazyModule
|
||||
from transformers.utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import *
|
||||
from .default_tools import *
|
||||
from .gradio_ui import *
|
||||
from .models import *
|
||||
from .local_python_executor import *
|
||||
from .e2b_executor import *
|
||||
from .gradio_ui import *
|
||||
from .local_python_executor import *
|
||||
from .models import *
|
||||
from .monitoring import *
|
||||
from .prompts import *
|
||||
from .tools import *
|
||||
|
|
|
@ -15,49 +15,50 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
from rich.syntax import Syntax
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
from .utils import (
|
||||
console,
|
||||
parse_code_blob,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
AgentError,
|
||||
AgentParsingError,
|
||||
AgentExecutionError,
|
||||
AgentGenerationError,
|
||||
AgentMaxStepsError,
|
||||
)
|
||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||
from .default_tools import FinalAnswerTool
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
|
||||
from .models import MessageRole
|
||||
from .monitoring import Monitor
|
||||
from .prompts import (
|
||||
CODE_SYSTEM_PROMPT,
|
||||
TOOL_CALLING_SYSTEM_PROMPT,
|
||||
MANAGED_AGENT_PROMPT,
|
||||
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
||||
SYSTEM_PROMPT_FACTS,
|
||||
SYSTEM_PROMPT_FACTS_UPDATE,
|
||||
USER_PROMPT_FACTS_UPDATE,
|
||||
USER_PROMPT_PLAN_UPDATE,
|
||||
USER_PROMPT_PLAN,
|
||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
SYSTEM_PROMPT_PLAN,
|
||||
MANAGED_AGENT_PROMPT,
|
||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
TOOL_CALLING_SYSTEM_PROMPT,
|
||||
USER_PROMPT_FACTS_UPDATE,
|
||||
USER_PROMPT_PLAN,
|
||||
USER_PROMPT_PLAN_UPDATE,
|
||||
)
|
||||
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
get_tool_description_with_args,
|
||||
Toolbox,
|
||||
get_tool_description_with_args,
|
||||
)
|
||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||
from .utils import (
|
||||
AgentError,
|
||||
AgentExecutionError,
|
||||
AgentGenerationError,
|
||||
AgentMaxStepsError,
|
||||
AgentParsingError,
|
||||
console,
|
||||
parse_code_blob,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -18,20 +18,20 @@ import json
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from transformers.utils import is_offline_mode
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
from transformers.models.whisper import (
|
||||
WhisperProcessor,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.utils import is_offline_mode
|
||||
|
||||
from .local_python_executor import (
|
||||
BASE_BUILTIN_MODULES,
|
||||
BASE_PYTHON_TOOLS,
|
||||
evaluate_python_code,
|
||||
)
|
||||
from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool
|
||||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||
from .types import AgentAudio
|
||||
|
||||
|
||||
|
@ -271,8 +271,8 @@ class VisitWebpageTool(Tool):
|
|||
|
||||
def forward(self, url: str) -> str:
|
||||
try:
|
||||
from markdownify import markdownify
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from requests.exceptions import RequestException
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
|
|
|
@ -14,18 +14,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dotenv import load_dotenv
|
||||
import textwrap
|
||||
import base64
|
||||
import pickle
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from PIL import Image
|
||||
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from typing import List, Tuple, Any
|
||||
from .tool_validation import validate_tool_attributes
|
||||
from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
|
||||
from .tools import Tool
|
||||
from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||
from .agents import MultiStepAgent, AgentStep, ActionStep
|
||||
import gradio as gr
|
||||
|
||||
from .agents import ActionStep, AgentStep, MultiStepAgent
|
||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||
|
||||
|
||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||
"""Extract ChatMessage objects from agent steps"""
|
||||
|
|
|
@ -17,14 +17,15 @@
|
|||
import ast
|
||||
import builtins
|
||||
import difflib
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .utils import truncate_content, BASE_BUILTIN_MODULES
|
||||
from .utils import BASE_BUILTIN_MODULES, truncate_content
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
|
|
|
@ -14,24 +14,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
|
||||
import litellm
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
import torch
|
||||
from huggingface_hub import InferenceClient
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
|
||||
from .tools import Tool
|
||||
from .utils import parse_json_tool_call
|
||||
|
@ -352,16 +351,16 @@ class TransformersModel(Model):
|
|||
)
|
||||
|
||||
# Get LLM output
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
prompt = prompt.to(self.model.device)
|
||||
count_prompt_tokens = prompt["input_ids"].shape[1]
|
||||
prompt_tensor = prompt_tensor.to(self.model.device)
|
||||
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
|
||||
|
||||
out = self.model.generate(
|
||||
**prompt,
|
||||
**prompt_tensor,
|
||||
max_new_tokens=max_tokens,
|
||||
stopping_criteria=(
|
||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||
|
@ -383,7 +382,7 @@ class TransformersModel(Model):
|
|||
available_tools: List[Tool],
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
max_tokens: int = 500,
|
||||
) -> str:
|
||||
) -> Tuple[str, Union[str, None], str]:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .utils import console
|
||||
from rich.text import Text
|
||||
|
||||
from .utils import console
|
||||
|
||||
|
||||
class Monitor:
|
||||
def __init__(self, tracked_model):
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import ast
|
||||
import inspect
|
||||
import builtins
|
||||
from typing import Set
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Set
|
||||
|
||||
from .utils import BASE_BUILTIN_MODULES
|
||||
|
||||
_BUILTIN_NAMES = set(vars(builtins))
|
||||
|
|
|
@ -18,14 +18,16 @@ import ast
|
|||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
import textwrap
|
||||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
get_collection,
|
||||
|
@ -35,7 +37,8 @@ from huggingface_hub import (
|
|||
)
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from packaging import version
|
||||
import logging
|
||||
from transformers import AutoProcessor
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from transformers.utils import (
|
||||
TypeHintParsingException,
|
||||
cached_file,
|
||||
|
@ -45,13 +48,9 @@ from transformers.utils import (
|
|||
)
|
||||
from transformers.utils.chat_template_utils import _parse_type_hint
|
||||
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from .tool_validation import MethodChecker, validate_tool_attributes
|
||||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||
from .utils import instance_to_source
|
||||
from .tool_validation import validate_tool_attributes, MethodChecker
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -12,21 +12,20 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers.utils import (
|
||||
is_soundfile_availble,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, Dict, Union
|
||||
import ast
|
||||
from rich.console import Console
|
||||
import inspect
|
||||
import types
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from rich.console import Console
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
_pygments_available = _is_package_available("pygments")
|
||||
|
|
|
@ -16,22 +16,22 @@ import os
|
|||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from smolagents.types import AgentText, AgentImage
|
||||
import pytest
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents.agents import (
|
||||
AgentMaxStepsError,
|
||||
ManagedAgent,
|
||||
CodeAgent,
|
||||
ToolCallingAgent,
|
||||
ManagedAgent,
|
||||
Toolbox,
|
||||
ToolCall,
|
||||
ToolCallingAgent,
|
||||
)
|
||||
from smolagents.tools import tool
|
||||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
|
|
|
@ -17,12 +17,13 @@ import ast
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import tempfile
|
||||
import traceback
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
|
|
@ -18,16 +18,14 @@ from pathlib import Path
|
|||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from smolagents.default_tools import FinalAnswerTool
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from smolagents import models, tool
|
||||
from typing import Optional
|
||||
|
||||
from smolagents import models, tool
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
def test_get_json_schema_has_nullable_args(self):
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
import unittest
|
||||
|
||||
from smolagents import (
|
||||
AgentImage,
|
||||
AgentError,
|
||||
AgentImage,
|
||||
CodeAgent,
|
||||
ToolCallingAgent,
|
||||
stream_to_gradio,
|
||||
|
|
|
@ -19,12 +19,12 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from smolagents import load_tool
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
||||
from smolagents.local_python_executor import (
|
||||
InterpreterError,
|
||||
evaluate_python_code,
|
||||
)
|
||||
from smolagents.types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
|
|
@ -14,21 +14,20 @@
|
|||
# limitations under the License.
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
|
||||
from smolagents.types import (
|
||||
AGENT_TYPE_MAPPING,
|
||||
AgentAudio,
|
||||
AgentImage,
|
||||
AgentText,
|
||||
)
|
||||
from smolagents.tools import Tool, tool, AUTHORIZED_TYPES
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
|
|
@ -18,7 +18,8 @@ import unittest
|
|||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.testing_utils import (
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
|
@ -28,9 +29,7 @@ from transformers.utils import (
|
|||
is_soundfile_availble,
|
||||
)
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||
|
||||
if is_soundfile_availble():
|
||||
import soundfile as sf
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue