Add standard call to LiteLLM engine
This commit is contained in:
		
							parent
							
								
									1e357cee7f
								
							
						
					
					
						commit
						162d4dc362
					
				|  | @ -6,7 +6,7 @@ from smolagents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine, Transfo | ||||||
| # llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620") | # llm_engine = AnthropicEngine("claude-3-5-sonnet-20240620") | ||||||
| # llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") | # llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") | ||||||
| # llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct") | # llm_engine = TransformersEngine("meta-llama/Llama-3.2-2B-Instruct") | ||||||
| llm_engine = LiteLLMEngine() | llm_engine = LiteLLMEngine("gpt-4o") | ||||||
| 
 | 
 | ||||||
| @tool | @tool | ||||||
| def get_weather(location: str) -> str: | def get_weather(location: str) -> str: | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ from enum import Enum | ||||||
| from typing import Dict, List, Optional, Tuple | from typing import Dict, List, Optional, Tuple | ||||||
| from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList | ||||||
| 
 | 
 | ||||||
|  | import litellm | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
|  | @ -566,10 +567,31 @@ class AnthropicEngine: | ||||||
| class LiteLLMEngine(): | class LiteLLMEngine(): | ||||||
|     def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): |     def __init__(self, model_id="anthropic/claude-3-5-sonnet-20240620"): | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|         import os, litellm |  | ||||||
| 
 |  | ||||||
|         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs |         # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs | ||||||
|         litellm.add_function_to_prompt = True |         litellm.add_function_to_prompt = True | ||||||
|  |         self.last_input_token_count = 0 | ||||||
|  |         self.last_output_token_count = 0 | ||||||
|  | 
 | ||||||
|  |     def __call__( | ||||||
|  |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|  |         max_tokens: int = 1500, | ||||||
|  |     ) -> str: | ||||||
|  |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=tool_role_conversions | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response = litellm.completion( | ||||||
|  |             model=self.model_id, | ||||||
|  |             messages=messages, | ||||||
|  |             stop=stop_sequences, | ||||||
|  |             max_tokens=max_tokens, | ||||||
|  |         ) | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|  |         return response.choices[0].message.content | ||||||
| 
 | 
 | ||||||
|     def get_tool_call( |     def get_tool_call( | ||||||
|             self, |             self, | ||||||
|  | @ -578,19 +600,20 @@ class LiteLLMEngine(): | ||||||
|             stop_sequences: Optional[List[str]] = None, |             stop_sequences: Optional[List[str]] = None, | ||||||
|             max_tokens: int = 1500, |             max_tokens: int = 1500, | ||||||
|         ): |         ): | ||||||
|         from litellm import completion |  | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|         response = completion( |         response = litellm.completion( | ||||||
|             model=self.model_id, |             model=self.model_id, | ||||||
|             messages=messages, |             messages=messages, | ||||||
|             tools=[get_json_schema(tool) for tool in available_tools], |             tools=[get_json_schema(tool) for tool in available_tools], | ||||||
|             tool_choice="required", |             tool_choice="required", | ||||||
|             max_tokens=max_tokens, |  | ||||||
|             stop=stop_sequences, |             stop=stop_sequences, | ||||||
|  |             max_tokens=max_tokens, | ||||||
|         ) |         ) | ||||||
|         tool_calls = response.choices[0].message.tool_calls[0] |         tool_calls = response.choices[0].message.tool_calls[0] | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id |         return tool_calls.function.name, tool_calls.function.arguments, tool_calls.id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue