Improve tool call argument parsing (#267)
* Improve tool call argument parsing
This commit is contained in:
		
							parent
							
								
									89a6350fe2
								
							
						
					
					
						commit
						0abd91cf72
					
				|  | @ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent): | |||
|                 tools_to_call_from=list(self.tools.values()), | ||||
|                 stop_sequences=["Observation:"], | ||||
|             ) | ||||
|             if model_message.tool_calls is None or len(model_message.tool_calls) == 0: | ||||
|                 raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.") | ||||
|             tool_call = model_message.tool_calls[0] | ||||
|             tool_name, tool_call_id = tool_call.function.name, tool_call.id | ||||
|             tool_arguments = tool_call.function.arguments | ||||
|  | @ -1022,9 +1024,9 @@ class ManagedAgent: | |||
|         """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" | ||||
|         full_task = self.managed_agent_prompt.format(name=self.name, task=task) | ||||
|         if self.additional_prompting: | ||||
|             full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip() | ||||
|             full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip() | ||||
|         else: | ||||
|             full_task = full_task.replace("\n{{additional_prompting}}", "").strip() | ||||
|             full_task = full_task.replace("\n{additional_prompting}", "").strip() | ||||
|         return full_task | ||||
| 
 | ||||
|     def __call__(self, request, **kwargs): | ||||
|  |  | |||
|  | @ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool): | |||
| 
 | ||||
|     def forward(self, query: str) -> str: | ||||
|         results = self.ddgs.text(query, max_results=self.max_results) | ||||
|         if len(results) == 0: | ||||
|             raise Exception("No results found! Try a less restrictive/shorter query.") | ||||
|         postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] | ||||
|         return "## Search Results\n\n" + "\n\n".join(postprocessed_results) | ||||
| 
 | ||||
|  |  | |||
|  | @ -104,6 +104,22 @@ class ChatMessage: | |||
|         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||
| 
 | ||||
| 
 | ||||
| def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: | ||||
|     if isinstance(arguments, dict): | ||||
|         return arguments | ||||
|     else: | ||||
|         try: | ||||
|             return json.loads(arguments) | ||||
|         except Exception: | ||||
|             return arguments | ||||
| 
 | ||||
| 
 | ||||
| def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: | ||||
|     for tool_call in message.tool_calls: | ||||
|         tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) | ||||
|     return message | ||||
| 
 | ||||
| 
 | ||||
| class MessageRole(str, Enum): | ||||
|     USER = "user" | ||||
|     ASSISTANT = "assistant" | ||||
|  | @ -181,17 +197,6 @@ def get_clean_message_list( | |||
|     return final_message_list | ||||
| 
 | ||||
| 
 | ||||
| def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: | ||||
|     try: | ||||
|         start, end = ( | ||||
|             possible_dictionary.find("{"), | ||||
|             possible_dictionary.rfind("}") + 1, | ||||
|         ) | ||||
|         return json.loads(possible_dictionary[start:end]) | ||||
|     except Exception: | ||||
|         return possible_dictionary | ||||
| 
 | ||||
| 
 | ||||
| class Model: | ||||
|     def __init__(self): | ||||
|         self.last_input_token_count = None | ||||
|  | @ -304,7 +309,10 @@ class HfApiModel(Model): | |||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return ChatMessage.from_hf_api(response.choices[0].message) | ||||
|         message = ChatMessage.from_hf_api(response.choices[0].message) | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
| 
 | ||||
| 
 | ||||
| class TransformersModel(Model): | ||||
|  | @ -523,7 +531,10 @@ class LiteLLMModel(Model): | |||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return response.choices[0].message | ||||
|         message = response.choices[0].message | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIServerModel(Model): | ||||
|  | @ -582,7 +593,7 @@ class OpenAIServerModel(Model): | |||
|                 model=self.model_id, | ||||
|                 messages=messages, | ||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||
|                 tool_choice="auto", | ||||
|                 tool_choice="required", | ||||
|                 stop=stop_sequences, | ||||
|                 **self.kwargs, | ||||
|             ) | ||||
|  | @ -595,7 +606,10 @@ class OpenAIServerModel(Model): | |||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return response.choices[0].message | ||||
|         message = response.choices[0].message | ||||
|         if tools_to_call_from is not None: | ||||
|             return parse_tool_args_if_needed(message) | ||||
|         return message | ||||
| 
 | ||||
| 
 | ||||
| __all__ = [ | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ import unittest | |||
| from typing import Optional | ||||
| 
 | ||||
| from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | ||||
| from smolagents.models import parse_json_if_needed | ||||
| 
 | ||||
| 
 | ||||
| class ModelTests(unittest.TestCase): | ||||
|  | @ -55,3 +56,20 @@ class ModelTests(unittest.TestCase): | |||
|         messages = [{"role": "user", "content": "Hello!"}] | ||||
|         output = model(messages, stop_sequences=["great"]).content | ||||
|         assert output == "assistant\nHello" | ||||
| 
 | ||||
|     def test_parse_json_if_needed(self): | ||||
|         args = "abc" | ||||
|         parsed_args = parse_json_if_needed(args) | ||||
|         assert parsed_args == "abc" | ||||
| 
 | ||||
|         args = '{"a": 3}' | ||||
|         parsed_args = parse_json_if_needed(args) | ||||
|         assert parsed_args == {"a": 3} | ||||
| 
 | ||||
|         args = "3" | ||||
|         parsed_args = parse_json_if_needed(args) | ||||
|         assert parsed_args == 3 | ||||
| 
 | ||||
|         args = 3 | ||||
|         parsed_args = parse_json_if_needed(args) | ||||
|         assert parsed_args == 3 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue