Fix litellm flatten_messages_as_text detection
This commit is contained in:
parent
d33bc2dd9e
commit
d34e0c81d9
|
@ -22,7 +22,6 @@ import uuid
|
|||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
@ -857,16 +856,11 @@ class LiteLLMModel(Model):
|
|||
self.api_base = api_base
|
||||
self.api_key = api_key
|
||||
self.custom_role_conversions = custom_role_conversions
|
||||
|
||||
@cached_property
|
||||
def _flatten_messages_as_text(self):
|
||||
import litellm
|
||||
|
||||
model_info: dict = litellm.get_model_info(self.model_id)
|
||||
if model_info["litellm_provider"] == "ollama":
|
||||
return model_info["key"] != "llava"
|
||||
|
||||
return False
|
||||
self.flatten_messages_as_text = (
|
||||
kwargs.get("flatten_messages_as_text")
|
||||
if "flatten_messages_as_text" in kwargs
|
||||
else self.model_id.startswith(("ollama", "groq", "cerebras"))
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -887,7 +881,7 @@ class LiteLLMModel(Model):
|
|||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
convert_images_to_image_urls=True,
|
||||
flatten_messages_as_text=self._flatten_messages_as_text,
|
||||
flatten_messages_as_text=self.flatten_messages_as_text,
|
||||
custom_role_conversions=self.custom_role_conversions,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
|
||||
from smolagents import ChatMessage, HfApiModel, LiteLLMModel, MLXModel, TransformersModel, models, tool
|
||||
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
||||
|
||||
|
||||
|
@ -50,18 +50,6 @@ class ModelTests(unittest.TestCase):
|
|||
data = json.loads(message.model_dump_json())
|
||||
assert data["content"] == [{"type": "text", "text": "Hello!"}]
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||
def test_get_hfapi_message_no_tool(self):
|
||||
model = HfApiModel(max_tokens=10)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||
model(messages, stop_sequences=["great"])
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||
def test_get_hfapi_message_no_tool_external_provider(self):
|
||||
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||
model(messages, stop_sequences=["great"])
|
||||
|
||||
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
||||
def test_get_mlx_message_no_tool(self):
|
||||
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
|
||||
|
@ -142,6 +130,43 @@ class TestHfApiModel:
|
|||
"role conversion should be applied"
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||
def test_get_hfapi_message_no_tool(self):
|
||||
model = HfApiModel(max_tokens=10)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||
model(messages, stop_sequences=["great"])
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||
def test_get_hfapi_message_no_tool_external_provider(self):
|
||||
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||
model(messages, stop_sequences=["great"])
|
||||
|
||||
|
||||
class TestLiteLLMModel:
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, error_flag",
|
||||
[
|
||||
("groq/llama-3.3-70b", "Missing API Key"),
|
||||
("cerebras/llama-3.3-70b", "Wrong API Key"),
|
||||
("ollama/llama2", "not found"),
|
||||
],
|
||||
)
|
||||
def test_call_different_providers_without_key(self, model_id, error_flag):
|
||||
model = LiteLLMModel(model_id=model_id)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Test message"}]}]
|
||||
with pytest.raises(Exception) as e:
|
||||
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
|
||||
model(messages)
|
||||
assert error_flag in str(e)
|
||||
|
||||
def test_passing_flatten_messages(self):
|
||||
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
|
||||
assert not model.flatten_messages_as_text
|
||||
|
||||
model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
|
||||
assert model.flatten_messages_as_text
|
||||
|
||||
|
||||
def test_get_clean_message_list_basic():
|
||||
messages = [
|
||||
|
|
Loading…
Reference in New Issue