Refactor and test final answer prompts (#595)
This commit is contained in:
		
							parent
							
								
									02b2b7ebb9
								
							
						
					
					
						commit
						833aec9198
					
				|  | @ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n | |||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||
|  |  | |||
|  | @ -163,3 +163,5 @@ model = OpenAIServerModel( | |||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||
|  |  | |||
|  | @ -155,3 +155,5 @@ print(model(messages)) | |||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||
|  |  | |||
|  | @ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict): | |||
|     report: str | ||||
| 
 | ||||
| 
 | ||||
| class FinalAnswerPromptTemplate(TypedDict): | ||||
|     """ | ||||
|     Prompt templates for the final answer. | ||||
| 
 | ||||
|     Args: | ||||
|         pre_messages (`str`): Pre-messages prompt. | ||||
|         post_messages (`str`): Post-messages prompt. | ||||
|     """ | ||||
| 
 | ||||
|     pre_messages: str | ||||
|     post_messages: str | ||||
| 
 | ||||
| 
 | ||||
| class PromptTemplates(TypedDict): | ||||
|     """ | ||||
|     Prompt templates for the agent. | ||||
| 
 | ||||
|     Args: | ||||
|         system_prompt (`str`): System prompt. | ||||
|         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. | ||||
|         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. | ||||
|         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates. | ||||
|         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates. | ||||
|         final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates. | ||||
|     """ | ||||
| 
 | ||||
|     system_prompt: str | ||||
|     planning: PlanningPromptTemplate | ||||
|     managed_agent: ManagedAgentPromptTemplate | ||||
|     final_answer: FinalAnswerPromptTemplate | ||||
| 
 | ||||
| 
 | ||||
| EMPTY_PROMPT_TEMPLATES = PromptTemplates( | ||||
|  | @ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates( | |||
|         update_plan_post_messages="", | ||||
|     ), | ||||
|     managed_agent=ManagedAgentPromptTemplate(task="", report=""), | ||||
|     final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -290,46 +306,33 @@ class MultiStepAgent: | |||
|         Returns: | ||||
|             `str`: Final answer to the task. | ||||
|         """ | ||||
|         messages = [{"role": MessageRole.SYSTEM, "content": []}] | ||||
|         messages = [ | ||||
|             { | ||||
|                 "role": MessageRole.SYSTEM, | ||||
|                 "content": [ | ||||
|                     { | ||||
|                         "type": "text", | ||||
|                         "text": self.prompt_templates["final_answer"]["pre_messages"], | ||||
|                     } | ||||
|                 ], | ||||
|             } | ||||
|         ] | ||||
|         if images: | ||||
|             messages[0]["content"] = [ | ||||
|                 { | ||||
|                     "type": "text", | ||||
|                     "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||||
|                 } | ||||
|             ] | ||||
|             messages[0]["content"].append({"type": "image"}) | ||||
|             messages += self.write_memory_to_messages()[1:] | ||||
|             messages += [ | ||||
|                 { | ||||
|                     "role": MessageRole.USER, | ||||
|                     "content": [ | ||||
|                         { | ||||
|                             "type": "text", | ||||
|                             "text": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||||
|                         } | ||||
|                     ], | ||||
|                 } | ||||
|             ] | ||||
|         else: | ||||
|             messages[0]["content"] = [ | ||||
|                 { | ||||
|                     "type": "text", | ||||
|                     "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", | ||||
|                 } | ||||
|             ] | ||||
|             messages += self.write_memory_to_messages()[1:] | ||||
|             messages += [ | ||||
|                 { | ||||
|                     "role": MessageRole.USER, | ||||
|                     "content": [ | ||||
|                         { | ||||
|                             "type": "text", | ||||
|                             "text": f"Based on the above, please provide an answer to the following user request:\n{task}", | ||||
|                         } | ||||
|                     ], | ||||
|                 } | ||||
|             ] | ||||
|         messages += self.write_memory_to_messages()[1:] | ||||
|         messages += [ | ||||
|             { | ||||
|                 "role": MessageRole.USER, | ||||
|                 "content": [ | ||||
|                     { | ||||
|                         "type": "text", | ||||
|                         "text": populate_template( | ||||
|                             self.prompt_templates["final_answer"]["post_messages"], variables={"task": task} | ||||
|                         ), | ||||
|                     } | ||||
|                 ], | ||||
|             } | ||||
|         ] | ||||
|         try: | ||||
|             chat_message: ChatMessage = self.model(messages) | ||||
|             return chat_message.content | ||||
|  |  | |||
|  | @ -318,4 +318,10 @@ managed_agent: | |||
|       And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. | ||||
|   report: |- | ||||
|       Here is the final answer from your managed agent '{{name}}': | ||||
|       {{final_answer}} | ||||
|       {{final_answer}} | ||||
| final_answer: | ||||
|   pre_messages: |- | ||||
|     An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: | ||||
|   post_messages: |- | ||||
|     Based on the above, please provide an answer to the following user request: | ||||
|     {{task}} | ||||
|  |  | |||
|  | @ -261,4 +261,10 @@ managed_agent: | |||
|       And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. | ||||
|   report: |- | ||||
|       Here is the final answer from your managed agent '{{name}}': | ||||
|       {{final_answer}} | ||||
|       {{final_answer}} | ||||
| final_answer: | ||||
|   pre_messages: |- | ||||
|     An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: | ||||
|   post_messages: |- | ||||
|     Based on the above, please provide an answer to the following user request: | ||||
|     {{task}} | ||||
|  |  | |||
|  | @ -29,6 +29,7 @@ from smolagents.agents import ( | |||
|     MultiStepAgent, | ||||
|     ToolCall, | ||||
|     ToolCallingAgent, | ||||
|     populate_template, | ||||
| ) | ||||
| from smolagents.default_tools import PythonInterpreterTool | ||||
| from smolagents.memory import PlanningStep | ||||
|  | @ -771,6 +772,74 @@ class TestMultiStepAgent: | |||
|                     assert "type" in content | ||||
|                     assert "text" in content | ||||
| 
 | ||||
|     @pytest.mark.parametrize( | ||||
|         "images, expected_messages_list", | ||||
|         [ | ||||
|             ( | ||||
|                 None, | ||||
|                 [ | ||||
|                     [ | ||||
|                         { | ||||
|                             "role": MessageRole.SYSTEM, | ||||
|                             "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}], | ||||
|                         }, | ||||
|                         {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, | ||||
|                     ] | ||||
|                 ], | ||||
|             ), | ||||
|             ( | ||||
|                 ["image1.png"], | ||||
|                 [ | ||||
|                     [ | ||||
|                         { | ||||
|                             "role": MessageRole.SYSTEM, | ||||
|                             "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}], | ||||
|                         }, | ||||
|                         {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, | ||||
|                     ] | ||||
|                 ], | ||||
|             ), | ||||
|         ], | ||||
|     ) | ||||
|     def test_provide_final_answer(self, images, expected_messages_list): | ||||
|         fake_model = MagicMock() | ||||
|         fake_model.return_value.content = "Final answer." | ||||
|         agent = CodeAgent( | ||||
|             tools=[], | ||||
|             model=fake_model, | ||||
|         ) | ||||
|         task = "Test task" | ||||
|         final_answer = agent.provide_final_answer(task, images=images) | ||||
|         expected_message_texts = { | ||||
|             "FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"], | ||||
|             "FINAL_ANSWER_USER_PROMPT": populate_template( | ||||
|                 agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task) | ||||
|             ), | ||||
|         } | ||||
|         for expected_messages in expected_messages_list: | ||||
|             for expected_message in expected_messages: | ||||
|                 for expected_content in expected_message["content"]: | ||||
|                     if "text" in expected_content: | ||||
|                         expected_content["text"] = expected_message_texts[expected_content["text"]] | ||||
|         assert final_answer == "Final answer." | ||||
|         # Test calls to model | ||||
|         assert len(fake_model.call_args_list) == 1 | ||||
|         for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): | ||||
|             assert len(call_args.args) == 1 | ||||
|             messages = call_args.args[0] | ||||
|             assert isinstance(messages, list) | ||||
|             assert len(messages) == len(expected_messages) | ||||
|             for message, expected_message in zip(messages, expected_messages): | ||||
|                 assert isinstance(message, dict) | ||||
|                 assert "role" in message | ||||
|                 assert "content" in message | ||||
|                 assert message["role"] in MessageRole.__members__.values() | ||||
|                 assert message["role"] == expected_message["role"] | ||||
|                 assert isinstance(message["content"], list) | ||||
|                 assert len(message["content"]) == len(expected_message["content"]) | ||||
|                 for content, expected_content in zip(message["content"], expected_message["content"]): | ||||
|                     assert content == expected_content | ||||
| 
 | ||||
| 
 | ||||
| class TestCodeAgent: | ||||
|     @pytest.mark.parametrize("provide_run_summary", [False, True]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue