MLX model support (#300)
This commit is contained in:
		
							parent
							
								
									bca3a9bc13
								
							
						
					
					
						commit
						9b96199d00
					
				|  | @ -28,10 +28,11 @@ To initialize a minimal agent, you need at least these two arguments: | |||
|     - [`HfApiModel`] leverages a `huggingface_hub.InferenceClient` under the hood and supports all Inference Providers on the Hub. | ||||
|     - [`LiteLLMModel`] similarly lets you call 100+ different models and providers through [LiteLLM](https://docs.litellm.ai/)! | ||||
|     - [`AzureOpenAIServerModel`] allows you to use OpenAI models deployed in [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service). | ||||
|     - [`MLXModel`] creates a [mlx-lm](https://pypi.org/project/mlx-lm/) pipeline to run inference on your local machine. | ||||
| 
 | ||||
| - `tools`, a list of `Tools` that the agent can use to solve the task. It can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. | ||||
| 
 | ||||
| Once you have these two arguments, `tools` and `model`,  you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), or [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service). | ||||
| Once you have these two arguments, `tools` and `model`,  you can create an agent and run it. You can use any LLM you'd like, either through [Inference Providers](https://huggingface.co/blog/inference-providers), [transformers](https://github.com/huggingface/transformers/), [ollama](https://ollama.com/), [LiteLLM](https://www.litellm.ai/), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), or [mlx-lm](https://pypi.org/project/mlx-lm/). | ||||
| 
 | ||||
| <hfoptions id="Pick a LLM"> | ||||
| <hfoption id="HF Inference API"> | ||||
|  | @ -148,6 +149,19 @@ agent.run( | |||
| ) | ||||
| ``` | ||||
| 
 | ||||
| </hfoption> | ||||
| <hfoption id="mlx-lm"> | ||||
| 
 | ||||
| ```python | ||||
| # !pip install smolagents[mlx-lm] | ||||
| from smolagents import CodeAgent, MLXModel | ||||
| 
 | ||||
| mlx_model = MLXModel("mlx-community/Qwen2.5-Coder-32B-Instruct-4bit") | ||||
| agent = CodeAgent(model=mlx_model, tools=[], add_base_tools=True) | ||||
| 
 | ||||
| agent.run("Could you give me the 118th number in the Fibonacci sequence?") | ||||
| ``` | ||||
| 
 | ||||
| </hfoption> | ||||
| </hfoptions> | ||||
| 
 | ||||
|  |  | |||
|  | @ -147,4 +147,23 @@ model = AzureOpenAIServerModel( | |||
| ) | ||||
| ``` | ||||
| 
 | ||||
| [[autodoc]] AzureOpenAIServerModel | ||||
| [[autodoc]] AzureOpenAIServerModel | ||||
| 
 | ||||
| ### MLXModel | ||||
| 
 | ||||
| 
 | ||||
| ```python | ||||
| from smolagents import MLXModel | ||||
| 
 | ||||
| model = MLXModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct") | ||||
| 
 | ||||
| print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])) | ||||
| ``` | ||||
| ```text | ||||
| >>> What a | ||||
| ``` | ||||
| 
 | ||||
| > [!TIP] | ||||
| > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. | ||||
| 
 | ||||
| [[autodoc]] MLXModel | ||||
|  |  | |||
|  | @ -46,6 +46,9 @@ mcp = [ | |||
|   "mcpadapt>=0.0.6", | ||||
|   "mcp", | ||||
| ] | ||||
| mlx-lm = [ | ||||
|   "mlx-lm" | ||||
| ] | ||||
| openai = [ | ||||
|   "openai>=1.58.1" | ||||
| ] | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ import json | |||
| import logging | ||||
| import os | ||||
| import random | ||||
| import uuid | ||||
| from copy import deepcopy | ||||
| from dataclasses import asdict, dataclass | ||||
| from enum import Enum | ||||
|  | @ -415,6 +416,128 @@ class HfApiModel(Model): | |||
|         return message | ||||
| 
 | ||||
| 
 | ||||
| class MLXModel(Model): | ||||
|     """A class to interact with models loaded using MLX on Apple silicon. | ||||
| 
 | ||||
|     > [!TIP] | ||||
|     > You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case. | ||||
| 
 | ||||
|     Parameters: | ||||
|         model_id (str): | ||||
|             The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||
|         tool_name_key (str): | ||||
|             The key, which can usually be found in the model's chat template, for retrieving a tool name. | ||||
|         tool_arguments_key (str): | ||||
|             The key, which can usually be found in the model's chat template, for retrieving tool arguments. | ||||
|         trust_remote_code (bool): | ||||
|             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||
|         kwargs (dict, *optional*): | ||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_tokens`. | ||||
| 
 | ||||
|     Example: | ||||
|     ```python | ||||
|     >>> engine = MLXModel( | ||||
|     ...     model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", | ||||
|     ...     max_tokens=10000, | ||||
|     ... ) | ||||
|     >>> messages = [ | ||||
|     ...     { | ||||
|     ...         "role": "user",  | ||||
|     ...         "content": [ | ||||
|     ...             {"type": "text", "text": "Explain quantum mechanics in simple terms."} | ||||
|     ...         ] | ||||
|     ...     } | ||||
|     ... ] | ||||
|     >>> response = engine(messages, stop_sequences=["END"]) | ||||
|     >>> print(response) | ||||
|     "Quantum mechanics is the branch of physics that studies..." | ||||
|     ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_id: str, | ||||
|         tool_name_key: str = "name", | ||||
|         tool_arguments_key: str = "arguments", | ||||
|         trust_remote_code: bool = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__(**kwargs) | ||||
|         if not _is_package_available("mlx_lm"): | ||||
|             raise ModuleNotFoundError( | ||||
|                 "Please install 'mlx-lm' extra to use 'MLXModel': `pip install 'smolagents[mlx-lm]'`" | ||||
|             ) | ||||
|         import mlx_lm | ||||
| 
 | ||||
|         self.model_id = model_id | ||||
|         self.model, self.tokenizer = mlx_lm.load(model_id, tokenizer_config={"trust_remote_code": trust_remote_code}) | ||||
|         self.stream_generate = mlx_lm.stream_generate | ||||
|         self.tool_name_key = tool_name_key | ||||
|         self.tool_arguments_key = tool_arguments_key | ||||
| 
 | ||||
|     def _to_message(self, text, tools_to_call_from): | ||||
|         if tools_to_call_from: | ||||
|             # tmp solution for extracting tool JSON without assuming a specific model output format | ||||
|             maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}" | ||||
|             parsed_text = json.loads(maybe_json) | ||||
|             tool_name = parsed_text.get(self.tool_name_key, None) | ||||
|             tool_arguments = parsed_text.get(self.tool_arguments_key, None) | ||||
|             if tool_name: | ||||
|                 return ChatMessage( | ||||
|                     role="assistant", | ||||
|                     content="", | ||||
|                     tool_calls=[ | ||||
|                         ChatMessageToolCall( | ||||
|                             id=uuid.uuid4(), | ||||
|                             type="function", | ||||
|                             function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), | ||||
|                         ) | ||||
|                     ], | ||||
|                 ) | ||||
|         return ChatMessage(role="assistant", content=text) | ||||
| 
 | ||||
|     def __call__( | ||||
|         self, | ||||
|         messages: List[Dict[str, str]], | ||||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|         **kwargs, | ||||
|     ) -> ChatMessage: | ||||
|         completion_kwargs = self._prepare_completion_kwargs( | ||||
|             flatten_messages_as_text=True,  # mlx-lm doesn't support vision models | ||||
|             messages=messages, | ||||
|             stop_sequences=stop_sequences, | ||||
|             grammar=grammar, | ||||
|             tools_to_call_from=tools_to_call_from, | ||||
|             **kwargs, | ||||
|         ) | ||||
|         messages = completion_kwargs.pop("messages") | ||||
|         prepared_stop_sequences = completion_kwargs.pop("stop", []) | ||||
|         tools = completion_kwargs.pop("tools", None) | ||||
|         completion_kwargs.pop("tool_choice", None) | ||||
| 
 | ||||
|         prompt_ids = self.tokenizer.apply_chat_template( | ||||
|             messages, | ||||
|             tools=tools, | ||||
|             add_generation_prompt=True, | ||||
|         ) | ||||
| 
 | ||||
|         self.last_input_token_count = len(prompt_ids) | ||||
|         self.last_output_token_count = 0 | ||||
|         text = "" | ||||
| 
 | ||||
|         for _ in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs): | ||||
|             self.last_output_token_count += 1 | ||||
|             text += _.text | ||||
|             for stop_sequence in prepared_stop_sequences: | ||||
|                 if text.strip().endswith(stop_sequence): | ||||
|                     text = text[: -len(stop_sequence)] | ||||
|                     return self._to_message(text, tools_to_call_from) | ||||
| 
 | ||||
|         return self._to_message(text, tools_to_call_from) | ||||
| 
 | ||||
| 
 | ||||
| class TransformersModel(Model): | ||||
|     """A class that uses Hugging Face's Transformers library for language model interaction. | ||||
| 
 | ||||
|  | @ -837,6 +960,7 @@ __all__ = [ | |||
|     "tool_role_conversions", | ||||
|     "get_clean_message_list", | ||||
|     "Model", | ||||
|     "MLXModel", | ||||
|     "TransformersModel", | ||||
|     "HfApiModel", | ||||
|     "LiteLLMModel", | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| # limitations under the License. | ||||
| import json | ||||
| import os | ||||
| import sys | ||||
| import unittest | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
|  | @ -22,7 +23,7 @@ from unittest.mock import MagicMock, patch | |||
| import pytest | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| 
 | ||||
| from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | ||||
| from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool | ||||
| from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed | ||||
| 
 | ||||
| 
 | ||||
|  | @ -61,6 +62,13 @@ class ModelTests(unittest.TestCase): | |||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||
|         model(messages, stop_sequences=["great"]) | ||||
| 
 | ||||
|     @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") | ||||
|     def test_get_mlx_message_no_tool(self): | ||||
|         model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10) | ||||
|         messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] | ||||
|         output = model(messages, stop_sequences=["great"]).content | ||||
|         assert output.startswith("Hello") | ||||
| 
 | ||||
|     def test_transformers_message_no_tool(self): | ||||
|         model = TransformersModel( | ||||
|             model_id="HuggingFaceTB/SmolLM2-135M-Instruct", | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue