Merge pull request #10 from haesleinhuepf/support_remote_llm_servers
Support remote llm servers in LiteLLMModel
This commit is contained in:
		
						commit
						a50f9284b3
					
				|  | @ -0,0 +1,23 @@ | ||||||
|  | from smolagents.agents import ToolCallingAgent | ||||||
|  | from smolagents import tool, LiteLLMModel | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | model = LiteLLMModel(model_id="openai/llama3.2", | ||||||
|  |                      api_base="http://localhost:11434/v1", # replace with remote open-ai compatible server if necessary | ||||||
|  |                      api_key="your-api-key")               # replace with API key if necessary | ||||||
|  | 
 | ||||||
|  | @tool | ||||||
|  | def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||||
|  |     """ | ||||||
|  |     Get weather in the next days at given location. | ||||||
|  |     Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         location: the location | ||||||
|  |         celsius: the temperature | ||||||
|  |     """ | ||||||
|  |     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  | 
 | ||||||
|  | agent = ToolCallingAgent(tools=[get_weather], model=model) | ||||||
|  | 
 | ||||||
|  | print(agent.run("What's the weather like in Paris?")) | ||||||
|  | @ -410,11 +410,13 @@ class TransformersModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LiteLLMModel(Model): | class LiteLLMModel(Model): | ||||||
|     def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): |     def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs |         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||||
|         litellm.add_function_to_prompt = True |         litellm.add_function_to_prompt = True | ||||||
|  |         self.api_base = api_base | ||||||
|  |         self.api_key = api_key | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|  | @ -432,6 +434,8 @@ class LiteLLMModel(Model): | ||||||
|             messages=messages, |             messages=messages, | ||||||
|             stop=stop_sequences, |             stop=stop_sequences, | ||||||
|             max_tokens=max_tokens, |             max_tokens=max_tokens, | ||||||
|  |             api_base=self.api_base, | ||||||
|  |             api_key=self.api_key, | ||||||
|         ) |         ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|  | @ -454,6 +458,8 @@ class LiteLLMModel(Model): | ||||||
|             tool_choice="required", |             tool_choice="required", | ||||||
|             stop=stop_sequences, |             stop=stop_sequences, | ||||||
|             max_tokens=max_tokens, |             max_tokens=max_tokens, | ||||||
|  |             api_base=self.api_base, | ||||||
|  |             api_key=self.api_key, | ||||||
|         ) |         ) | ||||||
|         tool_calls = response.choices[0].message.tool_calls[0] |         tool_calls = response.choices[0].message.tool_calls[0] | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue