diff --git a/src/smolagents/models.py b/src/smolagents/models.py index cb825b4..3159cb4 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -842,17 +842,8 @@ class LiteLLMModel(Model): custom_role_conversions: Optional[Dict[str, str]] = None, **kwargs, ): - try: - import litellm - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" - ) - super().__init__(**kwargs) self.model_id = model_id - # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs - litellm.add_function_to_prompt = True self.api_base = api_base self.api_key = api_key self.custom_role_conversions = custom_role_conversions @@ -870,7 +861,12 @@ class LiteLLMModel(Model): tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: - import litellm + try: + import litellm + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" + ) completion_kwargs = self._prepare_completion_kwargs( messages=messages, diff --git a/tests/test_models.py b/tests/test_models.py index a77c956..10f36ae 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -148,8 +148,8 @@ class TestLiteLLMModel: "model_id, error_flag", [ ("groq/llama-3.3-70b", "Missing API Key"), - ("cerebras/llama-3.3-70b", "Wrong API Key"), - ("ollama/llama2", "not found"), + ("cerebras/llama-3.3-70b", "The api_key client option must be set"), + ("mistral/mistral-tiny", "The api_key client option must be set"), ], ) def test_call_different_providers_without_key(self, model_id, error_flag):