Fix litellm flatten_messages_as_text detection
This commit is contained in:
parent
d33bc2dd9e
commit
d34e0c81d9
|
@ -22,7 +22,6 @@ import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
@ -857,16 +856,11 @@ class LiteLLMModel(Model):
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.custom_role_conversions = custom_role_conversions
|
self.custom_role_conversions = custom_role_conversions
|
||||||
|
self.flatten_messages_as_text = (
|
||||||
@cached_property
|
kwargs.get("flatten_messages_as_text")
|
||||||
def _flatten_messages_as_text(self):
|
if "flatten_messages_as_text" in kwargs
|
||||||
import litellm
|
else self.model_id.startswith(("ollama", "groq", "cerebras"))
|
||||||
|
)
|
||||||
model_info: dict = litellm.get_model_info(self.model_id)
|
|
||||||
if model_info["litellm_provider"] == "ollama":
|
|
||||||
return model_info["key"] != "llava"
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -887,7 +881,7 @@ class LiteLLMModel(Model):
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
convert_images_to_image_urls=True,
|
convert_images_to_image_urls=True,
|
||||||
flatten_messages_as_text=self._flatten_messages_as_text,
|
flatten_messages_as_text=self.flatten_messages_as_text,
|
||||||
custom_role_conversions=self.custom_role_conversions,
|
custom_role_conversions=self.custom_role_conversions,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ from unittest.mock import MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents import ChatMessage, HfApiModel, MLXModel, TransformersModel, models, tool
|
from smolagents import ChatMessage, HfApiModel, LiteLLMModel, MLXModel, TransformersModel, models, tool
|
||||||
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,18 +50,6 @@ class ModelTests(unittest.TestCase):
|
||||||
data = json.loads(message.model_dump_json())
|
data = json.loads(message.model_dump_json())
|
||||||
assert data["content"] == [{"type": "text", "text": "Hello!"}]
|
assert data["content"] == [{"type": "text", "text": "Hello!"}]
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
|
||||||
def test_get_hfapi_message_no_tool(self):
|
|
||||||
model = HfApiModel(max_tokens=10)
|
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
|
||||||
model(messages, stop_sequences=["great"])
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
|
||||||
def test_get_hfapi_message_no_tool_external_provider(self):
|
|
||||||
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
|
||||||
model(messages, stop_sequences=["great"])
|
|
||||||
|
|
||||||
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
@unittest.skipUnless(sys.platform.startswith("darwin"), "requires macOS")
|
||||||
def test_get_mlx_message_no_tool(self):
|
def test_get_mlx_message_no_tool(self):
|
||||||
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
|
model = MLXModel(model_id="HuggingFaceTB/SmolLM2-135M-Instruct", max_tokens=10)
|
||||||
|
@ -142,6 +130,43 @@ class TestHfApiModel:
|
||||||
"role conversion should be applied"
|
"role conversion should be applied"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||||
|
def test_get_hfapi_message_no_tool(self):
|
||||||
|
model = HfApiModel(max_tokens=10)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||||
|
model(messages, stop_sequences=["great"])
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not os.getenv("RUN_ALL"), reason="RUN_ALL environment variable not set")
|
||||||
|
def test_get_hfapi_message_no_tool_external_provider(self):
|
||||||
|
model = HfApiModel(model="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together", max_tokens=10)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}]
|
||||||
|
model(messages, stop_sequences=["great"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestLiteLLMModel:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_id, error_flag",
|
||||||
|
[
|
||||||
|
("groq/llama-3.3-70b", "Missing API Key"),
|
||||||
|
("cerebras/llama-3.3-70b", "Wrong API Key"),
|
||||||
|
("ollama/llama2", "not found"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_call_different_providers_without_key(self, model_id, error_flag):
|
||||||
|
model = LiteLLMModel(model_id=model_id)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": "Test message"}]}]
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
# This should raise 401 error because of missing API key, not fail for any "bad format" reason
|
||||||
|
model(messages)
|
||||||
|
assert error_flag in str(e)
|
||||||
|
|
||||||
|
def test_passing_flatten_messages(self):
|
||||||
|
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
|
||||||
|
assert not model.flatten_messages_as_text
|
||||||
|
|
||||||
|
model = LiteLLMModel(model_id="fal/llama-3.3-70b", flatten_messages_as_text=True)
|
||||||
|
assert model.flatten_messages_as_text
|
||||||
|
|
||||||
|
|
||||||
def test_get_clean_message_list_basic():
|
def test_get_clean_message_list_basic():
|
||||||
messages = [
|
messages = [
|
||||||
|
|
Loading…
Reference in New Issue