Implement OpenAIServerModel (#109)
This commit is contained in:
		
							parent
							
								
									d3cd0f9e09
								
							
						
					
					
						commit
						b4528d6a6f
					
				|  | @ -27,6 +27,7 @@ dependencies = [ | ||||||
|   "python-dotenv>=1.0.1", |   "python-dotenv>=1.0.1", | ||||||
|   "e2b-code-interpreter>=1.0.3", |   "e2b-code-interpreter>=1.0.3", | ||||||
|   "litellm>=1.55.10", |   "litellm>=1.55.10", | ||||||
|  |   "openai>=1.58.1", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [tool.ruff] | [tool.ruff] | ||||||
|  |  | ||||||
|  | @ -31,6 +31,7 @@ from transformers import ( | ||||||
|     StoppingCriteria, |     StoppingCriteria, | ||||||
|     StoppingCriteriaList, |     StoppingCriteriaList, | ||||||
| ) | ) | ||||||
|  | import openai | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
| from .utils import parse_json_tool_call | from .utils import parse_json_tool_call | ||||||
|  | @ -487,6 +488,99 @@ class LiteLLMModel(Model): | ||||||
|         return tool_calls.function.name, arguments, tool_calls.id |         return tool_calls.function.name, arguments, tool_calls.id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class OpenAIServerModel(Model): | ||||||
|  |     """This engine connects to an OpenAI-compatible API server. | ||||||
|  | 
 | ||||||
|  |     Parameters: | ||||||
|  |         model_id (`str`): | ||||||
|  |             The model identifier to use on the server (e.g. "gpt-3.5-turbo"). | ||||||
|  |         api_base (`str`): | ||||||
|  |             The base URL of the OpenAI-compatible API server. | ||||||
|  |         api_key (`str`): | ||||||
|  |             The API key to use for authentication. | ||||||
|  |         temperature (`float`, *optional*, defaults to 0.7): | ||||||
|  |             Controls randomness in the model's responses. Values between 0 and 2. | ||||||
|  |         **kwargs: | ||||||
|  |             Additional keyword arguments to pass to the OpenAI API. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_id: str, | ||||||
|  |         api_base: str, | ||||||
|  |         api_key: str, | ||||||
|  |         temperature: float = 0.7, | ||||||
|  |         **kwargs | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.model_id = model_id | ||||||
|  |         self.client = openai.OpenAI( | ||||||
|  |             base_url=api_base, | ||||||
|  |             api_key=api_key, | ||||||
|  |         ) | ||||||
|  |         self.temperature = temperature | ||||||
|  |         self.kwargs = kwargs | ||||||
|  | 
 | ||||||
|  |     def generate( | ||||||
|  |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|  |         max_tokens: int = 1500, | ||||||
|  |     ) -> str: | ||||||
|  |         """Generates a text completion for the given message list""" | ||||||
|  |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=tool_role_conversions | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response = self.client.chat.completions.create( | ||||||
|  |             model=self.model_id, | ||||||
|  |             messages=messages, | ||||||
|  |             stop=stop_sequences, | ||||||
|  |             max_tokens=max_tokens, | ||||||
|  |             temperature=self.temperature, | ||||||
|  |             **self.kwargs | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|  |         return response.choices[0].message.content | ||||||
|  | 
 | ||||||
|  |     def get_tool_call( | ||||||
|  |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         available_tools: List[Tool], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         max_tokens: int = 500, | ||||||
|  |     ) -> Tuple[str, Union[str, Dict], str]: | ||||||
|  |         """Generates a tool call for the given message list""" | ||||||
|  |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=tool_role_conversions | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response = self.client.chat.completions.create( | ||||||
|  |             model=self.model_id, | ||||||
|  |             messages=messages, | ||||||
|  |             tools=[get_json_schema(tool) for tool in available_tools], | ||||||
|  |             tool_choice="auto", | ||||||
|  |             stop=stop_sequences, | ||||||
|  |             max_tokens=max_tokens, | ||||||
|  |             temperature=self.temperature, | ||||||
|  |             **self.kwargs | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         tool_calls = response.choices[0].message.tool_calls[0] | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             arguments = json.loads(tool_calls.function.arguments) | ||||||
|  |         except json.JSONDecodeError: | ||||||
|  |             arguments = tool_calls.function.arguments | ||||||
|  | 
 | ||||||
|  |         return tool_calls.function.name, arguments, tool_calls.id | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "MessageRole", |     "MessageRole", | ||||||
|     "tool_role_conversions", |     "tool_role_conversions", | ||||||
|  | @ -495,4 +589,5 @@ __all__ = [ | ||||||
|     "TransformersModel", |     "TransformersModel", | ||||||
|     "HfApiModel", |     "HfApiModel", | ||||||
|     "LiteLLMModel", |     "LiteLLMModel", | ||||||
|  |     "OpenAIServerModel", | ||||||
| ] | ] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue