Add standard call to LiteLLM engine
This commit is contained in:
parent
1e357cee7f
commit
162d4dc362
|
@ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo
|
||||||
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
|
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
|
||||||
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
|
||||||
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
|
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
|
||||||
llm_engine = LiteLLMEngine()
|
llm_engine = LiteLLMEngine("gpt-4o")
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str) -> str:
|
def get_weather(location: str) -> str:
|
||||||
|
|
|
@ -19,6 +19,7 @@ from enum import Enum
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
|
import litellm
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
@ -566,11 +567,32 @@ class AnthropicEngine:
|
||||||
class LiteLLMEngine():
|
class LiteLLMEngine():
|
||||||
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
|
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
import os, litellm
|
|
||||||
|
|
||||||
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
|
||||||
litellm.add_function_to_prompt = True
|
litellm.add_function_to_prompt = True
|
||||||
|
self.last_input_token_count = 0
|
||||||
|
self.last_output_token_count = 0
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
max_tokens: int = 1500,
|
||||||
|
) -> str:
|
||||||
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=tool_role_conversions
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def get_tool_call(
|
def get_tool_call(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
@ -578,19 +600,20 @@ class LiteLLMEngine():
|
||||||
stop_sequences: Optional[List[str]] = None,
|
stop_sequences: Optional[List[str]] = None,
|
||||||
max_tokens: int = 1500,
|
max_tokens: int = 1500,
|
||||||
):
|
):
|
||||||
from litellm import completion
|
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
response = completion(
|
response = litellm.completion(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[get_json_schema(tool) for tool in available_tools],
|
tools=[get_json_schema(tool) for tool in available_tools],
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop_sequences,
|
stop=stop_sequences,
|
||||||
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
tool_calls = response.choices[0].message.tool_calls[0]
|
tool_calls = response.choices[0].message.tool_calls[0]
|
||||||
|
self.last_input_token_count = response.usage.prompt_tokens
|
||||||
|
self.last_output_token_count = response.usage.completion_tokens
|
||||||
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
|
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue