Fix litellm flatten_messages_as_text detection

This commit is contained in:
Aymeric 2025-02-15 11:17:16 +01:00
parent d33bc2dd9e
commit d34e0c81d9
2 changed files with 44 additions and 25 deletions

View File

@ -22,7 +22,6 @@ import uuid
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -857,16 +856,11 @@ class LiteLLMModel(Model):
self.api_base = api_base self.api_base = api_base
self.api_key = api_key self.api_key = api_key
self.custom_role_conversions = custom_role_conversions self.custom_role_conversions = custom_role_conversions
self.flatten_messages_as_text = (
@cached_property kwargs.get("flatten_messages_as_text")
def _flatten_messages_as_text(self): if "flatten_messages_as_text" in kwargs
import litellm else self.model_id.startswith(("ollama", "groq", "cerebras"))
)
model_info: dict = litellm.get_model_info(self.model_id)
if model_info["litellm_provider"] == "ollama":
return model_info["key"] != "llava"
return False
def __call__( def __call__(
self, self,
@ -887,7 +881,7 @@ class LiteLLMModel(Model):
api_base=self.api_base, api_base=self.api_base,
api_key=self.api_key, api_key=self.api_key,
convert_images_to_image_urls=True, convert_images_to_image_urls=True,
flatten_messages_as_text=self._flatten_messages_as_text, flatten_messages_as_text=self.flatten_messages_as_text,
custom_role_conversions=self.custom_role_conversions, custom_role_conversions=self.custom_role_conversions,
**kwargs, **kwargs,
) )

View File

@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool from smolagents import ChatMessage, HfApiModel, LiteLLMModel, MLXModel, TransformersModel, models, tool
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
@ -50,18 +50,6 @@ class ModelTests(unittest.TestCase):
data = json.loads(message.model_dump_json()) data = json.loads(message.model_dump_json())
assert data["content"] == [{"type": "text", "text": "Hello!"}] assert data["content"] == [{"type": "text", "text": "Hello!"}]
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool(self):
model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool_external_provider(self):
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS") @unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self): def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10) model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
@ -142,6 +130,43 @@ class TestHfApiModel:
"role conversion should be applied" "role conversion should be applied"
) )
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool(self):
model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool_external_provider(self):
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
class TestLiteLLMModel:
@pytest.mark.parametrize(
"model_id, error_flag",
[
("groq/llama-3.3-70b", "Missing API Key"),
("cerebras/llama-3.3-70b", "Wrong API Key"),
("ollama/llama2", "not found"),
],
)
def test_call_different_providers_without_key(self, model_id, error_flag):
model = LiteLLMModel(model_id=model_id)
messages = [{"role": "user", "content": [{"type": "text", "text": "Test message"}]}]
with pytest.raises(Exception) as e:
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
model(messages)
assert error_flag in str(e)
def test_passing_flatten_messages(self):
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
assert not model.flatten_messages_as_text
model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
assert model.flatten_messages_as_text
def test_get_clean_message_list_basic(): def test_get_clean_message_list_basic():
messages = [ messages = [