Merge pull request #10 from haesleinhuepf/support_remote_llm_servers

Support remote llm servers in LiteLLMModel
This commit is contained in:
Aymeric Roucher 2024-12-30 15:14:21 +01:00 committed by GitHub
commit a50f9284b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 1 deletions

View File

@ -0,0 +1,23 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, LiteLLMModel
from typing import Optional
model = LiteLLMModel(model_id="openai/llama3.2",
api_base="http://localhost:11434/v1", # replace with remote open-ai compatible server if necessary
api_key="your-api-key") # replace with API key if necessary
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
agent = ToolCallingAgent(tools=[get_weather], model=model)
print(agent.run("What's the weather like in Paris?"))

View File

@ -410,11 +410,13 @@ class TransformersModel(Model):
class LiteLLMModel(Model):
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"):
def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None):
super().__init__()
self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
self.api_base = api_base
self.api_key = api_key
def __call__(
self,
@ -432,6 +434,8 @@ class LiteLLMModel(Model):
messages=messages,
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
@ -454,6 +458,8 @@ class LiteLLMModel(Model):
tool_choice="required",
stop=stop_sequences,
max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
)
tool_calls = response.choices[0].message.tool_calls[0]
self.last_input_token_count = response.usage.prompt_tokens