Sort imports and add test workflows
This commit is contained in:
		
							parent
							
								
									417c6685b0
								
							
						
					
					
						commit
						c22fedaee1
					
				|  | @ -15,19 +15,21 @@ jobs: | ||||||
|         with: |         with: | ||||||
|           python-version: 3.10 |           python-version: 3.10 | ||||||
| 
 | 
 | ||||||
|       - name: Install Python dependencies |       # Setup venv | ||||||
|         run: pip install -e .[quality] |       # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. | ||||||
| 
 |       - name: Setup venv + uv | ||||||
|       - name: Run Quality check |  | ||||||
|         run: make quality |  | ||||||
|       - name: Check if failure |  | ||||||
|         if: ${{ failure() }} |  | ||||||
|         run: | |         run: | | ||||||
|           echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY |           pip install --upgrade uv | ||||||
|  |           uv venv | ||||||
| 
 | 
 | ||||||
|       - name: Run Style check |       - name: Install dependencies | ||||||
|         run: make style |         run: uv pip install "smolagents[test] @ ." | ||||||
|       - name: Check if failure |       - run: uv run ruff check tests src # linter | ||||||
|         if: ${{ failure() }} |       - run: uv run ruff format --check tests src # formatter | ||||||
|         run: | | 
 | ||||||
|           echo "Style check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and rerun 'make style; make quality;'" >> $GITHUB_STEP_SUMMARY |       # Run type checking at least on smolagents root file to check all modules | ||||||
|  |       # that can be lazy-loaded actually exist. | ||||||
|  |       # - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback | ||||||
|  | 
 | ||||||
|  |       # Run mypy on full package | ||||||
|  |       # - run: uv run mypy src | ||||||
|  | @ -0,0 +1,47 @@ | ||||||
|  | name: Python tests | ||||||
|  | 
 | ||||||
|  | on: [pull_request] | ||||||
|  | 
 | ||||||
|  | jobs: | ||||||
|  |   build-ubuntu: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     env: | ||||||
|  |       UV_HTTP_TIMEOUT: 600 # max 10min to install deps | ||||||
|  | 
 | ||||||
|  |     strategy: | ||||||
|  |       fail-fast: false | ||||||
|  |       matrix: | ||||||
|  |         python-version: ["3.10", "3.12"] | ||||||
|  | 
 | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v2 | ||||||
|  |       - name: Set up Python ${{ matrix.python-version }} | ||||||
|  |         uses: actions/setup-python@v2 | ||||||
|  |         with: | ||||||
|  |           python-version: ${{ matrix.python-version }} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |       # Setup venv | ||||||
|  |       # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. | ||||||
|  |       - name: Setup venv + uv | ||||||
|  |         run: | | ||||||
|  |           pip install --upgrade uv | ||||||
|  |           uv venv | ||||||
|  | 
 | ||||||
|  |       # Install dependencies | ||||||
|  |       - name: Install dependencies | ||||||
|  |         run: | | ||||||
|  |           uv pip install "smolagents[test] @ ." | ||||||
|  | 
 | ||||||
