Improve tool call argument parsing (#267)
* Improve tool call argument parsing
This commit is contained in:
		
							parent
							
								
									89a6350fe2
								
							
						
					
					
						commit
						0abd91cf72
					
				|  | @ -755,6 +755,8 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 tools_to_call_from=list(self.tools.values()), |                 tools_to_call_from=list(self.tools.values()), | ||||||
|                 stop_sequences=["Observation:"], |                 stop_sequences=["Observation:"], | ||||||
|             ) |             ) | ||||||
|  |             if model_message.tool_calls is None or len(model_message.tool_calls) == 0: | ||||||
|  |                 raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.") | ||||||
|             tool_call = model_message.tool_calls[0] |             tool_call = model_message.tool_calls[0] | ||||||
|             tool_name, tool_call_id = tool_call.function.name, tool_call.id |             tool_name, tool_call_id = tool_call.function.name, tool_call.id | ||||||
|             tool_arguments = tool_call.function.arguments |             tool_arguments = tool_call.function.arguments | ||||||
|  | @ -1022,9 +1024,9 @@ class ManagedAgent: | ||||||
|         """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" |         """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" | ||||||
|         full_task = self.managed_agent_prompt.format(name=self.name, task=task) |         full_task = self.managed_agent_prompt.format(name=self.name, task=task) | ||||||
|         if self.additional_prompting: |         if self.additional_prompting: | ||||||
|             full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip() |             full_task = full_task.replace("\n{additional_prompting}", self.additional_prompting).strip() | ||||||
|         else: |         else: | ||||||
|             full_task = full_task.replace("\n{{additional_prompting}}", "").strip() |             full_task = full_task.replace("\n{additional_prompting}", "").strip() | ||||||
|         return full_task |         return full_task | ||||||
| 
 | 
 | ||||||
|     def __call__(self, request, **kwargs): |     def __call__(self, request, **kwargs): | ||||||
|  |  | ||||||
|  | @ -158,6 +158,8 @@ class DuckDuckGoSearchTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def forward(self, query: str) -> str: |     def forward(self, query: str) -> str: | ||||||
|         results = self.ddgs.text(query, max_results=self.max_results) |         results = self.ddgs.text(query, max_results=self.max_results) | ||||||
|  |         if len(results) == 0: | ||||||
|  |             raise Exception("No results found! Try a less restrictive/shorter query.") | ||||||
|         postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] |         postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] | ||||||
|         return "## Search Results\n\n" + "\n\n".join(postprocessed_results) |         return "## Search Results\n\n" + "\n\n".join(postprocessed_results) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -104,6 +104,22 @@ class ChatMessage: | ||||||
|         return cls(role=message.role, content=message.content, tool_calls=tool_calls) |         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: | ||||||
|  |     if isinstance(arguments, dict): | ||||||
|  |         return arguments | ||||||
|  |     else: | ||||||
|  |         try: | ||||||
|  |             return json.loads(arguments) | ||||||
|  |         except Exception: | ||||||
|  |             return arguments | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage: | ||||||
|  |     for tool_call in message.tool_calls: | ||||||
|  |         tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) | ||||||
|  |     return message | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|     USER = "user" |     USER = "user" | ||||||
|     ASSISTANT = "assistant" |     ASSISTANT = "assistant" | ||||||
|  | @ -181,17 +197,6 @@ def get_clean_message_list( | ||||||
|     return final_message_list |     return final_message_list | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]: |  | ||||||
|     try: |  | ||||||
|         start, end = ( |  | ||||||
|             possible_dictionary.find("{"), |  | ||||||
|             possible_dictionary.rfind("}") + 1, |  | ||||||
|         ) |  | ||||||
|         return json.loads(possible_dictionary[start:end]) |  | ||||||
|     except Exception: |  | ||||||
|         return possible_dictionary |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class Model: | class Model: | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.last_input_token_count = None |         self.last_input_token_count = None | ||||||
|  | @ -304,7 +309,10 @@ class HfApiModel(Model): | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return ChatMessage.from_hf_api(response.choices[0].message) |         message = ChatMessage.from_hf_api(response.choices[0].message) | ||||||
|  |         if tools_to_call_from is not None: | ||||||
|  |             return parse_tool_args_if_needed(message) | ||||||
|  |         return message | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|  | @ -523,7 +531,10 @@ class LiteLLMModel(Model): | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return response.choices[0].message |         message = response.choices[0].message | ||||||
|  |         if tools_to_call_from is not None: | ||||||
|  |             return parse_tool_args_if_needed(message) | ||||||
|  |         return message | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAIServerModel(Model): | class OpenAIServerModel(Model): | ||||||
|  | @ -582,7 +593,7 @@ class OpenAIServerModel(Model): | ||||||
|                 model=self.model_id, |                 model=self.model_id, | ||||||
|                 messages=messages, |                 messages=messages, | ||||||
|                 tools=[get_json_schema(tool) for tool in tools_to_call_from], |                 tools=[get_json_schema(tool) for tool in tools_to_call_from], | ||||||
|                 tool_choice="auto", |                 tool_choice="required", | ||||||
|                 stop=stop_sequences, |                 stop=stop_sequences, | ||||||
|                 **self.kwargs, |                 **self.kwargs, | ||||||
|             ) |             ) | ||||||
|  | @ -595,7 +606,10 @@ class OpenAIServerModel(Model): | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return response.choices[0].message |         message = response.choices[0].message | ||||||
|  |         if tools_to_call_from is not None: | ||||||
|  |             return parse_tool_args_if_needed(message) | ||||||
|  |         return message | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ import unittest | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | ||||||
|  | from smolagents.models import parse_json_if_needed | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ModelTests(unittest.TestCase): | class ModelTests(unittest.TestCase): | ||||||
|  | @ -55,3 +56,20 @@ class ModelTests(unittest.TestCase): | ||||||
|         messages = [{"role": "user", "content": "Hello!"}] |         messages = [{"role": "user", "content": "Hello!"}] | ||||||
|         output = model(messages, stop_sequences=["great"]).content |         output = model(messages, stop_sequences=["great"]).content | ||||||
|         assert output == "assistant\nHello" |         assert output == "assistant\nHello" | ||||||
|  | 
 | ||||||
|  |     def test_parse_json_if_needed(self): | ||||||
|  |         args = "abc" | ||||||
|  |         parsed_args = parse_json_if_needed(args) | ||||||
|  |         assert parsed_args == "abc" | ||||||
|  | 
 | ||||||
|  |         args = '{"a": 3}' | ||||||
|  |         parsed_args = parse_json_if_needed(args) | ||||||
|  |         assert parsed_args == {"a": 3} | ||||||
|  | 
 | ||||||
|  |         args = "3" | ||||||
|  |         parsed_args = parse_json_if_needed(args) | ||||||
|  |         assert parsed_args == 3 | ||||||
|  | 
 | ||||||
|  |         args = 3 | ||||||
|  |         parsed_args = parse_json_if_needed(args) | ||||||
|  |         assert parsed_args == 3 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue