LiteLLMModel - detect message flatenning based on model information (#553)
This commit is contained in:
parent
41a388dac6
commit
392fc5ade5
|
@ -22,6 +22,7 @@ import uuid
|
|||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
@ -799,6 +800,16 @@ class LiteLLMModel(Model):
|
|||
self.api_key = api_key
|
||||
self.custom_role_conversions = custom_role_conversions
|
||||
|
||||
@cached_property
|
||||
def _flatten_messages_as_text(self):
|
||||
import litellm
|
||||
|
||||
model_info: dict = litellm.get_model_info(self.model_id)
|
||||
if model_info["litellm_provider"] == "ollama":
|
||||
return model_info["key"] != "llava"
|
||||
|
||||
return False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -818,7 +829,7 @@ class LiteLLMModel(Model):
|
|||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
convert_images_to_image_urls=True,
|
||||
flatten_messages_as_text=self.model_id.startswith("ollama"),
|
||||
flatten_messages_as_text=self._flatten_messages_as_text,
|
||||
custom_role_conversions=self.custom_role_conversions,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue