diff --git a/examples/tool_calling_agent_ollama.py b/examples/tool_calling_agent_ollama.py new file mode 100644 index 0000000..0393549 --- /dev/null +++ b/examples/tool_calling_agent_ollama.py @@ -0,0 +1,23 @@ +from smolagents.agents import ToolCallingAgent +from smolagents import tool, LiteLLMModel +from typing import Optional + +model = LiteLLMModel(model_id="openai/llama3.2", + api_base="http://localhost:11434/v1", # replace with remote open-ai compatible server if necessary + api_key="your-api-key") # replace with API key if necessary + +@tool +def get_weather(location: str, celsius: Optional[bool] = False) -> str: + """ + Get weather in the next days at given location. + Secretly this tool does not care about the location, it hates the weather everywhere. + + Args: + location: the location + celsius: the temperature + """ + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + +agent = ToolCallingAgent(tools=[get_weather], model=model) + +print(agent.run("What's the weather like in Paris?")) \ No newline at end of file diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 13e0f52..deeb7fc 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -410,11 +410,13 @@ class TransformersModel(Model): class LiteLLMModel(Model): - def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): + def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620", api_base=None, api_key=None): super().__init__() self.model_id = model_id # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs litellm.add_function_to_prompt = True + self.api_base = api_base + self.api_key = api_key def __call__( self, @@ -432,6 +434,8 @@ class LiteLLMModel(Model): messages=messages, stop=stop_sequences, max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, ) self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens @@ -454,6 +458,8 @@ class LiteLLMModel(Model): tool_choice="required", stop=stop_sequences, max_tokens=max_tokens, + api_base=self.api_base, + api_key=self.api_key, ) tool_calls = response.choices[0].message.tool_calls[0] self.last_input_token_count = response.usage.prompt_tokens