Refactor and test final answer prompts (#595)
This commit is contained in:
		
							parent
							
								
									02b2b7ebb9
								
							
						
					
					
						commit
						833aec9198
					
				|  | @ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n | ||||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
| 
 | 
 | ||||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||||
|  |  | ||||||
|  | @ -163,3 +163,5 @@ model = OpenAIServerModel( | ||||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
| 
 | 
 | ||||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||||
|  |  | ||||||
|  | @ -155,3 +155,5 @@ print(model(messages)) | ||||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
| 
 | 
 | ||||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.FinalAnswerPromptTemplate | ||||||
|  |  | ||||||
|  | @ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict): | ||||||
|     report: str |     report: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class FinalAnswerPromptTemplate(TypedDict): | ||||||
|  |     """ | ||||||
|  |     Prompt templates for the final answer. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pre_messages (`str`): Pre-messages prompt. | ||||||
|  |         post_messages (`str`): Post-messages prompt. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     pre_messages: str | ||||||
|  |     post_messages: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class PromptTemplates(TypedDict): | class PromptTemplates(TypedDict): | ||||||
|     """ |     """ | ||||||
|     Prompt templates for the agent. |     Prompt templates for the agent. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         system_prompt (`str`): System prompt. |         system_prompt (`str`): System prompt. | ||||||
|         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. |         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates. | ||||||
|         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. |         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates. | ||||||
|  |         final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     system_prompt: str |     system_prompt: str | ||||||
|     planning: PlanningPromptTemplate |     planning: PlanningPromptTemplate | ||||||
|     managed_agent: ManagedAgentPromptTemplate |     managed_agent: ManagedAgentPromptTemplate | ||||||
|  |     final_answer: FinalAnswerPromptTemplate | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| EMPTY_PROMPT_TEMPLATES = PromptTemplates( | EMPTY_PROMPT_TEMPLATES = PromptTemplates( | ||||||
|  | @ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates( | ||||||
|         update_plan_post_messages="", |         update_plan_post_messages="", | ||||||
|     ), |     ), | ||||||
|     managed_agent=ManagedAgentPromptTemplate(task="", report=""), |     managed_agent=ManagedAgentPromptTemplate(task="", report=""), | ||||||
|  |     final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""), | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -290,46 +306,33 @@ class MultiStepAgent: | ||||||
|         Returns: |         Returns: | ||||||
|             `str`: Final answer to the task. |             `str`: Final answer to the task. | ||||||
|         """ |         """ | ||||||
|         messages = [{"role": MessageRole.SYSTEM, "content": []}] |         messages = [ | ||||||
|  |             { | ||||||
|  |                 "role": MessageRole.SYSTEM, | ||||||
|  |                 "content": [ | ||||||
|  |                     { | ||||||
|  |                         "type": "text", | ||||||
|  |                         "text": self.prompt_templates["final_answer"]["pre_messages"], | ||||||
|  |                     } | ||||||
|  |                 ], | ||||||
|  |             } | ||||||
|  |         ] | ||||||
|         if images: |         if images: | ||||||
|             messages[0]["content"] = [ |  | ||||||
|                 { |  | ||||||
|                     "type": "text", |  | ||||||
|                     "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", |  | ||||||
|                 } |  | ||||||
|             ] |  | ||||||
|             messages[0]["content"].append({"type": "image"}) |             messages[0]["content"].append({"type": "image"}) | ||||||
|             messages += self.write_memory_to_messages()[1:] |         messages += self.write_memory_to_messages()[1:] | ||||||
|             messages += [ |         messages += [ | ||||||
|                 { |             { | ||||||
|                     "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                     "content": [ |                 "content": [ | ||||||
|                         { |                     { | ||||||
|                             "type": "text", |                         "type": "text", | ||||||
|                             "text": f"Based on the above, please provide an answer to the following user request:\n{task}", |                         "text": populate_template( | ||||||
|                         } |                             self.prompt_templates["final_answer"]["post_messages"], variables={"task": task} | ||||||
|                     ], |                         ), | ||||||
|                 } |                     } | ||||||
|             ] |                 ], | ||||||
|         else: |             } | ||||||
|             messages[0]["content"] = [ |         ] | ||||||
|                 { |  | ||||||
|                     "type": "text", |  | ||||||
|                     "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", |  | ||||||
|                 } |  | ||||||
|             ] |  | ||||||
|             messages += self.write_memory_to_messages()[1:] |  | ||||||
|             messages += [ |  | ||||||
|                 { |  | ||||||
|                     "role": MessageRole.USER, |  | ||||||
|                     "content": [ |  | ||||||
|                         { |  | ||||||
|                             "type": "text", |  | ||||||
|                             "text": f"Based on the above, please provide an answer to the following user request:\n{task}", |  | ||||||
|                         } |  | ||||||
|                     ], |  | ||||||
|                 } |  | ||||||
|             ] |  | ||||||
|         try: |         try: | ||||||
|             chat_message: ChatMessage = self.model(messages) |             chat_message: ChatMessage = self.model(messages) | ||||||
|             return chat_message.content |             return chat_message.content | ||||||
|  |  | ||||||
|  | @ -319,3 +319,9 @@ managed_agent: | ||||||
|   report: |- |   report: |- | ||||||
|       Here is the final answer from your managed agent '{{name}}': |       Here is the final answer from your managed agent '{{name}}': | ||||||
|       {{final_answer}} |       {{final_answer}} | ||||||
|  | final_answer: | ||||||
|  |   pre_messages: |- | ||||||
|  |     An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: | ||||||
|  |   post_messages: |- | ||||||
|  |     Based on the above, please provide an answer to the following user request: | ||||||
|  |     {{task}} | ||||||
|  |  | ||||||
|  | @ -262,3 +262,9 @@ managed_agent: | ||||||
|   report: |- |   report: |- | ||||||
|       Here is the final answer from your managed agent '{{name}}': |       Here is the final answer from your managed agent '{{name}}': | ||||||
|       {{final_answer}} |       {{final_answer}} | ||||||
|  | final_answer: | ||||||
|  |   pre_messages: |- | ||||||
|  |     An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: | ||||||
|  |   post_messages: |- | ||||||
|  |     Based on the above, please provide an answer to the following user request: | ||||||
|  |     {{task}} | ||||||
|  |  | ||||||
|  | @ -29,6 +29,7 @@ from smolagents.agents import ( | ||||||
|     MultiStepAgent, |     MultiStepAgent, | ||||||
|     ToolCall, |     ToolCall, | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
|  |     populate_template, | ||||||
| ) | ) | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from smolagents.memory import PlanningStep | from smolagents.memory import PlanningStep | ||||||
|  | @ -771,6 +772,74 @@ class TestMultiStepAgent: | ||||||
|                     assert "type" in content |                     assert "type" in content | ||||||
|                     assert "text" in content |                     assert "text" in content | ||||||
| 
 | 
 | ||||||
|  |     @pytest.mark.parametrize( | ||||||
|  |         "images, expected_messages_list", | ||||||
|  |         [ | ||||||
|  |             ( | ||||||
|  |                 None, | ||||||
|  |                 [ | ||||||
|  |                     [ | ||||||
|  |                         { | ||||||
|  |                             "role": MessageRole.SYSTEM, | ||||||
|  |                             "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}], | ||||||
|  |                         }, | ||||||
|  |                         {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, | ||||||
|  |                     ] | ||||||
|  |                 ], | ||||||
|  |             ), | ||||||
|  |             ( | ||||||
|  |                 ["image1.png"], | ||||||
|  |                 [ | ||||||
|  |                     [ | ||||||
|  |                         { | ||||||
|  |                             "role": MessageRole.SYSTEM, | ||||||
|  |                             "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}], | ||||||
|  |                         }, | ||||||
|  |                         {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, | ||||||
|  |                     ] | ||||||
|  |                 ], | ||||||
|  |             ), | ||||||
|  |         ], | ||||||
|  |     ) | ||||||
|  |     def test_provide_final_answer(self, images, expected_messages_list): | ||||||
|  |         fake_model = MagicMock() | ||||||
|  |         fake_model.return_value.content = "Final answer." | ||||||
|  |         agent = CodeAgent( | ||||||
|  |             tools=[], | ||||||
|  |             model=fake_model, | ||||||
|  |         ) | ||||||
|  |         task = "Test task" | ||||||
|  |         final_answer = agent.provide_final_answer(task, images=images) | ||||||
|  |         expected_message_texts = { | ||||||
|  |             "FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"], | ||||||
|  |             "FINAL_ANSWER_USER_PROMPT": populate_template( | ||||||
|  |                 agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task) | ||||||
|  |             ), | ||||||
|  |         } | ||||||
|  |         for expected_messages in expected_messages_list: | ||||||
|  |             for expected_message in expected_messages: | ||||||
|  |                 for expected_content in expected_message["content"]: | ||||||
|  |                     if "text" in expected_content: | ||||||
|  |                         expected_content["text"] = expected_message_texts[expected_content["text"]] | ||||||
|  |         assert final_answer == "Final answer." | ||||||
|  |         # Test calls to model | ||||||
|  |         assert len(fake_model.call_args_list) == 1 | ||||||
|  |         for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): | ||||||
|  |             assert len(call_args.args) == 1 | ||||||
|  |             messages = call_args.args[0] | ||||||
|  |             assert isinstance(messages, list) | ||||||
|  |             assert len(messages) == len(expected_messages) | ||||||
|  |             for message, expected_message in zip(messages, expected_messages): | ||||||
|  |                 assert isinstance(message, dict) | ||||||
|  |                 assert "role" in message | ||||||
|  |                 assert "content" in message | ||||||
|  |                 assert message["role"] in MessageRole.__members__.values() | ||||||
|  |                 assert message["role"] == expected_message["role"] | ||||||
|  |                 assert isinstance(message["content"], list) | ||||||
|  |                 assert len(message["content"]) == len(expected_message["content"]) | ||||||
|  |                 for content, expected_content in zip(message["content"], expected_message["content"]): | ||||||
|  |                     assert content == expected_content | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class TestCodeAgent: | class TestCodeAgent: | ||||||
|     @pytest.mark.parametrize("provide_run_summary", [False, True]) |     @pytest.mark.parametrize("provide_run_summary", [False, True]) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue