Fix tool_calls parsing error in ToolCallingAgent (#160)
This commit is contained in:
parent
c611dfc7e5
commit
ad18041078
|
@ -15,6 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -806,9 +807,17 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tools_to_call_from=list(self.tools.values()),
|
tools_to_call_from=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract tool call from model output
|
||||||
|
if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
|
||||||
tool_calls = model_message.tool_calls[0]
|
tool_calls = model_message.tool_calls[0]
|
||||||
tool_arguments = tool_calls.function.arguments
|
tool_arguments = tool_calls.function.arguments
|
||||||
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
|
||||||
|
else:
|
||||||
|
start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
|
||||||
|
tool_calls = json.loads(model_message.content[start:end])
|
||||||
|
tool_arguments = tool_calls["tool_arguments"]
|
||||||
|
tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(
|
raise AgentGenerationError(
|
||||||
|
|
Loading…
Reference in New Issue