MLX model support (#300)
This commit is contained in:
		
							parent
							
								
									bca3a9bc13
								
							
						
					
					
						commit
						9b96199d00
					
				|  | @ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments: | ||||||
|     - [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub. |     - [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub. | ||||||
|     - [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)! |     - [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)! | ||||||
|     - [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service). |     - [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service). | ||||||
|  |     - [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine. | ||||||
| 
 | 
 | ||||||
| - `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. | - `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. | ||||||
| 
 | 
 | ||||||
| Once you have these two arguments, `tools` and `model`,  you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service). | Once you have these two arguments, `tools` and `model`,  you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/). | ||||||
| 
 | 
 | ||||||
| <hfoptions id="Pick a LLM"> | <hfoptions id="Pick a LLM"> | ||||||
| <hfoption id="HF Inference API"> | <hfoption id="HF Inference API"> | ||||||
|  | @ -148,6 +149,19 @@ agent.run( | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | </hfoption> | ||||||
|  | <hfoption id="mlx-lm"> | ||||||
|  | 
 | ||||||
|  | ```python | ||||||
|  | # !pip install smolagents[mlx-lm] | ||||||
|  | from smolagents import CodeAgent, MLXModel | ||||||
|  | 
 | ||||||
|  | mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit") | ||||||
|  | agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True) | ||||||
|  | 
 | ||||||
|  | agent.run("Could you give me the 118th number in the Fibonacci sequence?") | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
| </hfoption> | </hfoption> | ||||||
| </hfoptions> | </hfoptions> | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -148,3 +148,22 @@ model = AzureOpenAIServerModel( | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| [[autodoc]] AzureOpenAIServerModel | [[autodoc]] AzureOpenAIServerModel | ||||||
|  | 
 | ||||||
|  | ### MLXModel | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ```python | ||||||
|  | from smolagents import MLXModel | ||||||
|  | 
 | ||||||
|  | model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct") | ||||||
|  | 
 | ||||||
|  | print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) | ||||||
|  | ``` | ||||||
|  | ```text | ||||||
|  | >>> What a | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | > [!TIP] | ||||||
|  | > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. | ||||||
|  | 
 | ||||||
|  | [[autodoc]] MLXModel | ||||||
|  |  | ||||||
|  | @ -46,6 +46,9 @@ mcp = [ | ||||||
|   "mcpadapt>=0.0.6", |   "mcpadapt>=0.0.6", | ||||||
|   "mcp", |   "mcp", | ||||||
| ] | ] | ||||||
|  | mlx-lm = [ | ||||||
|  |   "mlx-lm" | ||||||
|  | ] | ||||||
| openai = [ | openai = [ | ||||||
|   "openai>=1.58.1" |   "openai>=1.58.1" | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ import json | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
|  | import uuid | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from dataclasses import asdict, dataclass | from dataclasses import asdict, dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
|  | @ -415,6 +416,128 @@ class HfApiModel(Model): | ||||||
|         return message |         return message | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class MLXModel(Model): | ||||||
|  |     """A class to interact with models loaded using MLX on Apple silicon. | ||||||
|  | 
 | ||||||
|  |     > [!TIP] | ||||||
|  |     > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. | ||||||
|  | 
 | ||||||
|  |     Parameters: | ||||||
|  |         model_id (str): | ||||||
|  |             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||||
|  |         tool_name_key (str): | ||||||
|  |             The key, which can usually be found in the model's chat template, for retrieving a tool name. | ||||||
|  |         tool_arguments_key (str): | ||||||
|  |             The key, which can usually be found in the model's chat template, for retrieving tool arguments. | ||||||
|  |         trust_remote_code (bool): | ||||||
|  |             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||||
|  |         kwargs (dict, *optional*): | ||||||
|  |             Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |     ```python | ||||||
|  |     >>> engine = MLXModel( | ||||||
|  |     ...     model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", | ||||||
|  |     ...     max_tokens=10000, | ||||||
|  |     ... ) | ||||||
|  |     >>> messages = [ | ||||||
|  |     ...     { | ||||||
|  |     ...         "role": "user",  | ||||||
|  |     ...         "content": [ | ||||||
|  |     ...             {"type": "text", "text": "Explain quantum mechanics in simple terms."} | ||||||
|  |     ...         ] | ||||||
|  |     ...     } | ||||||
|  |     ... ] | ||||||
|  |     >>> response = engine(messages, stop_sequences=["END"]) | ||||||
|  |     >>> print(response) | ||||||
|  |     "Quantum mechanics is the branch of physics that studies..." | ||||||
|  |     ``` | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_id: str, | ||||||
|  |         tool_name_key: str = "name", | ||||||
|  |         tool_arguments_key: str = "arguments", | ||||||
|  |         trust_remote_code: bool = False, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         super().__init__(**kwargs) | ||||||
|  |         if not _is_package_available("mlx_lm"): | ||||||
|  |             raise ModuleNotFoundError( | ||||||
|  |                 "Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`" | ||||||
|  |             ) | ||||||
|  |         import mlx_lm | ||||||
|  | 
 | ||||||
|  |         self.model_id = model_id | ||||||
|  |         self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code}) | ||||||
|  |         self.stream_generate = mlx_lm.stream_generate | ||||||
|  |         self.tool_name_key = tool_name_key | ||||||
|  |         self.tool_arguments_key = tool_arguments_key | ||||||
|  | 
 | ||||||
|  |     def _to_message(self, text, tools_to_call_from): | ||||||
|  |         if tools_to_call_from: | ||||||
|  |             # tmp solution for extracting tool JSON without assuming a specific model output format | ||||||
|  |             maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}" | ||||||
|  |             parsed_text = json.loads(maybe_json) | ||||||
|  |             tool_name = parsed_text.get(self.tool_name_key, None) | ||||||
|  |             tool_arguments = parsed_text.get(self.tool_arguments_key, None) | ||||||
|  |             if tool_name: | ||||||
|  |                 return ChatMessage( | ||||||
|  |                     role="assistant", | ||||||
|  |                     content="", | ||||||
|  |                     tool_calls=[ | ||||||
|  |                         ChatMessageToolCall( | ||||||
|  |                             id=uuid.uuid4(), | ||||||
|  |                             type="function", | ||||||
|  |                             function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), | ||||||
|  |                         ) | ||||||
|  |                     ], | ||||||
|  |                 ) | ||||||
|  |         return ChatMessage(role="assistant", content=text) | ||||||
|  | 
 | ||||||
|  |     def __call__( | ||||||
|  |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|  |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> ChatMessage: | ||||||
|  |         completion_kwargs = self._prepare_completion_kwargs( | ||||||
|  |             flatten_messages_as_text=True,  # mlx-lm doesn't support vision models | ||||||
|  |             messages=messages, | ||||||
|  |             stop_sequences=stop_sequences, | ||||||
|  |             grammar=grammar, | ||||||
|  |             tools_to_call_from=tools_to_call_from, | ||||||
|  |             **kwargs, | ||||||
|  |         ) | ||||||
|  |         messages = completion_kwargs.pop("messages") | ||||||
|  |         prepared_stop_sequences = completion_kwargs.pop("stop", []) | ||||||
|  |         tools = completion_kwargs.pop("tools", None) | ||||||
|  |         completion_kwargs.pop("tool_choice", None) | ||||||
|  | 
 | ||||||
|  |         prompt_ids = self.tokenizer.apply_chat_template( | ||||||
|  |             messages, | ||||||
|  |             tools=tools, | ||||||
|  |             add_generation_prompt=True, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.last_input_token_count = len(prompt_ids) | ||||||
|  |         self.last_output_token_count = 0 | ||||||
|  |         text = "" | ||||||
|  | 
 | ||||||
|  |         for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs): | ||||||
|  |             self.last_output_token_count += 1 | ||||||
|  |             text += _.text | ||||||
|  |             for stop_sequence in prepared_stop_sequences: | ||||||
|  |                 if text.strip().endswith(stop_sequence): | ||||||
|  |                     text = text[: -len(stop_sequence)] | ||||||
|  |                     return self._to_message(text, tools_to_call_from) | ||||||
|  | 
 | ||||||
|  |         return self._to_message(text, tools_to_call_from) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|     """A class that uses Hugging Face's Transformers library for language model interaction. |     """A class that uses Hugging Face's Transformers library for language model interaction. | ||||||
| 
 | 
 | ||||||
|  | @ -837,6 +960,7 @@ __all__ = [ | ||||||
|     "tool_role_conversions", |     "tool_role_conversions", | ||||||
|     "get_clean_message_list", |     "get_clean_message_list", | ||||||
|     "Model", |     "Model", | ||||||
|  |     "MLXModel", | ||||||
|     "TransformersModel", |     "TransformersModel", | ||||||
|     "HfApiModel", |     "HfApiModel", | ||||||
|     "LiteLLMModel", |     "LiteLLMModel", | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import json | import json | ||||||
| import os | import os | ||||||
|  | import sys | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  | @ -22,7 +23,7 @@ from unittest.mock import MagicMock, patch | ||||||
| import pytest | import pytest | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool | ||||||
| from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed | from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -61,6 +62,13 @@ class ModelTests(unittest.TestCase): | ||||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] |         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||||
|         model(messages, stop_sequences=["great"]) |         model(messages, stop_sequences=["great"]) | ||||||
| 
 | 
 | ||||||
|  |     @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") | ||||||
|  |     def test_get_mlx_message_no_tool(self): | ||||||
|  |         model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10) | ||||||
|  |         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||||
|  |         output = model(messages, stop_sequences=["great"]).content | ||||||
|  |         assert output.startswith("Hello") | ||||||
|  | 
 | ||||||
|     def test_transformers_message_no_tool(self): |     def test_transformers_message_no_tool(self): | ||||||
|         model = TransformersModel( |         model = TransformersModel( | ||||||
|             model_id="HuggingFaceTB/SmolLM2-135M-Instruct", |             model_id="HuggingFaceTB/SmolLM2-135M-Instruct", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue