Fix MultiStepAgent.planning_step message content (#437)
* Fix MultiStepAgent.planning_step message content
This commit is contained in:
		
							parent
							
								
									6d0e4e49fc
								
							
						
					
					
						commit
						42d97716fe
					
				|  | @ -494,15 +494,20 @@ You have been provided with these additional arguments, that you can access usin | ||||||
|         if is_first_step: |         if is_first_step: | ||||||
|             message_prompt_facts = { |             message_prompt_facts = { | ||||||
|                 "role": MessageRole.SYSTEM, |                 "role": MessageRole.SYSTEM, | ||||||
|                 "content": SYSTEM_PROMPT_FACTS, |                 "content": [{"type": "text", "text": SYSTEM_PROMPT_FACTS}], | ||||||
|             } |             } | ||||||
|             message_prompt_task = { |             message_prompt_task = { | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": f"""Here is the task: |                 "content": [ | ||||||
|  |                     { | ||||||
|  |                         "type": "text", | ||||||
|  |                         "text": f"""Here is the task: | ||||||
| ``` | ``` | ||||||
| {task} | {task} | ||||||
| ``` | ``` | ||||||
| Now begin!""", | Now begin!""", | ||||||
|  |                     } | ||||||
|  |                 ], | ||||||
|             } |             } | ||||||
|             input_messages = [message_prompt_facts, message_prompt_task] |             input_messages = [message_prompt_facts, message_prompt_task] | ||||||
| 
 | 
 | ||||||
|  | @ -511,16 +516,21 @@ Now begin!""", | ||||||
| 
 | 
 | ||||||
|             message_system_prompt_plan = { |             message_system_prompt_plan = { | ||||||
|                 "role": MessageRole.SYSTEM, |                 "role": MessageRole.SYSTEM, | ||||||
|                 "content": SYSTEM_PROMPT_PLAN, |                 "content": [{"type": "text", "text": SYSTEM_PROMPT_PLAN}], | ||||||
|             } |             } | ||||||
|             message_user_prompt_plan = { |             message_user_prompt_plan = { | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": USER_PROMPT_PLAN.format( |                 "content": [ | ||||||
|                     task=task, |                     { | ||||||
|                     tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), |                         "type": "text", | ||||||
|                     managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), |                         "text": USER_PROMPT_PLAN.format( | ||||||
|                     answer_facts=answer_facts, |                             task=task, | ||||||
|                 ), |                             tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), | ||||||
|  |                             managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), | ||||||
|  |                             answer_facts=answer_facts, | ||||||
|  |                         ), | ||||||
|  |                     } | ||||||
|  |                 ], | ||||||
|             } |             } | ||||||
|             chat_message_plan: ChatMessage = self.model( |             chat_message_plan: ChatMessage = self.model( | ||||||
|                 [message_system_prompt_plan, message_user_prompt_plan], |                 [message_system_prompt_plan, message_user_prompt_plan], | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ import tempfile | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from unittest.mock import MagicMock | ||||||
| 
 | 
 | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
|  | @ -25,11 +26,19 @@ from smolagents.agents import ( | ||||||
|     AgentMaxStepsError, |     AgentMaxStepsError, | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     ManagedAgent, |     ManagedAgent, | ||||||
|  |     MultiStepAgent, | ||||||
|     ToolCall, |     ToolCall, | ||||||
|     ToolCallingAgent, |     ToolCallingAgent, | ||||||
| ) | ) | ||||||
| from smolagents.default_tools import PythonInterpreterTool | from smolagents.default_tools import PythonInterpreterTool | ||||||
| from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel | from smolagents.memory import PlanningStep | ||||||
|  | from smolagents.models import ( | ||||||
|  |     ChatMessage, | ||||||
|  |     ChatMessageToolCall, | ||||||
|  |     ChatMessageToolCallDefinition, | ||||||
|  |     MessageRole, | ||||||
|  |     TransformersModel, | ||||||
|  | ) | ||||||
| from smolagents.tools import tool | from smolagents.tools import tool | ||||||
| from smolagents.utils import BASE_BUILTIN_MODULES | from smolagents.utils import BASE_BUILTIN_MODULES | ||||||
| 
 | 
 | ||||||
|  | @ -644,3 +653,49 @@ nested_answer() | ||||||
|         assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather" |         assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather" | ||||||
|         assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100 |         assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100 | ||||||
|         assert "model_input_messages" in agent.memory.get_full_steps()[1] |         assert "model_input_messages" in agent.memory.get_full_steps()[1] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TestMultiStepAgent: | ||||||
|  |     def test_planning_step_first_step(self): | ||||||
|  |         fake_model = MagicMock() | ||||||
|  |         agent = MultiStepAgent( | ||||||
|  |             tools=[], | ||||||
|  |             model=fake_model, | ||||||
|  |         ) | ||||||
|  |         task = "Test task" | ||||||
|  |         agent.planning_step(task, is_first_step=True, step=0) | ||||||
|  |         assert len(agent.memory.steps) == 1 | ||||||
|  |         planning_step = agent.memory.steps[0] | ||||||
|  |         assert isinstance(planning_step, PlanningStep) | ||||||
|  |         messages = planning_step.model_input_messages | ||||||
|  |         assert isinstance(messages, list) | ||||||
|  |         assert len(messages) == 2 | ||||||
|  |         for message in messages: | ||||||
|  |             assert isinstance(message, dict) | ||||||
|  |             assert "role" in message | ||||||
|  |             assert "content" in message | ||||||
|  |             assert isinstance(message["role"], MessageRole) | ||||||
|  |             assert isinstance(message["content"], list) | ||||||
|  |             assert len(message["content"]) == 1 | ||||||
|  |             for content in message["content"]: | ||||||
|  |                 assert isinstance(content, dict) | ||||||
|  |                 assert "type" in content | ||||||
|  |                 assert "text" in content | ||||||
|  |         # Test calls to model | ||||||
|  |         assert len(fake_model.call_args_list) == 2 | ||||||
|  |         for call_args in fake_model.call_args_list: | ||||||
|  |             assert len(call_args.args) == 1 | ||||||
|  |             messages = call_args.args[0] | ||||||
|  |             assert isinstance(messages, list) | ||||||
|  |             assert len(messages) == 2 | ||||||
|  |             for message in messages: | ||||||
|  |                 assert isinstance(message, dict) | ||||||
|  |                 assert "role" in message | ||||||
|  |                 assert "content" in message | ||||||
|  |                 assert isinstance(message["role"], MessageRole) | ||||||
|  |                 assert isinstance(message["content"], list) | ||||||
|  |                 assert len(message["content"]) == 1 | ||||||
|  |                 for content in message["content"]: | ||||||
|  |                     assert isinstance(content, dict) | ||||||
|  |                     assert "type" in content | ||||||
|  |                     assert "text" in content | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue