From d34e0c81d90134c84f2743fea2e226f697a855ba Mon Sep 17 00:00:00 2001 From: Aymeric Date: Sat, 15 Feb 2025 11:17:16 +0100 Subject: [PATCH] Fix litellm flatten_messages_as_text detection --- src/smolagents/models.py | 18 +++++--------- tests/test_models.py | 51 ++++++++++++++++++++++++++++++---------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 172757e..cb825b4 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -22,7 +22,6 @@ import uuid from copy import deepcopy from dataclasses import asdict, dataclass from enum import Enum -from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from huggingface_hub import InferenceClient @@ -857,16 +856,11 @@ class LiteLLMModel(Model): self.api_base = api_base self.api_key = api_key self.custom_role_conversions = custom_role_conversions - - @cached_property - def _flatten_messages_as_text(self): - import litellm - - model_info: dict = litellm.get_model_info(self.model_id) - if model_info["litellm_provider"] == "ollama": - return model_info["key"] != "llava" - - return False + self.flatten_messages_as_text = ( + kwargs.get("flatten_messages_as_text") + if "flatten_messages_as_text" in kwargs + else self.model_id.startswith(("ollama", "groq", "cerebras")) + ) def __call__( self, @@ -887,7 +881,7 @@ class LiteLLMModel(Model): api_base=self.api_base, api_key=self.api_key, convert_images_to_image_urls=True, - flatten_messages_as_text=self._flatten_messages_as_text, + flatten_messages_as_text=self.flatten_messages_as_text, custom_role_conversions=self.custom_role_conversions, **kwargs, ) diff --git a/tests/test_models.py b/tests/test_models.py index 8c83a9f..a77c956 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch import pytest from transformers.testing_utils import get_tests_dir -from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool +from smolagents import ChatMessage, HfApiModel, LiteLLMModel, MLXModel, TransformersModel, models, tool from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed @@ -50,18 +50,6 @@ class ModelTests(unittest.TestCase): data = json.loads(message.model_dump_json()) assert data["content"] == [{"type": "text", "text": "Hello!"}] - @pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set") - def test_get_hfapi_message_no_tool(self): - model = HfApiModel(max_tokens=10) - messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] - model(messages, stop_sequences=["great"]) - - @pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set") - def test_get_hfapi_message_no_tool_external_provider(self): - model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10) - messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] - model(messages, stop_sequences=["great"]) - @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") def test_get_mlx_message_no_tool(self): model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10) @@ -142,6 +130,43 @@ class TestHfApiModel: "role conversion should be applied" ) + @pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set") + def test_get_hfapi_message_no_tool(self): + model = HfApiModel(max_tokens=10) + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] + model(messages, stop_sequences=["great"]) + + @pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set") + def test_get_hfapi_message_no_tool_external_provider(self): + model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10) + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] + model(messages, stop_sequences=["great"]) + + +class TestLiteLLMModel: + @pytest.mark.parametrize( + "model_id, error_flag", + [ + ("groq/llama-3.3-70b", "Missing API Key"), + ("cerebras/llama-3.3-70b", "Wrong API Key"), + ("ollama/llama2", "not found"), + ], + ) + def test_call_different_providers_without_key(self, model_id, error_flag): + model = LiteLLMModel(model_id=model_id) + messages = [{"role": "user", "content": [{"type": "text", "text": "Test message"}]}] + with pytest.raises(Exception) as e: + # This should raise 401 error because of missing API key, not fail for any "bad format" reason + model(messages) + assert error_flag in str(e) + + def test_passing_flatten_messages(self): + model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False) + assert not model.flatten_messages_as_text + + model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True) + assert model.flatten_messages_as_text + def test_get_clean_message_list_basic(): messages = [