Fix litellm flatten_messages_as_text detection

This commit is contained in:
Aymeric 2025-02-15 11:17:16 +01:00
parent d33bc2dd9e
commit d34e0c81d9
2 changed files with 44 additions and 25 deletions

View File

@ -22,7 +22,6 @@ import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient
@ -857,16 +856,11 @@ class LiteLLMModel(Model):
self.api_base = api_base
self.api_key = api_key
self.custom_role_conversions = custom_role_conversions
@cached_property
def _flatten_messages_as_text(self):
import litellm
model_info: dict = litellm.get_model_info(self.model_id)
if model_info["litellm_provider"] == "ollama":
return model_info["key"] != "llava"
return False
self.flatten_messages_as_text = (
kwargs.get("flatten_messages_as_text")
if "flatten_messages_as_text" in kwargs
else self.model_id.startswith(("ollama", "groq", "cerebras"))
)
def __call__(
self,
@ -887,7 +881,7 @@ class LiteLLMModel(Model):
api_base=self.api_base,
api_key=self.api_key,
convert_images_to_image_urls=True,
flatten_messages_as_text=self._flatten_messages_as_text,
flatten_messages_as_text=self.flatten_messages_as_text,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)

View File

@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
import pytest
from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
from smolagents import ChatMessage, HfApiModel, LiteLLMModel, MLXModel, TransformersModel, models, tool
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
@ -50,18 +50,6 @@ class ModelTests(unittest.TestCase):
data = json.loads(message.model_dump_json())
assert data["content"] == [{"type": "text", "text": "Hello!"}]
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool(self):
model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool_external_provider(self):
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
def test_get_mlx_message_no_tool(self):
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
@ -142,6 +130,43 @@ class TestHfApiModel:
"role conversion should be applied"
)
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool(self):
model = HfApiModel(max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
def test_get_hfapi_message_no_tool_external_provider(self):
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
model(messages, stop_sequences=["great"])
class TestLiteLLMModel:
@pytest.mark.parametrize(
"model_id, error_flag",
[
("groq/llama-3.3-70b", "Missing API Key"),
("cerebras/llama-3.3-70b", "Wrong API Key"),
("ollama/llama2", "not found"),
],
)
def test_call_different_providers_without_key(self, model_id, error_flag):
model = LiteLLMModel(model_id=model_id)
messages = [{"role": "user", "content": [{"type": "text", "text": "Test message"}]}]
with pytest.raises(Exception) as e:
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
model(messages)
assert error_flag in str(e)
def test_passing_flatten_messages(self):
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
assert not model.flatten_messages_as_text
model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
assert model.flatten_messages_as_text
def test_get_clean_message_list_basic():
messages = [