Move plan user prompt to YAML and test text of plan prompts (#591)

This commit is contained in:
Albert Villanova del Moral 2025-02-13 10:13:25 +01:00 committed by GitHub
parent 2797f2fb3b
commit 1516ce8d74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 75 additions and 42 deletions

View File

@ -88,7 +88,8 @@ class PlanningPromptTemplate(TypedDict):
Prompt templates for the planning step.
Args:
initial_facts (`str`): Initial facts prompt.
initial_facts_pre_task (`str`): Initial facts pre-task prompt.
initial_facts_task (`str`): Initial facts task prompt.
initial_plan (`str`): Initial plan prompt.
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
update_facts_post_messages (`str`): Update facts post-messages prompt.
@ -96,7 +97,8 @@ class PlanningPromptTemplate(TypedDict):
update_plan_post_messages (`str`): Update plan post-messages prompt.
"""
initial_facts: str
initial_facts_pre_task: str
initial_facts_task: str
initial_plan: str
update_facts_pre_messages: str
update_facts_post_messages: str
@ -524,26 +526,19 @@ You have been provided with these additional arguments, that you can access usin
step (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}],
}
message_prompt_task = {
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": textwrap.dedent(
f"""Here is the task:
```
{task}
```
Now begin!"""
),
},
],
}
input_messages = [message_prompt_facts, message_prompt_task]
input_messages = [
{
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["initial_facts"], variables={"task": task}
),
}
],
},
]
chat_message_facts: ChatMessage = self.model(input_messages)
answer_facts = chat_message_facts.content

View File

@ -196,6 +196,12 @@ planning:
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else.
Here is the task:
```
{{task}}
```
Now begin!
initial_plan : |-
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.

View File

@ -139,6 +139,12 @@ planning:
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else.
Here is the task:
```
{{task}}
```
Now begin!
initial_plan : |-
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.

View File

@ -697,11 +697,8 @@ class TestMultiStepAgent:
(
1,
[
[
{"role": MessageRole.SYSTEM, "content": [{"type": "text", "text": "FACTS_SYSTEM_PROMPT"}]},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_USER_PROMPT"}]},
],
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_USER_PROMPT"}]}],
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_FACTS_USER_PROMPT"}]}],
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_PLAN_USER_PROMPT"}]}],
],
),
(
@ -710,22 +707,22 @@ class TestMultiStepAgent:
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "FACTS_UPDATE_SYSTEM_PROMPT"}],
"content": [{"type": "text", "text": "UPDATE_FACTS_SYSTEM_PROMPT"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_UPDATE_USER_PROMPT"}]},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_FACTS_USER_PROMPT"}]},
],
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "PLAN_UPDATE_SYSTEM_PROMPT"}],
"content": [{"type": "text", "text": "UPDATE_PLAN_SYSTEM_PROMPT"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_UPDATE_USER_PROMPT"}]},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_PLAN_USER_PROMPT"}]},
],
],
),
],
)
def test_planning_step_first_step(self, step, expected_messages_list):
def test_planning_step(self, step, expected_messages_list):
fake_model = MagicMock()
agent = CodeAgent(
tools=[],
@ -733,6 +730,39 @@ class TestMultiStepAgent:
)
task = "Test task"
agent.planning_step(task, is_first_step=(step == 1), step=step)
expected_message_texts = {
"INITIAL_FACTS_USER_PROMPT": populate_template(
agent.prompt_templates["planning"]["initial_facts"], variables=dict(task=task)
),
"INITIAL_PLAN_USER_PROMPT": populate_template(
agent.prompt_templates["planning"]["initial_plan"],
variables=dict(
task=task,
tools=agent.tools,
managed_agents=agent.managed_agents,
answer_facts=agent.memory.steps[0].model_output_message_facts.content,
),
),
"UPDATE_FACTS_SYSTEM_PROMPT": agent.prompt_templates["planning"]["update_facts_pre_messages"],
"UPDATE_FACTS_USER_PROMPT": agent.prompt_templates["planning"]["update_facts_post_messages"],
"UPDATE_PLAN_SYSTEM_PROMPT": populate_template(
agent.prompt_templates["planning"]["update_plan_pre_messages"], variables=dict(task=task)
),
"UPDATE_PLAN_USER_PROMPT": populate_template(
agent.prompt_templates["planning"]["update_plan_post_messages"],
variables=dict(
task=task,
tools=agent.tools,
managed_agents=agent.managed_agents,
facts_update=agent.memory.steps[0].model_output_message_facts.content,
remaining_steps=agent.max_steps - step,
),
),
}
for expected_messages in expected_messages_list:
for expected_message in expected_messages:
for expected_content in expected_message["content"]:
expected_content["text"] = expected_message_texts[expected_content["text"]]
assert len(agent.memory.steps) == 1
planning_step = agent.memory.steps[0]
assert isinstance(planning_step, PlanningStep)
@ -744,14 +774,12 @@ class TestMultiStepAgent:
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert isinstance(message["role"], MessageRole)
assert message["role"] in MessageRole.__members__.values()
assert message["role"] == expected_message["role"]
assert isinstance(message["content"], list)
assert len(message["content"]) == 1
for content in message["content"]:
assert isinstance(content, dict)
assert "type" in content
assert "text" in content
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content
# Test calls to model
assert len(fake_model.call_args_list) == 2
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
@ -763,14 +791,12 @@ class TestMultiStepAgent:
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert isinstance(message["role"], MessageRole)
assert message["role"] in MessageRole.__members__.values()
assert message["role"] == expected_message["role"]
assert isinstance(message["content"], list)
assert len(message["content"]) == 1
for content in message["content"]:
assert isinstance(content, dict)
assert "type" in content
assert "text" in content
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content
@pytest.mark.parametrize(
"images, expected_messages_list",