diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 5e5fe41..73b2647 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -529,9 +529,9 @@ You have been provided with these additional arguments, that you can access usin level=LogLevel.INFO, ) else: # update plan - memory_messages = self.write_memory_to_messages( - summary_mode=False - ) # This will not log the plan but will log facts + # Do not take the system prompt message from the memory + # summary_mode=False: Do not take previous plan steps to avoid influencing the new plan + memory_messages = self.write_memory_to_messages()[1:] # Redact updated facts facts_update_pre_messages = { diff --git a/tests/test_agents.py b/tests/test_agents.py index 631ccf8..b097132 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -690,27 +690,61 @@ class TestMultiStepAgent: assert hasattr(agent, "step_number"), "step_number attribute should be defined" assert agent.step_number == max_steps + 1, "step_number should be max_steps + 1 after run method is called" - def test_planning_step_first_step(self): + @pytest.mark.parametrize( + "step, expected_messages_list", + [ + ( + 1, + [ + [ + {"role": MessageRole.SYSTEM, "content": [{"type": "text", "text": "FACTS_SYSTEM_PROMPT"}]}, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_USER_PROMPT"}]}, + ], + [{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_USER_PROMPT"}]}], + ], + ), + ( + 2, + [ + [ + { + "role": MessageRole.SYSTEM, + "content": [{"type": "text", "text": "FACTS_UPDATE_SYSTEM_PROMPT"}], + }, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_UPDATE_USER_PROMPT"}]}, + ], + [ + { + "role": MessageRole.SYSTEM, + "content": [{"type": "text", "text": "PLAN_UPDATE_SYSTEM_PROMPT"}], + }, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_UPDATE_USER_PROMPT"}]}, + ], + ], + ), + ], + ) + def test_planning_step_first_step(self, step, expected_messages_list): fake_model = MagicMock() agent = CodeAgent( tools=[], model=fake_model, ) task = "Test task" - agent.planning_step(task, is_first_step=True, step=0) + agent.planning_step(task, is_first_step=(step == 1), step=step) assert len(agent.memory.steps) == 1 planning_step = agent.memory.steps[0] assert isinstance(planning_step, PlanningStep) - messages = planning_step.model_input_messages - assert isinstance(messages, list) - assert len(messages) == 2 - expected_roles = [MessageRole.SYSTEM, MessageRole.USER] - for i, message in enumerate(messages): + expected_model_input_messages = expected_messages_list[0] + model_input_messages = planning_step.model_input_messages + assert isinstance(model_input_messages, list) + assert len(model_input_messages) == len(expected_model_input_messages) # 2 + for message, expected_message in zip(model_input_messages, expected_model_input_messages): assert isinstance(message, dict) assert "role" in message assert "content" in message assert isinstance(message["role"], MessageRole) - assert message["role"] == expected_roles[i] + assert message["role"] == expected_message["role"] assert isinstance(message["content"], list) assert len(message["content"]) == 1 for content in message["content"]: @@ -719,16 +753,17 @@ class TestMultiStepAgent: assert "text" in content # Test calls to model assert len(fake_model.call_args_list) == 2 - for call_args in fake_model.call_args_list: + for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): assert len(call_args.args) == 1 messages = call_args.args[0] assert isinstance(messages, list) - # assert len(messages) == 1 # TODO - for message in messages: + assert len(messages) == len(expected_messages) + for message, expected_message in zip(messages, expected_messages): assert isinstance(message, dict) assert "role" in message assert "content" in message assert isinstance(message["role"], MessageRole) + assert message["role"] == expected_message["role"] assert isinstance(message["content"], list) assert len(message["content"]) == 1 for content in message["content"]: