Remove dependency on LiteLLM (#126)

This commit is contained in:
Aymeric Roucher 2025-01-08 22:57:55 +01:00 committed by GitHub
parent d2f4eecba4
commit e1414f6653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 9 deletions

View File

@ -21,13 +21,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Using the latest cached version of the dataset since m-ric/smolagentsbenchmark couldn't be found on the Hugging Face Hub\n", "Using the latest cached version of the dataset since m-ric/smolagentsbenchmark couldn't be found on the Hugging Face Hub\n",
"Found the latest cached dataset configuration 'default' at /Users/aymeric/.cache/huggingface/datasets/m-ric___smolagentsbenchmark/default/0.0.0/0ad5fb2293ab185eece723a4ac0e4a7188f71add (last modified on Wed Jan 8 17:50:13 2025).\n" "Found the latest cached dataset configuration 'default' at /Users/aymeric/.cache/huggingface/datasets/m-ric___smolagentsbenchmark/default/0.0.0/0ad5fb2293ab185eece723a4ac0e4a7188f71add (last modified on Wed Jan 8 17:50:13 2025).\n"
] ]
@ -172,7 +174,7 @@
"[132 rows x 4 columns]" "[132 rows x 4 columns]"
] ]
}, },
"execution_count": 4, "execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -195,9 +197,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
"* 'fields' has been removed\n",
" warnings.warn(message, UserWarning)\n"
]
}
],
"source": [ "source": [
"import time\n", "import time\n",
"import json\n", "import json\n",

View File

@ -26,7 +26,6 @@ dependencies = [
"duckduckgo-search>=6.3.7", "duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3", "e2b-code-interpreter>=1.0.3",
"litellm>=1.55.10",
"openai>=1.58.1", "openai>=1.58.1",
] ]
@ -40,4 +39,5 @@ test = [
"ruff>=0.5.0", "ruff>=0.5.0",
"accelerate", "accelerate",
"soundfile", "soundfile",
"litellm>=1.55.10",
] ]

View File

@ -22,7 +22,6 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import litellm
import torch import torch
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from transformers import ( from transformers import (
@ -48,6 +47,13 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>", "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
} }
try:
import litellm
is_litellm_available = True
except ImportError:
is_litellm_available = False
class MessageRole(str, Enum): class MessageRole(str, Enum):
USER = "user" USER = "user"
@ -428,6 +434,10 @@ class LiteLLMModel(Model):
api_key=None, api_key=None,
**kwargs, **kwargs,
): ):
if not is_litellm_available:
raise ImportError(
"litellm not found. Install it with `pip install litellm`"
)
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
@ -510,7 +520,7 @@ class OpenAIServerModel(Model):
api_base: str, api_base: str,
api_key: str, api_key: str,
temperature: float = 0.7, temperature: float = 0.7,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
@ -539,7 +549,7 @@ class OpenAIServerModel(Model):
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=self.temperature, temperature=self.temperature,
**self.kwargs **self.kwargs,
) )
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
@ -566,7 +576,7 @@ class OpenAIServerModel(Model):
stop=stop_sequences, stop=stop_sequences,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=self.temperature, temperature=self.temperature,
**self.kwargs **self.kwargs,
) )
tool_calls = response.choices[0].message.tool_calls[0] tool_calls = response.choices[0].message.tool_calls[0]