Fix tool_calls parsing error in ToolCallingAgent (#160)
This commit is contained in:
parent
c611dfc7e5
commit
ad18041078
|
@ -15,6 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -806,9 +807,17 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tools_to_call_from=list(self.tools.values()),
|
||||
stop_sequences=["Observation:"],
|
||||
)
|
||||
tool_calls = model_message.tool_calls[0]
|
||||
tool_arguments = tool_calls.function.arguments
|
||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||
|
||||
# Extract tool call from model output
|
||||
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
|
||||
tool_calls = model_message.tool_calls[0]
|
||||
tool_arguments = tool_calls.function.arguments
|
||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||
else:
|
||||
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
|
||||
tool_calls = json.loads(model_message.content[start:end])
|
||||
tool_arguments = tool_calls["tool_arguments"]
|
||||
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(
|
||||
|
|
Loading…
Reference in New Issue