Add standard call to LiteLLM engine

This commit is contained in:
Aymeric 2024-12-24 19:55:34 +01:00
parent 1e357cee7f
commit 162d4dc362
2 changed files with 30 additions and 7 deletions

View File

@ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo
# llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620") # llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620")
# llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") # llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct")
# llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct") # llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct")
llm_engine = LiteLLMEngine() llm_engine = LiteLLMEngine("gpt-4o")
@tool @tool
def get_weather(location: str) -> str: def get_weather(location: str) -> str:

View File

@ -19,6 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import litellm
import logging import logging
import os import os
import random import random
@ -566,10 +567,31 @@ class AnthropicEngine:
class LiteLLMEngine(): class LiteLLMEngine():
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
self.model_id = model_id self.model_id = model_id
import os, litellm
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
self.last_input_token_count = 0
self.last_output_token_count = 0
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> str:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
response = litellm.completion(
model=self.model_id,
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return response.choices[0].message.content
def get_tool_call( def get_tool_call(
self, self,
@ -578,19 +600,20 @@ class LiteLLMEngine():
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
max_tokens: int = 1500, max_tokens: int = 1500,
): ):
from litellm import completion
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
response = completion( response = litellm.completion(
model=self.model_id, model=self.model_id,
messages=messages, messages=messages,
tools=[get_json_schema(tool) for tool in available_tools], tools=[get_json_schema(tool) for tool in available_tools],
tool_choice="required", tool_choice="required",
max_tokens=max_tokens,
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens,
) )
tool_calls = response.choices[0].message.tool_calls[0] tool_calls = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id