diff --git a/docs/source/en/reference/agents.md b/docs/source/en/reference/agents.md index 7cf44dd..a6f5718 100644 --- a/docs/source/en/reference/agents.md +++ b/docs/source/en/reference/agents.md @@ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n [[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate + +[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate diff --git a/docs/source/hi/reference/agents.md b/docs/source/hi/reference/agents.md index dc3a18c..7d0f2d7 100644 --- a/docs/source/hi/reference/agents.md +++ b/docs/source/hi/reference/agents.md @@ -163,3 +163,5 @@ model = OpenAIServerModel( [[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate + +[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate diff --git a/docs/source/zh/reference/agents.md b/docs/source/zh/reference/agents.md index 471d245..970372f 100644 --- a/docs/source/zh/reference/agents.md +++ b/docs/source/zh/reference/agents.md @@ -155,3 +155,5 @@ print(model(messages)) [[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate + +[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 4158374..72db9d2 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict): report: str +class FinalAnswerPromptTemplate(TypedDict): + """ + Prompt templates for the final answer. + + Args: + pre_messages (`str`): Pre-messages prompt. + post_messages (`str`): Post-messages prompt. + """ + + pre_messages: str + post_messages: str + + class PromptTemplates(TypedDict): """ Prompt templates for the agent. Args: system_prompt (`str`): System prompt. - planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. - managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. + planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates. + managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates. + final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates. """ system_prompt: str planning: PlanningPromptTemplate managed_agent: ManagedAgentPromptTemplate + final_answer: FinalAnswerPromptTemplate EMPTY_PROMPT_TEMPLATES = PromptTemplates( @@ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates( update_plan_post_messages="", ), managed_agent=ManagedAgentPromptTemplate(task="", report=""), + final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""), ) @@ -290,46 +306,33 @@ class MultiStepAgent: Returns: `str`: Final answer to the task. """ - messages = [{"role": MessageRole.SYSTEM, "content": []}] + messages = [ + { + "role": MessageRole.SYSTEM, + "content": [ + { + "type": "text", + "text": self.prompt_templates["final_answer"]["pre_messages"], + } + ], + } + ] if images: - messages[0]["content"] = [ - { - "type": "text", - "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", - } - ] messages[0]["content"].append({"type": "image"}) - messages += self.write_memory_to_messages()[1:] - messages += [ - { - "role": MessageRole.USER, - "content": [ - { - "type": "text", - "text": f"Based on the above, please provide an answer to the following user request:\n{task}", - } - ], - } - ] - else: - messages[0]["content"] = [ - { - "type": "text", - "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", - } - ] - messages += self.write_memory_to_messages()[1:] - messages += [ - { - "role": MessageRole.USER, - "content": [ - { - "type": "text", - "text": f"Based on the above, please provide an answer to the following user request:\n{task}", - } - ], - } - ] + messages += self.write_memory_to_messages()[1:] + messages += [ + { + "role": MessageRole.USER, + "content": [ + { + "type": "text", + "text": populate_template( + self.prompt_templates["final_answer"]["post_messages"], variables={"task": task} + ), + } + ], + } + ] try: chat_message: ChatMessage = self.model(messages) return chat_message.content diff --git a/src/smolagents/prompts/code_agent.yaml b/src/smolagents/prompts/code_agent.yaml index 2076bcc..852c4cf 100644 --- a/src/smolagents/prompts/code_agent.yaml +++ b/src/smolagents/prompts/code_agent.yaml @@ -318,4 +318,10 @@ managed_agent: And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. report: |- Here is the final answer from your managed agent '{{name}}': - {{final_answer}} \ No newline at end of file + {{final_answer}} +final_answer: + pre_messages: |- + An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: + post_messages: |- + Based on the above, please provide an answer to the following user request: + {{task}} diff --git a/src/smolagents/prompts/toolcalling_agent.yaml b/src/smolagents/prompts/toolcalling_agent.yaml index 8e11798..19d67da 100644 --- a/src/smolagents/prompts/toolcalling_agent.yaml +++ b/src/smolagents/prompts/toolcalling_agent.yaml @@ -261,4 +261,10 @@ managed_agent: And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. report: |- Here is the final answer from your managed agent '{{name}}': - {{final_answer}} \ No newline at end of file + {{final_answer}} +final_answer: + pre_messages: |- + An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory: + post_messages: |- + Based on the above, please provide an answer to the following user request: + {{task}} diff --git a/tests/test_agents.py b/tests/test_agents.py index ed885a9..58c6315 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -29,6 +29,7 @@ from smolagents.agents import ( MultiStepAgent, ToolCall, ToolCallingAgent, + populate_template, ) from smolagents.default_tools import PythonInterpreterTool from smolagents.memory import PlanningStep @@ -771,6 +772,74 @@ class TestMultiStepAgent: assert "type" in content assert "text" in content + @pytest.mark.parametrize( + "images, expected_messages_list", + [ + ( + None, + [ + [ + { + "role": MessageRole.SYSTEM, + "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}], + }, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, + ] + ], + ), + ( + ["image1.png"], + [ + [ + { + "role": MessageRole.SYSTEM, + "content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}], + }, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]}, + ] + ], + ), + ], + ) + def test_provide_final_answer(self, images, expected_messages_list): + fake_model = MagicMock() + fake_model.return_value.content = "Final answer." + agent = CodeAgent( + tools=[], + model=fake_model, + ) + task = "Test task" + final_answer = agent.provide_final_answer(task, images=images) + expected_message_texts = { + "FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"], + "FINAL_ANSWER_USER_PROMPT": populate_template( + agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task) + ), + } + for expected_messages in expected_messages_list: + for expected_message in expected_messages: + for expected_content in expected_message["content"]: + if "text" in expected_content: + expected_content["text"] = expected_message_texts[expected_content["text"]] + assert final_answer == "Final answer." + # Test calls to model + assert len(fake_model.call_args_list) == 1 + for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): + assert len(call_args.args) == 1 + messages = call_args.args[0] + assert isinstance(messages, list) + assert len(messages) == len(expected_messages) + for message, expected_message in zip(messages, expected_messages): + assert isinstance(message, dict) + assert "role" in message + assert "content" in message + assert message["role"] in MessageRole.__members__.values() + assert message["role"] == expected_message["role"] + assert isinstance(message["content"], list) + assert len(message["content"]) == len(expected_message["content"]) + for content, expected_content in zip(message["content"], expected_message["content"]): + assert content == expected_content + class TestCodeAgent: @pytest.mark.parametrize("provide_run_summary", [False, True])