From fe2f4e735caae669949dae31905756ad70fcf63e Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:42:10 +0100 Subject: [PATCH] Fix tool calls with LiteLLM and tool optional types (#318) --- src/smolagents/_function_type_hints_utils.py | 3 ++- src/smolagents/models.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/smolagents/_function_type_hints_utils.py b/src/smolagents/_function_type_hints_utils.py index b076941..5eb9502 100644 --- a/src/smolagents/_function_type_hints_utils.py +++ b/src/smolagents/_function_type_hints_utils.py @@ -27,6 +27,7 @@ import json import os import re import types +from copy import copy from typing import ( Any, Callable, @@ -381,7 +382,7 @@ _BASE_TYPE_MAPPING = { def _get_json_schema_type(param_type: str) -> Dict[str, str]: if param_type in _BASE_TYPE_MAPPING: - return _BASE_TYPE_MAPPING[param_type] + return copy(_BASE_TYPE_MAPPING[param_type]) if str(param_type) == "Image" and _is_pillow_available(): from PIL.Image import Image diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 575d2a5..6638783 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -101,6 +101,18 @@ class ChatMessage: tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] return cls(role=message.role, content=message.content, tool_calls=tool_calls) + @classmethod + def from_dict(cls, data: dict) -> "ChatMessage": + if data.get("tool_calls"): + tool_calls = [ + ChatMessageToolCall( + function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"] + ) + for tc in data["tool_calls"] + ] + data["tool_calls"] = tool_calls + return cls(**data) + def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: if isinstance(arguments, dict): @@ -595,7 +607,9 @@ class LiteLLMModel(Model): self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) + message = ChatMessage.from_dict( + response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) + ) if tools_to_call_from is not None: return parse_tool_args_if_needed(message) @@ -664,7 +678,9 @@ class OpenAIServerModel(Model): self.last_input_token_count = response.usage.prompt_tokens self.last_output_token_count = response.usage.completion_tokens - message = ChatMessage(**response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})) + message = ChatMessage.from_dict( + response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) + ) if tools_to_call_from is not None: return parse_tool_args_if_needed(message) return message