Fix tool calls with LiteLLM and tool optional types (#318)
This commit is contained in:
		
							parent
							
								
									ffaa945936
								
							
						
					
					
						commit
						fe2f4e735c
					
				|  | @ -27,6 +27,7 @@ import json | |||
| import os | ||||
| import re | ||||
| import types | ||||
| from copy import copy | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|  | @ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = { | |||
| 
 | ||||
| def _get_json_schema_type(param_type: str) -> Dict[str, str]: | ||||
|     if param_type in _BASE_TYPE_MAPPING: | ||||
|         return _BASE_TYPE_MAPPING[param_type] | ||||
|         return copy(_BASE_TYPE_MAPPING[param_type]) | ||||
|     if str(param_type) == "Image" and _is_pillow_available(): | ||||
|         from PIL.Image import Image | ||||
| 
 | ||||
|  |  | |||
|  | @ -101,6 +101,18 @@ class ChatMessage: | |||
|             tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] | ||||
|         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_dict(cls, data: dict) -> "ChatMessage": | ||||
|         if data.get("tool_calls"): | ||||
|             tool_calls = [ | ||||
|                 ChatMessageToolCall( | ||||
|                     function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"] | ||||
|                 ) | ||||
|                 for tc in data["tool_calls"] | ||||
|             ] | ||||
|             data["tool_calls"] = tool_calls | ||||
|         return cls(**data) | ||||
| 
 | ||||
| 
 | ||||
| def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: | ||||
|     if isinstance(arguments, dict): | ||||
|  | @ -595,7 +607,9 @@ class LiteLLMModel(Model): | |||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
| 
 | ||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||
|         message = ChatMessage.from_dict( | ||||
|             response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) | ||||
|         ) | ||||
| 
 | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|  | @ -664,7 +678,9 @@ class OpenAIServerModel(Model): | |||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
| 
 | ||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) | ||||
|         message = ChatMessage.from_dict( | ||||
|             response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) | ||||
|         ) | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue