Fix tool calls with LiteLLM and tool optional types (#318)
This commit is contained in:
		
							parent
							
								
									ffaa945936
								
							
						
					
					
						commit
						fe2f4e735c
					
				|  | @ -27,6 +27,7 @@ import json | ||||||
| import os | import os | ||||||
| import re | import re | ||||||
| import types | import types | ||||||
|  | from copy import copy | ||||||
| from typing import ( | from typing import ( | ||||||
|     Any, |     Any, | ||||||
|     Callable, |     Callable, | ||||||
|  | @ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = { | ||||||
| 
 | 
 | ||||||
| def _get_json_schema_type(param_type: str) -> Dict[str, str]: | def _get_json_schema_type(param_type: str) -> Dict[str, str]: | ||||||
|     if param_type in _BASE_TYPE_MAPPING: |     if param_type in _BASE_TYPE_MAPPING: | ||||||
|         return _BASE_TYPE_MAPPING[param_type] |         return copy(_BASE_TYPE_MAPPING[param_type]) | ||||||
|     if str(param_type) == "Image" and _is_pillow_available(): |     if str(param_type) == "Image" and _is_pillow_available(): | ||||||
|         from PIL.Image import Image |         from PIL.Image import Image | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -101,6 +101,18 @@ class ChatMessage: | ||||||
|             tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] |             tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] | ||||||
|         return cls(role=message.role, content=message.content, tool_calls=tool_calls) |         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_dict(cls, data: dict) -> "ChatMessage": | ||||||
|  |         if data.get("tool_calls"): | ||||||
|  |             tool_calls = [ | ||||||
|  |                 ChatMessageToolCall( | ||||||
|  |                     function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"] | ||||||
|  |                 ) | ||||||
|  |                 for tc in data["tool_calls"] | ||||||
|  |             ] | ||||||
|  |             data["tool_calls"] = tool_calls | ||||||
|  |         return cls(**data) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: | def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: | ||||||
|     if isinstance(arguments, dict): |     if isinstance(arguments, dict): | ||||||
|  | @ -595,7 +607,9 @@ class LiteLLMModel(Model): | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
| 
 | 
 | ||||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) |         message = ChatMessage.from_dict( | ||||||
|  |             response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             return parse_tool_args_if_needed(message) |             return parse_tool_args_if_needed(message) | ||||||
|  | @ -664,7 +678,9 @@ class OpenAIServerModel(Model): | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
| 
 | 
 | ||||||
|         message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) |         message = ChatMessage.from_dict( | ||||||
|  |             response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) | ||||||
|  |         ) | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             return parse_tool_args_if_needed(message) |             return parse_tool_args_if_needed(message) | ||||||
|         return message |         return message | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue