diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 832ac8e..5741ce9 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +import json from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -806,9 +807,17 @@ class ToolCallingAgent(MultiStepAgent): tools_to_call_from=list(self.tools.values()), stop_sequences=["Observation:"], ) - tool_calls = model_message.tool_calls[0] - tool_arguments = tool_calls.function.arguments - tool_name, tool_call_id = tool_calls.function.name, tool_calls.id + + # Extract tool call from model output + if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0: + tool_calls = model_message.tool_calls[0] + tool_arguments = tool_calls.function.arguments + tool_name, tool_call_id = tool_calls.function.name, tool_calls.id + else: + start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1 + tool_calls = json.loads(model_message.content[start:end]) + tool_arguments = tool_calls["tool_arguments"] + tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}" except Exception as e: raise AgentGenerationError(