|  |       - name: Agent tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_agents.py | ||||||
|  |       - name: Tool tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_toolss.py | ||||||
|  |       - name: Python interpreter tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_python_interpreter.py | ||||||
|  |       - name: Final answer tests | ||||||
|  |         run: | | ||||||
|  |           uv run pytest -sv ./tests/test_final_answer.py | ||||||
|  | @ -346,7 +346,7 @@ | ||||||
|     "import glob\n", |     "import glob\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "res = []\n", |     "res = []\n", | ||||||
|     "for f in glob.glob(f\"output/*.jsonl\"):\n", |     "for f in glob.glob(\"output/*.jsonl\"):\n", | ||||||
|     "    res.append(pd.read_json(f, lines=True))\n", |     "    res.append(pd.read_json(f, lines=True))\n", | ||||||
|     "result_df = pd.concat(res)\n", |     "result_df = pd.concat(res)\n", | ||||||
|     "\n", |     "\n", | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| from smolagents.agents import ToolCallingAgent | from smolagents.agents import ToolCallingAgent | ||||||
| from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel | from smolagents import tool, LiteLLMModel | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| # Choose which LLM engine to use! | # Choose which LLM engine to use! | ||||||
|  |  | ||||||
|  | @ -12,25 +12,29 @@ authors = [ | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| requires-python = ">=3.10" | requires-python = ">=3.10" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|     "torch", |   "torch", | ||||||
|     "torchaudio", |   "torchaudio", | ||||||
|     "torchvision", |   "torchvision", | ||||||
|     "transformers>=4.0.0", |   "transformers>=4.0.0", | ||||||
|     "requests>=2.32.3", |   "requests>=2.32.3", | ||||||
|     "rich>=13.9.4", |   "rich>=13.9.4", | ||||||
|     "pandas>=2.2.3", |   "pandas>=2.2.3", | ||||||
|     "jinja2>=3.1.4", |   "jinja2>=3.1.4", | ||||||
|     "pillow>=11.0.0", |   "pillow>=11.0.0", | ||||||
|     "markdownify>=0.14.1", |   "markdownify>=0.14.1", | ||||||
|     "gradio>=5.8.0", |   "gradio>=5.8.0", | ||||||
|     "duckduckgo-search>=6.3.7", |   "duckduckgo-search>=6.3.7", | ||||||
|     "python-dotenv>=1.0.1", |   "python-dotenv>=1.0.1", | ||||||
|     "e2b-code-interpreter>=1.0.3", |   "e2b-code-interpreter>=1.0.3", | ||||||
|     "litellm>=1.55.10", |   "litellm>=1.55.10", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [tool.ruff] | ||||||
|  | ignore = ["F403"] | ||||||
|  | 
 | ||||||
| [project.optional-dependencies] | [project.optional-dependencies] | ||||||
| test = [ | test = [ | ||||||
|     "pytest>=8.1.0", |   "pytest>=8.1.0", | ||||||
|     "sqlalchemy" |   "sqlalchemy", | ||||||
|  |   "ruff>=0.5.0", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ def execute_code(code): | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         exec(code, exec_globals, exec_locals) |         exec(code, exec_globals, exec_locals) | ||||||
|     except Exception as e: |     except Exception: | ||||||
|         traceback.print_exc(file=stderr) |         traceback.print_exc(file=stderr) | ||||||
|      |      | ||||||
|     output = stdout.getvalue() |     output = stdout.getvalue() | ||||||
|  |  | ||||||
|  | @ -21,14 +21,13 @@ from typing import TYPE_CHECKING | ||||||
| from transformers.utils import _LazyModule | from transformers.utils import _LazyModule | ||||||
| from transformers.utils.import_utils import define_import_structure | from transformers.utils.import_utils import define_import_structure | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from .agents import * |     from .agents import * | ||||||
|     from .default_tools import * |     from .default_tools import * | ||||||
|     from .gradio_ui import * |  | ||||||
|     from .models import * |  | ||||||
|     from .local_python_executor import * |  | ||||||
|     from .e2b_executor import * |     from .e2b_executor import * | ||||||
|  |     from .gradio_ui import * | ||||||
|  |     from .local_python_executor import * | ||||||
|  |     from .models import * | ||||||
|     from .monitoring import * |     from .monitoring import * | ||||||
|     from .prompts import * |     from .prompts import * | ||||||
|     from .tools import * |     from .tools import * | ||||||
|  |  | ||||||
|  | @ -15,49 +15,50 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import time | import time | ||||||
| from typing import Any, Callable, Dict, List, Optional, Union, Tuple |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from rich.syntax import Syntax | from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||
|  | 
 | ||||||
| from rich.console import Group | from rich.console import Group | ||||||
| from rich.panel import Panel | from rich.panel import Panel | ||||||
| from rich.rule import Rule | from rich.rule import Rule | ||||||
|  | from rich.syntax import Syntax | ||||||
| from rich.text import Text | from rich.text import Text | ||||||
| 
 | 
 | ||||||
| from .utils import ( |  | ||||||
|     console, |  | ||||||
|     parse_code_blob, |  | ||||||
|     parse_json_tool_call, |  | ||||||
|     truncate_content, |  | ||||||
|     AgentError, |  | ||||||
|     AgentParsingError, |  | ||||||
|     AgentExecutionError, |  | ||||||
|     AgentGenerationError, |  | ||||||
|     AgentMaxStepsError, |  | ||||||
| ) |  | ||||||
| from .types import AgentAudio, AgentImage, handle_agent_output_types |  | ||||||
| from .default_tools import FinalAnswerTool | from .default_tools import FinalAnswerTool | ||||||
|  | from .e2b_executor import E2BExecutor | ||||||
|  | from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter | ||||||
| from .models import MessageRole | from .models import MessageRole | ||||||
| from .monitoring import Monitor | from .monitoring import Monitor | ||||||
| from .prompts import ( | from .prompts import ( | ||||||
|     CODE_SYSTEM_PROMPT, |     CODE_SYSTEM_PROMPT, | ||||||
|     TOOL_CALLING_SYSTEM_PROMPT, |     MANAGED_AGENT_PROMPT, | ||||||
|     PLAN_UPDATE_FINAL_PLAN_REDACTION, |     PLAN_UPDATE_FINAL_PLAN_REDACTION, | ||||||
|     SYSTEM_PROMPT_FACTS, |     SYSTEM_PROMPT_FACTS, | ||||||
|     SYSTEM_PROMPT_FACTS_UPDATE, |     SYSTEM_PROMPT_FACTS_UPDATE, | ||||||
|     USER_PROMPT_FACTS_UPDATE, |  | ||||||
|     USER_PROMPT_PLAN_UPDATE, |  | ||||||
|     USER_PROMPT_PLAN, |  | ||||||
|     SYSTEM_PROMPT_PLAN_UPDATE, |  | ||||||
|     SYSTEM_PROMPT_PLAN, |     SYSTEM_PROMPT_PLAN, | ||||||
|     MANAGED_AGENT_PROMPT, |     SYSTEM_PROMPT_PLAN_UPDATE, | ||||||
|  |     TOOL_CALLING_SYSTEM_PROMPT, | ||||||
|  |     USER_PROMPT_FACTS_UPDATE, | ||||||
|  |     USER_PROMPT_PLAN, | ||||||
|  |     USER_PROMPT_PLAN_UPDATE, | ||||||
| ) | ) | ||||||
| from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter |  | ||||||
| from .e2b_executor import E2BExecutor |  | ||||||
| from .tools import ( | from .tools import ( | ||||||
|     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, |     DEFAULT_TOOL_DESCRIPTION_TEMPLATE, | ||||||
|     Tool, |     Tool, | ||||||
|     get_tool_description_with_args, |  | ||||||
|     Toolbox, |     Toolbox, | ||||||
|  |     get_tool_description_with_args, | ||||||
|  | ) | ||||||
|  | from .types import AgentAudio, AgentImage, handle_agent_output_types | ||||||
|  | from .utils import ( | ||||||
|  |     AgentError, | ||||||
|  |     AgentExecutionError, | ||||||
|  |     AgentGenerationError, | ||||||
|  |     AgentMaxStepsError, | ||||||
|  |     AgentParsingError, | ||||||
|  |     console, | ||||||
|  |     parse_code_blob, | ||||||
|  |     parse_json_tool_call, | ||||||
|  |     truncate_content, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,20 +18,20 @@ import json | ||||||
| import re | import re | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Dict, Optional | from typing import Dict, Optional | ||||||
| from huggingface_hub import hf_hub_download, list_spaces |  | ||||||
| 
 | 
 | ||||||
| from transformers.utils import is_offline_mode | from huggingface_hub import hf_hub_download, list_spaces | ||||||
| from transformers.models.whisper import ( | from transformers.models.whisper import ( | ||||||
|     WhisperProcessor, |  | ||||||
|     WhisperForConditionalGeneration, |     WhisperForConditionalGeneration, | ||||||
|  |     WhisperProcessor, | ||||||
| ) | ) | ||||||
|  | from transformers.utils import is_offline_mode | ||||||
| 
 | 
 | ||||||
| from .local_python_executor import ( | from .local_python_executor import ( | ||||||
|     BASE_BUILTIN_MODULES, |     BASE_BUILTIN_MODULES, | ||||||
|     BASE_PYTHON_TOOLS, |     BASE_PYTHON_TOOLS, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
| ) | ) | ||||||
| from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool | from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | ||||||
| from .types import AgentAudio | from .types import AgentAudio | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -271,8 +271,8 @@ class VisitWebpageTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def forward(self, url: str) -> str: |     def forward(self, url: str) -> str: | ||||||
|         try: |         try: | ||||||
|             from markdownify import markdownify |  | ||||||
|             import requests |             import requests | ||||||
|  |             from markdownify import markdownify | ||||||
|             from requests.exceptions import RequestException |             from requests.exceptions import RequestException | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             raise ImportError( |             raise ImportError( | ||||||
|  |  | ||||||
|  | @ -14,18 +14,19 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from dotenv import load_dotenv |  | ||||||
| import textwrap |  | ||||||
| import base64 | import base64 | ||||||
| import pickle | import pickle | ||||||
|  | import textwrap | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
|  | from typing import Any, List, Tuple | ||||||
|  | 
 | ||||||
|  | from dotenv import load_dotenv | ||||||
|  | from e2b_code_interpreter import Sandbox | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
| from e2b_code_interpreter import Sandbox |  | ||||||
| from typing import List, Tuple, Any |  | ||||||
| from .tool_validation import validate_tool_attributes | from .tool_validation import validate_tool_attributes | ||||||
| from .utils import instance_to_source, BASE_BUILTIN_MODULES, console |  | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
|  | from .utils import BASE_BUILTIN_MODULES, console, instance_to_source | ||||||
| 
 | 
 | ||||||
| load_dotenv() | load_dotenv() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,10 +14,11 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types |  | ||||||
| from .agents import MultiStepAgent, AgentStep, ActionStep |  | ||||||
| import gradio as gr | import gradio as gr | ||||||
| 
 | 
 | ||||||
|  | from .agents import ActionStep, AgentStep, MultiStepAgent | ||||||
|  | from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|  |  | ||||||
|  | @ -17,14 +17,15 @@ | ||||||
| import ast | import ast | ||||||
| import builtins | import builtins | ||||||
| import difflib | import difflib | ||||||
|  | import math | ||||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple | from typing import Any, Callable, Dict, List, Optional, Tuple | ||||||
| import math | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import pandas as pd | import pandas as pd | ||||||
| 
 | 
 | ||||||
| from .utils import truncate_content, BASE_BUILTIN_MODULES | from .utils import BASE_BUILTIN_MODULES, truncate_content | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class InterpreterError(ValueError): | class InterpreterError(ValueError): | ||||||
|  |  | ||||||
|  | @ -14,24 +14,23 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from copy import deepcopy |  | ||||||
| from enum import Enum |  | ||||||
| import json | import json | ||||||
| from typing import Dict, List, Optional |  | ||||||
| from transformers import ( |  | ||||||
|     AutoTokenizer, |  | ||||||
|     AutoModelForCausalLM, |  | ||||||
|     StoppingCriteria, |  | ||||||
|     StoppingCriteriaList, |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| import litellm |  | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
| import torch | from copy import deepcopy | ||||||
|  | from enum import Enum | ||||||
|  | from typing import Dict, List, Optional, Tuple, Union | ||||||
| 
 | 
 | ||||||
|  | import litellm | ||||||
|  | import torch | ||||||
| from huggingface_hub import InferenceClient | from huggingface_hub import InferenceClient | ||||||
|  | from transformers import ( | ||||||
|  |     AutoModelForCausalLM, | ||||||
|  |     AutoTokenizer, | ||||||
|  |     StoppingCriteria, | ||||||
|  |     StoppingCriteriaList, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
| from .utils import parse_json_tool_call | from .utils import parse_json_tool_call | ||||||
|  | @ -352,16 +351,16 @@ class TransformersModel(Model): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Get LLM output |         # Get LLM output | ||||||
|         prompt = self.tokenizer.apply_chat_template( |         prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|             messages, |             messages, | ||||||
|             return_tensors="pt", |             return_tensors="pt", | ||||||
|             return_dict=True, |             return_dict=True, | ||||||
|         ) |         ) | ||||||
|         prompt = prompt.to(self.model.device) |         prompt_tensor = prompt_tensor.to(self.model.device) | ||||||
|         count_prompt_tokens = prompt["input_ids"].shape[1] |         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] | ||||||
| 
 | 
 | ||||||
|         out = self.model.generate( |         out = self.model.generate( | ||||||
|             **prompt, |             **prompt_tensor, | ||||||
|             max_new_tokens=max_tokens, |             max_new_tokens=max_tokens, | ||||||
|             stopping_criteria=( |             stopping_criteria=( | ||||||
|                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None |                 self.make_stopping_criteria(stop_sequences) if stop_sequences else None | ||||||
|  | @ -383,7 +382,7 @@ class TransformersModel(Model): | ||||||
|         available_tools: List[Tool], |         available_tools: List[Tool], | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         max_tokens: int = 500, |         max_tokens: int = 500, | ||||||
|     ) -> str: |     ) -> Tuple[str, Union[str, None], str]: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -14,9 +14,10 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from .utils import console |  | ||||||
| from rich.text import Text | from rich.text import Text | ||||||
| 
 | 
 | ||||||
|  | from .utils import console | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class Monitor: | class Monitor: | ||||||
|     def __init__(self, tracked_model): |     def __init__(self, tracked_model): | ||||||
|  |  | ||||||
|  | @ -1,8 +1,9 @@ | ||||||
| import ast | import ast | ||||||
| import inspect |  | ||||||
| import builtins | import builtins | ||||||
| from typing import Set | import inspect | ||||||
| import textwrap | import textwrap | ||||||
|  | from typing import Set | ||||||
|  | 
 | ||||||
| from .utils import BASE_BUILTIN_MODULES | from .utils import BASE_BUILTIN_MODULES | ||||||
| 
 | 
 | ||||||
| _BUILTIN_NAMES = set(vars(builtins)) | _BUILTIN_NAMES = set(vars(builtins)) | ||||||
|  |  | ||||||
|  | @ -18,14 +18,16 @@ import ast | ||||||
| import importlib | import importlib | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
|  | import logging | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
| import tempfile | import tempfile | ||||||
| import torch |  | ||||||
| import textwrap | import textwrap | ||||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Callable, Dict, List, Optional, Union, get_type_hints | from typing import Callable, Dict, List, Optional, Union, get_type_hints | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
| from huggingface_hub import ( | from huggingface_hub import ( | ||||||
|     create_repo, |     create_repo, | ||||||
|     get_collection, |     get_collection, | ||||||
|  | @ -35,7 +37,8 @@ from huggingface_hub import ( | ||||||
| ) | ) | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError | from huggingface_hub.utils import RepositoryNotFoundError | ||||||
| from packaging import version | from packaging import version | ||||||
| import logging | from transformers import AutoProcessor | ||||||
|  | from transformers.dynamic_module_utils import get_imports | ||||||
| from transformers.utils import ( | from transformers.utils import ( | ||||||
|     TypeHintParsingException, |     TypeHintParsingException, | ||||||
|     cached_file, |     cached_file, | ||||||
|  | @ -45,13 +48,9 @@ from transformers.utils import ( | ||||||
| ) | ) | ||||||
| from transformers.utils.chat_template_utils import _parse_type_hint | from transformers.utils.chat_template_utils import _parse_type_hint | ||||||
| 
 | 
 | ||||||
| from transformers.dynamic_module_utils import get_imports | from .tool_validation import MethodChecker, validate_tool_attributes | ||||||
| from transformers import AutoProcessor |  | ||||||
| 
 |  | ||||||
| from .types import ImageType, handle_agent_input_types, handle_agent_output_types | from .types import ImageType, handle_agent_input_types, handle_agent_output_types | ||||||
| from .utils import instance_to_source | from .utils import instance_to_source | ||||||
| from .tool_validation import validate_tool_attributes, MethodChecker |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,21 +12,20 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import logging | ||||||
| import os | import os | ||||||
| import pathlib | import pathlib | ||||||
| import tempfile | import tempfile | ||||||
| import uuid | import uuid | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| import requests |  | ||||||
| import numpy as np |  | ||||||
| 
 | 
 | ||||||
|  | import numpy as np | ||||||
|  | import requests | ||||||
| from transformers.utils import ( | from transformers.utils import ( | ||||||
|     is_soundfile_availble, |     is_soundfile_availble, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
|     is_vision_available, |     is_vision_available, | ||||||
| ) | ) | ||||||
| import logging |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,14 +14,14 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import ast | ||||||
|  | import inspect | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| from typing import Tuple, Dict, Union |  | ||||||
| import ast |  | ||||||
| from rich.console import Console |  | ||||||
| import inspect |  | ||||||
| import types | import types | ||||||
|  | from typing import Dict, Tuple, Union | ||||||
| 
 | 
 | ||||||
|  | from rich.console import Console | ||||||
| from transformers.utils.import_utils import _is_package_available | from transformers.utils.import_utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| _pygments_available = _is_package_available("pygments") | _pygments_available = _is_package_available("pygments") | ||||||
|  |  | ||||||
|  | @ -16,22 +16,22 @@ import os | ||||||
| import tempfile | import tempfile | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
| import pytest |  | ||||||
| 
 |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from smolagents.types import AgentText, AgentImage | import pytest | ||||||
|  | from transformers.testing_utils import get_tests_dir | ||||||
|  | 
 | ||||||
| from smolagents.agents import ( | from smolagents.agents import ( | ||||||
|     AgentMaxStepsError, |     AgentMaxStepsError, | ||||||
|     ManagedAgent, |  | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     ToolCallingAgent, |     ManagedAgent, | ||||||
|     Toolbox, |     Toolbox, | ||||||
|     ToolCall, |     ToolCall, | ||||||
|  |     ToolCallingAgent, | ||||||
| ) | ) | ||||||
| from smolagents.tools import tool |  | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from transformers.testing_utils import get_tests_dir | from smolagents.tools import tool | ||||||
|  | from smolagents.types import AgentImage, AgentText | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|  |  | ||||||
|  | @ -17,12 +17,13 @@ import ast | ||||||
| import os | import os | ||||||
| import re | import re | ||||||
| import shutil | import shutil | ||||||
| import tempfile |  | ||||||
| import subprocess | import subprocess | ||||||
|  | import tempfile | ||||||
| import traceback | import traceback | ||||||
| import pytest |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import List | from typing import List | ||||||
|  | 
 | ||||||
|  | import pytest | ||||||
| from dotenv import load_dotenv | from dotenv import load_dotenv | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,16 +18,14 @@ from pathlib import Path | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 |  | ||||||
| from transformers import is_torch_available | from transformers import is_torch_available | ||||||
| from transformers.testing_utils import get_tests_dir, require_torch | from transformers.testing_utils import get_tests_dir, require_torch | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING |  | ||||||
| 
 | 
 | ||||||
| from smolagents.default_tools import FinalAnswerTool | from smolagents.default_tools import FinalAnswerTool | ||||||
|  | from smolagents.types import AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
| from .test_tools import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|     import torch |     import torch | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -13,9 +13,10 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
| from smolagents import models, tool |  | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
|  | from smolagents import models, tool | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class ModelTests(unittest.TestCase): | class ModelTests(unittest.TestCase): | ||||||
|     def test_get_json_schema_has_nullable_args(self): |     def test_get_json_schema_has_nullable_args(self): | ||||||
|  |  | ||||||
|  | @ -16,8 +16,8 @@ | ||||||
| import unittest | import unittest | ||||||
| 
 | 
 | ||||||
| from smolagents import ( | from smolagents import ( | ||||||
|     AgentImage, |  | ||||||
|     AgentError, |     AgentError, | ||||||
|  |     AgentImage, | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
|     stream_to_gradio, |     stream_to_gradio, | ||||||
|  |  | ||||||
|  | @ -19,12 +19,12 @@ import numpy as np | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from smolagents import load_tool | from smolagents import load_tool | ||||||
| from smolagents.types import AGENT_TYPE_MAPPING |  | ||||||
| from smolagents.default_tools import BASE_PYTHON_TOOLS | from smolagents.default_tools import BASE_PYTHON_TOOLS | ||||||
| from smolagents.local_python_executor import ( | from smolagents.local_python_executor import ( | ||||||
|     InterpreterError, |     InterpreterError, | ||||||
|     evaluate_python_code, |     evaluate_python_code, | ||||||
| ) | ) | ||||||
|  | from smolagents.types import AGENT_TYPE_MAPPING | ||||||
| 
 | 
 | ||||||
| from .test_tools import ToolTesterMixin | from .test_tools import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,21 +14,20 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Dict, Union, Optional | from typing import Dict, Optional, Union | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest | import pytest | ||||||
| 
 |  | ||||||
| from transformers import is_torch_available, is_vision_available | from transformers import is_torch_available, is_vision_available | ||||||
|  | from transformers.testing_utils import get_tests_dir | ||||||
|  | 
 | ||||||
|  | from smolagents.tools import AUTHORIZED_TYPES, Tool, tool | ||||||
| from smolagents.types import ( | from smolagents.types import ( | ||||||
|     AGENT_TYPE_MAPPING, |     AGENT_TYPE_MAPPING, | ||||||
|     AgentAudio, |     AgentAudio, | ||||||
|     AgentImage, |     AgentImage, | ||||||
|     AgentText, |     AgentText, | ||||||
| ) | ) | ||||||
| from smolagents.tools import Tool, tool, AUTHORIZED_TYPES |  | ||||||
| from transformers.testing_utils import get_tests_dir |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): | if is_torch_available(): | ||||||
|     import torch |     import torch | ||||||
|  |  | ||||||
|  | @ -18,7 +18,8 @@ import unittest | ||||||
| import uuid | import uuid | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | import torch | ||||||
|  | from PIL import Image | ||||||
| from transformers.testing_utils import ( | from transformers.testing_utils import ( | ||||||
|     require_soundfile, |     require_soundfile, | ||||||
|     require_torch, |     require_torch, | ||||||
|  | @ -28,9 +29,7 @@ from transformers.utils import ( | ||||||
|     is_soundfile_availble, |     is_soundfile_availble, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| import torch | from smolagents.types import AgentAudio, AgentImage, AgentText | ||||||
| from PIL import Image |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if is_soundfile_availble(): | if is_soundfile_availble(): | ||||||
|     import soundfile as sf |     import soundfile as sf | ||||||
|  |  | ||||||
|  | @ -1,8 +1,7 @@ | ||||||
| import os | import os | ||||||
| import unittest |  | ||||||
| import shutil | import shutil | ||||||
| import tempfile | import tempfile | ||||||
| 
 | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue