Halve import time by removing torch dependency (#147)
* Halve import time by removing torch dependency
This commit is contained in:
		
							parent
							
								
									d8a4b831bb
								
							
						
					
					
						commit
						eca83800e3
					
				|  | @ -13,7 +13,7 @@ jobs: | |||
|       - name: Set up Python | ||||
|         uses: actions/setup-python@v2 | ||||
|         with: | ||||
|           python-version: "3.10" | ||||
|           python-version: "3.12" | ||||
| 
 | ||||
|       # Setup venv | ||||
|       # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed. | ||||
|  |  | |||
|  | @ -48,10 +48,10 @@ Run the line below to install the required dependencies: | |||
| 
 | ||||
| Let's login in order to call the HF Inference API: | ||||
| 
 | ||||
| ```py | ||||
| from huggingface_hub import notebook_login | ||||
| ``` | ||||
| from huggingface_hub import login | ||||
| 
 | ||||
| notebook_login() | ||||
| login() | ||||
| ``` | ||||
| 
 | ||||
| ⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model. | ||||
|  |  | |||
|  | @ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode | |||
| 
 | ||||
| ### Manage your agent's toolbox | ||||
| 
 | ||||
| You can manage an agent's toolbox by adding or replacing a tool. | ||||
| You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary. | ||||
| 
 | ||||
| Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox. | ||||
| 
 | ||||
|  | @ -187,7 +187,7 @@ from smolagents import HfApiModel | |||
| model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") | ||||
| 
 | ||||
| agent = CodeAgent(tools=[], model=model, add_base_tools=True) | ||||
| agent.tools.append(model_download_tool) | ||||
| agent.tools[model_download_tool.name] = model_download_tool | ||||
| ``` | ||||
| Now we can leverage the new tool: | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,9 +12,6 @@ authors = [ | |||
| readme = "README.md" | ||||
| requires-python = ">=3.10" | ||||
| dependencies = [ | ||||
|   "torch", | ||||
|   "torchaudio", | ||||
|   "torchvision", | ||||
|   "transformers>=4.0.0", | ||||
|   "requests>=2.32.3", | ||||
|   "rich>=13.9.4", | ||||
|  | @ -30,10 +27,22 @@ dependencies = [ | |||
| ] | ||||
| 
 | ||||
| [tool.ruff] | ||||
| ignore = ["F403"] | ||||
| lint.ignore = ["F403"] | ||||
| 
 | ||||
| [project.optional-dependencies] | ||||
| dev = [ | ||||
|   "torch", | ||||
|   "torchaudio", | ||||
|   "torchvision", | ||||
|   "sqlalchemy", | ||||
|   "accelerate", | ||||
|   "soundfile", | ||||
|   "litellm>=1.55.10", | ||||
| ] | ||||
| test = [ | ||||
|   "torch", | ||||
|   "torchaudio", | ||||
|   "torchvision", | ||||
|   "pytest>=8.1.0", | ||||
|   "sqlalchemy", | ||||
|   "ruff>=0.5.0", | ||||
|  |  | |||
|  | @ -20,11 +20,9 @@ from dataclasses import dataclass | |||
| from typing import Dict, Optional | ||||
| 
 | ||||
| from huggingface_hub import hf_hub_download, list_spaces | ||||
| from transformers.models.whisper import ( | ||||
|     WhisperForConditionalGeneration, | ||||
|     WhisperProcessor, | ||||
| ) | ||||
| from transformers.utils import is_offline_mode | ||||
| 
 | ||||
| 
 | ||||
| from transformers.utils import is_offline_mode, is_torch_available | ||||
| 
 | ||||
| from .local_python_executor import ( | ||||
|     BASE_BUILTIN_MODULES, | ||||
|  | @ -34,6 +32,15 @@ from .local_python_executor import ( | |||
| from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool | ||||
| from .types import AgentAudio | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|     from transformers.models.whisper import ( | ||||
|         WhisperForConditionalGeneration, | ||||
|         WhisperProcessor, | ||||
|     ) | ||||
| else: | ||||
|     WhisperForConditionalGeneration = object | ||||
|     WhisperProcessor = object | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class PreTool: | ||||
|  |  | |||
|  | @ -22,7 +22,6 @@ from copy import deepcopy | |||
| from enum import Enum | ||||
| from typing import Dict, List, Optional | ||||
| 
 | ||||
| import torch | ||||
| from huggingface_hub import ( | ||||
|     InferenceClient, | ||||
|     ChatCompletionOutputMessage, | ||||
|  | @ -35,6 +34,7 @@ from transformers import ( | |||
|     AutoTokenizer, | ||||
|     StoppingCriteria, | ||||
|     StoppingCriteriaList, | ||||
|     is_torch_available, | ||||
| ) | ||||
| import openai | ||||
| 
 | ||||
|  | @ -147,29 +147,12 @@ class Model: | |||
|         self.last_input_token_count = None | ||||
|         self.last_output_token_count = None | ||||
| 
 | ||||
|     def get_token_counts(self): | ||||
|     def get_token_counts(self) -> Dict[str, int]: | ||||
|         return { | ||||
|             "input_token_count": self.last_input_token_count, | ||||
|             "output_token_count": self.last_output_token_count, | ||||
|         } | ||||
| 
 | ||||
|     def generate( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|     ): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def get_tool_call( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|         available_tools: List[Tool], | ||||
|         stop_sequences, | ||||
|     ): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|  | @ -256,6 +239,10 @@ class HfApiModel(Model): | |||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Gets an LLM output message for the given list of input messages. | ||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||
|         """ | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  | @ -293,6 +280,10 @@ class TransformersModel(Model): | |||
| 
 | ||||
|     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): | ||||
|         super().__init__() | ||||
|         if not is_torch_available(): | ||||
|             raise ImportError("Please install torch in order to use TransformersModel.") | ||||
|         import torch | ||||
| 
 | ||||
|         default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|         if model_id is None: | ||||
|             model_id = default_model_id | ||||
|  |  | |||
|  | @ -27,7 +27,6 @@ from functools import lru_cache, wraps | |||
| from pathlib import Path | ||||
| from typing import Callable, Dict, Optional, Union, get_type_hints | ||||
| 
 | ||||
| import torch | ||||
| from huggingface_hub import ( | ||||
|     create_repo, | ||||
|     get_collection, | ||||
|  | @ -37,7 +36,6 @@ from huggingface_hub import ( | |||
| ) | ||||
| from huggingface_hub.utils import RepositoryNotFoundError | ||||
| from packaging import version | ||||
| from transformers import AutoProcessor | ||||
| from transformers.dynamic_module_utils import get_imports | ||||
| from transformers.utils import ( | ||||
|     TypeHintParsingException, | ||||
|  | @ -54,13 +52,14 @@ from .utils import instance_to_source | |||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| if is_accelerate_available(): | ||||
|     from accelerate import PartialState | ||||
|     from accelerate.utils import send_to_device | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|     pass | ||||
| 
 | ||||
| if is_accelerate_available(): | ||||
|     pass | ||||
| 
 | ||||
|     from transformers import AutoProcessor | ||||
| else: | ||||
|     AutoProcessor = object | ||||
| 
 | ||||
| TOOL_CONFIG_FILE = "tool_config.json" | ||||
| 
 | ||||
|  | @ -1026,8 +1025,6 @@ class PipelineTool(Tool): | |||
|         """ | ||||
|         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. | ||||
|         """ | ||||
|         from accelerate import PartialState | ||||
| 
 | ||||
|         if isinstance(self.pre_processor, str): | ||||
|             self.pre_processor = self.pre_processor_class.from_pretrained( | ||||
|                 self.pre_processor, **self.hub_kwargs | ||||
|  | @ -1066,6 +1063,8 @@ class PipelineTool(Tool): | |||
|         """ | ||||
|         Sends the inputs through the `model`. | ||||
|         """ | ||||
|         import torch | ||||
| 
 | ||||
|         with torch.no_grad(): | ||||
|             return self.model(**inputs) | ||||
| 
 | ||||
|  | @ -1076,6 +1075,8 @@ class PipelineTool(Tool): | |||
|         return self.post_processor(outputs) | ||||
| 
 | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         import torch | ||||
| 
 | ||||
|         args, kwargs = handle_agent_input_types(*args, **kwargs) | ||||
| 
 | ||||
|         if not self.is_initialized: | ||||
|  | @ -1083,9 +1084,6 @@ class PipelineTool(Tool): | |||
| 
 | ||||
|         encoded_inputs = self.encode(*args, **kwargs) | ||||
| 
 | ||||
|         import torch | ||||
|         from accelerate.utils import send_to_device | ||||
| 
 | ||||
|         tensor_inputs = { | ||||
|             k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) | ||||
|         } | ||||
|  |  | |||
|  | @ -22,10 +22,10 @@ from io import BytesIO | |||
| import numpy as np | ||||
| import requests | ||||
| from transformers.utils import ( | ||||
|     is_soundfile_availble, | ||||
|     is_torch_available, | ||||
|     is_vision_available, | ||||
| ) | ||||
| from transformers.utils.import_utils import _is_package_available | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -41,7 +41,7 @@ if is_torch_available(): | |||
| else: | ||||
|     Tensor = object | ||||
| 
 | ||||
| if is_soundfile_availble(): | ||||
| if _is_package_available("soundfile"): | ||||
|     import soundfile as sf | ||||
| 
 | ||||
| 
 | ||||
|  | @ -189,7 +189,7 @@ class AgentAudio(AgentType, str): | |||
|     def __init__(self, value, samplerate=16_000): | ||||
|         super().__init__(value) | ||||
| 
 | ||||
|         if not is_soundfile_availble(): | ||||
|         if not _is_package_available("soundfile"): | ||||
|             raise ImportError("soundfile must be installed in order to handle audio.") | ||||
| 
 | ||||
|         self._path = None | ||||
|  | @ -253,7 +253,7 @@ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAu | |||
| INSTANCE_TYPE_MAPPING = { | ||||
|     str: AgentText, | ||||
|     ImageType: AgentImage, | ||||
|     torch.Tensor: AgentAudio, | ||||
|     Tensor: AgentAudio, | ||||
| } | ||||
| 
 | ||||
| if is_torch_available(): | ||||
|  |  | |||
|  | @ -18,20 +18,19 @@ import unittest | |||
| import uuid | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import torch | ||||
| from PIL import Image | ||||
| from transformers.testing_utils import ( | ||||
|     require_soundfile, | ||||
|     require_torch, | ||||
|     require_vision, | ||||
| ) | ||||
| from transformers.utils import ( | ||||
|     is_soundfile_availble, | ||||
| from transformers.utils.import_utils import ( | ||||
|     _is_package_available, | ||||
| ) | ||||
| 
 | ||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | ||||
| 
 | ||||
| if is_soundfile_availble(): | ||||
| if _is_package_available("soundfile"): | ||||
|     import soundfile as sf | ||||
| 
 | ||||
| 
 | ||||
|  | @ -44,6 +43,8 @@ def get_new_path(suffix="") -> str: | |||
| @require_torch | ||||
| class AgentAudioTests(unittest.TestCase): | ||||
|     def test_from_tensor(self): | ||||
|         import torch | ||||
| 
 | ||||
|         tensor = torch.rand(12, dtype=torch.float64) - 0.5 | ||||
|         agent_type = AgentAudio(tensor) | ||||
|         path = str(agent_type.to_string()) | ||||
|  | @ -61,6 +62,8 @@ class AgentAudioTests(unittest.TestCase): | |||
|         self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) | ||||
| 
 | ||||
|     def test_from_string(self): | ||||
|         import torch | ||||
| 
 | ||||
|         tensor = torch.rand(12, dtype=torch.float64) - 0.5 | ||||
|         path = get_new_path(suffix=".wav") | ||||
|         sf.write(path, tensor, 16000) | ||||
|  | @ -75,6 +78,8 @@ class AgentAudioTests(unittest.TestCase): | |||
| @require_torch | ||||
| class AgentImageTests(unittest.TestCase): | ||||
|     def test_from_tensor(self): | ||||
|         import torch | ||||
| 
 | ||||
|         tensor = torch.randint(0, 256, (64, 64, 3)) | ||||
|         agent_type = AgentImage(tensor) | ||||
|         path = str(agent_type.to_string()) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue