Fix tool_calls parsing error in ToolCallingAgent (#160)
This commit is contained in:
		
							parent
							
								
									c611dfc7e5
								
							
						
					
					
						commit
						ad18041078
					
				|  | @ -15,6 +15,7 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import time | import time | ||||||
|  | import json | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||
| 
 | 
 | ||||||
|  | @ -806,9 +807,17 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 tools_to_call_from=list(self.tools.values()), |                 tools_to_call_from=list(self.tools.values()), | ||||||
|                 stop_sequences=["Observation:"], |                 stop_sequences=["Observation:"], | ||||||
|             ) |             ) | ||||||
|             tool_calls = model_message.tool_calls[0] |              | ||||||
|             tool_arguments = tool_calls.function.arguments |             # Extract tool call from model output | ||||||
|             tool_name, tool_call_id = tool_calls.function.name, tool_calls.id |             if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0: | ||||||
|  |                 tool_calls = model_message.tool_calls[0] | ||||||
|  |                 tool_arguments = tool_calls.function.arguments | ||||||
|  |                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id | ||||||
|  |             else: | ||||||
|  |                 start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1 | ||||||
|  |                 tool_calls = json.loads(model_message.content[start:end]) | ||||||
|  |                 tool_arguments = tool_calls["tool_arguments"] | ||||||
|  |                 tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}" | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError( |             raise AgentGenerationError( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue