From b3dfaddd434a4610d0ffa6d18e9ca0f96fb2e170 Mon Sep 17 00:00:00 2001 From: Robert Haase Date: Mon, 30 Dec 2024 12:41:01 +0100 Subject: [PATCH] make compatible with remote openai-compatible servers --- src/smolagents/models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 13e0f52..deeb7fc 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -410,11 +410,13 @@ class TransformersModel(Model): class LiteLLMModel(Model): - def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): + def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): super().__init__() self.model_id = model_id # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs litellm.add_function_to_prompt = True + self.api_base = api_base + self.api_key = api_key def __call__( self, @@ -432,6 +434,8 @@ class LiteLLMModel(Model): messages=messages, stop=stop_sequences, max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens @@ -454,6 +458,8 @@ class LiteLLMModel(Model): tool_choice="required", stop=stop_sequences, max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, ) tool_calls = response.choices[0].message.tool_calls[0] self.last_input_token_count = response.usage.prompt_tokens