From ad180410789af353b0a13e601b8598813fde2ffa Mon Sep 17 00:00:00 2001 From: tanhuajie <68807603+tanhuajie@users.noreply.github.com> Date: Tue, 14 Jan 2025 00:24:18 +0800 Subject: [PATCH] Fix tool_calls parsing error in ToolCallingAgent (#160) --- src/smolagents/agents.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 832ac8e..5741ce9 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +import json from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -806,9 +807,17 @@ class ToolCallingAgent(MultiStepAgent): tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) - tool_calls = model_message.tool_calls[0] - tool_arguments = tool_calls.function.arguments - tool_name, tool_call_id = tool_calls.function.name, tool_calls.id + + # Extract tool call from model output + if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0: + tool_calls = model_message.tool_calls[0] + tool_arguments = tool_calls.function.arguments + tool_name, tool_call_id = tool_calls.function.name, tool_calls.id + else: + start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1 + tool_calls = json.loads(model_message.content[start:end]) + tool_arguments = tool_calls["tool_arguments"] + tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}" except Exception as e: raise AgentGenerationError(