Remove dependency on _is_package_available from transformers (#247)
This commit is contained in:
		
							parent
							
								
									6db75183ff
								
							
						
					
					
						commit
						58b18f5655
					
				|  | @ -32,7 +32,6 @@ from transformers import ( | ||||||
|     StoppingCriteriaList, |     StoppingCriteriaList, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
| ) | ) | ||||||
| from transformers.utils.import_utils import _is_package_available |  | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
| 
 | 
 | ||||||
|  | @ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | ||||||
|     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", |     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| if _is_package_available("litellm"): |  | ||||||
|     import litellm |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def get_dict_from_nested_dataclasses(obj): | def get_dict_from_nested_dataclasses(obj): | ||||||
|     def convert(obj): |     def convert(obj): | ||||||
|  | @ -508,9 +504,11 @@ class LiteLLMModel(Model): | ||||||
|         api_key=None, |         api_key=None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         if not _is_package_available("litellm"): |         try: | ||||||
|             raise ImportError( |             import litellm | ||||||
|                 "litellm not found. Install it with `pip install litellm`" |         except ModuleNotFoundError: | ||||||
|  |             raise ModuleNotFoundError( | ||||||
|  |                 "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" | ||||||
|             ) |             ) | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|  | @ -530,6 +528,8 @@ class LiteLLMModel(Model): | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  |         import litellm | ||||||
|  | 
 | ||||||
|         if tools_to_call_from: |         if tools_to_call_from: | ||||||
|             response = litellm.completion( |             response = litellm.completion( | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|  |  | ||||||
|  | @ -12,6 +12,7 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | import importlib.util | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import pathlib | import pathlib | ||||||
|  | @ -25,7 +26,6 @@ from transformers.utils import ( | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
|     is_vision_available, |     is_vision_available, | ||||||
| ) | ) | ||||||
| from transformers.utils.import_utils import _is_package_available |  | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  | @ -41,9 +41,6 @@ if is_torch_available(): | ||||||
| else: | else: | ||||||
|     Tensor = object |     Tensor = object | ||||||
| 
 | 
 | ||||||
| if _is_package_available("soundfile"): |  | ||||||
|     import soundfile as sf |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class AgentType: | class AgentType: | ||||||
|     """ |     """ | ||||||
|  | @ -187,11 +184,12 @@ class AgentAudio(AgentType, str): | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, value, samplerate=16_000): |     def __init__(self, value, samplerate=16_000): | ||||||
|  |         if importlib.util.find_spec("soundfile") is None: | ||||||
|  |             raise ModuleNotFoundError( | ||||||
|  |                 "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" | ||||||
|  |             ) | ||||||
|         super().__init__(value) |         super().__init__(value) | ||||||
| 
 | 
 | ||||||
|         if not _is_package_available("soundfile"): |  | ||||||
|             raise ImportError("soundfile must be installed in order to handle audio.") |  | ||||||
| 
 |  | ||||||
|         self._path = None |         self._path = None | ||||||
|         self._tensor = None |         self._tensor = None | ||||||
| 
 | 
 | ||||||
|  | @ -221,6 +219,8 @@ class AgentAudio(AgentType, str): | ||||||
|         """ |         """ | ||||||
|         Returns the "raw" version of that object. It is a `torch.Tensor` object. |         Returns the "raw" version of that object. It is a `torch.Tensor` object. | ||||||
|         """ |         """ | ||||||
|  |         import soundfile as sf | ||||||
|  | 
 | ||||||
|         if self._tensor is not None: |         if self._tensor is not None: | ||||||
|             return self._tensor |             return self._tensor | ||||||
| 
 | 
 | ||||||
|  | @ -239,6 +239,8 @@ class AgentAudio(AgentType, str): | ||||||
|         Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized |         Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized | ||||||
|         version of the audio. |         version of the audio. | ||||||
|         """ |         """ | ||||||
|  |         import soundfile as sf | ||||||
|  | 
 | ||||||
|         if self._path is not None: |         if self._path is not None: | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import ast | import ast | ||||||
|  | import importlib.util | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
|  | @ -22,13 +23,10 @@ import types | ||||||
| from typing import Dict, Tuple, Union | from typing import Dict, Tuple, Union | ||||||
| 
 | 
 | ||||||
| from rich.console import Console | from rich.console import Console | ||||||
| from transformers.utils.import_utils import _is_package_available |  | ||||||
| 
 |  | ||||||
| _pygments_available = _is_package_available("pygments") |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_pygments_available(): | def is_pygments_available(): | ||||||
|     return _pygments_available |     return importlib.util.find_spec("soundfile") is not None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| console = Console(width=200) | console = Console(width=200) | ||||||
|  |  | ||||||
|  | @ -24,15 +24,9 @@ from transformers.testing_utils import ( | ||||||
|     require_torch, |     require_torch, | ||||||
|     require_vision, |     require_vision, | ||||||
| ) | ) | ||||||
| from transformers.utils.import_utils import ( |  | ||||||
|     _is_package_available, |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| from smolagents.types import AgentAudio, AgentImage, AgentText | from smolagents.types import AgentAudio, AgentImage, AgentText | ||||||
| 
 | 
 | ||||||
| if _is_package_available("soundfile"): |  | ||||||
|     import soundfile as sf |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|     directory = tempfile.mkdtemp() |     directory = tempfile.mkdtemp() | ||||||
|  | @ -43,6 +37,7 @@ def get_new_path(suffix="") -> str: | ||||||
| @require_torch | @require_torch | ||||||
| class AgentAudioTests(unittest.TestCase): | class AgentAudioTests(unittest.TestCase): | ||||||
|     def test_from_tensor(self): |     def test_from_tensor(self): | ||||||
|  |         import soundfile as sf | ||||||
|         import torch |         import torch | ||||||
| 
 | 
 | ||||||
|         tensor = torch.rand(12, dtype=torch.float64) - 0.5 |         tensor = torch.rand(12, dtype=torch.float64) - 0.5 | ||||||
|  | @ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase): | ||||||
|         self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) |         self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) | ||||||
| 
 | 
 | ||||||
|     def test_from_string(self): |     def test_from_string(self): | ||||||
|  |         import soundfile as sf | ||||||
|         import torch |         import torch | ||||||
| 
 | 
 | ||||||
|         tensor = torch.rand(12, dtype=torch.float64) - 0.5 |         tensor = torch.rand(12, dtype=torch.float64) - 0.5 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue