Delete prompts_path argument and use prompt_templates (#541)
This commit is contained in:
		
							parent
							
								
									fd9eec8433
								
							
						
					
					
						commit
						ecabb9ea4f
					
				|  | @ -88,7 +88,7 @@ class MultiStepAgent: | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. | ||||
|         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. | ||||
|         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. | ||||
|  | @ -107,7 +107,7 @@ class MultiStepAgent: | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompts_path: Optional[str] = None, | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         max_steps: int = 6, | ||||
|         tool_parser: Optional[Callable] = None, | ||||
|         add_base_tools: bool = False, | ||||
|  | @ -125,6 +125,7 @@ class MultiStepAgent: | |||
|             tool_parser = parse_json_tool_call | ||||
|         self.agent_name = self.__class__.__name__ | ||||
|         self.model = model | ||||
|         self.prompt_templates = prompt_templates or {} | ||||
|         self.max_steps = max_steps | ||||
|         self.step_number: int = 0 | ||||
|         self.tool_parser = tool_parser | ||||
|  | @ -633,7 +634,7 @@ class ToolCallingAgent(MultiStepAgent): | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||
|         **kwargs: Additional keyword arguments. | ||||
|     """ | ||||
|  | @ -642,17 +643,17 @@ class ToolCallingAgent(MultiStepAgent): | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompts_path: Optional[str] = None, | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         planning_interval: Optional[int] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         self.prompt_templates = yaml.safe_load( | ||||
|         prompt_templates = prompt_templates or yaml.safe_load( | ||||
|             importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml") | ||||
|         ) | ||||
|         super().__init__( | ||||
|             tools=tools, | ||||
|             model=model, | ||||
|             prompts_path=prompts_path, | ||||
|             prompt_templates=prompt_templates, | ||||
|             planning_interval=planning_interval, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | @ -755,7 +756,7 @@ class CodeAgent(MultiStepAgent): | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. | ||||
|         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. | ||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||
|  | @ -769,7 +770,7 @@ class CodeAgent(MultiStepAgent): | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompts_path: Optional[str] = None, | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         grammar: Optional[Dict[str, str]] = None, | ||||
|         additional_authorized_imports: Optional[List[str]] = None, | ||||
|         planning_interval: Optional[int] = None, | ||||
|  | @ -779,10 +780,13 @@ class CodeAgent(MultiStepAgent): | |||
|     ): | ||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] | ||||
|         self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) | ||||
|         self.prompt_templates = yaml.safe_load(importlib.resources.read_text("smolagents.prompts", "code_agent.yaml")) | ||||
|         prompt_templates = prompt_templates or yaml.safe_load( | ||||
|             importlib.resources.read_text("smolagents.prompts", "code_agent.yaml") | ||||
|         ) | ||||
|         super().__init__( | ||||
|             tools=tools, | ||||
|             model=model, | ||||
|             prompt_templates=prompt_templates, | ||||
|             grammar=grammar, | ||||
|             planning_interval=planning_interval, | ||||
|             **kwargs, | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ import uuid | |||
| from pathlib import Path | ||||
| from unittest.mock import MagicMock | ||||
| 
 | ||||
| import pytest | ||||
| from transformers.testing_utils import get_tests_dir | ||||
| 
 | ||||
| from smolagents.agent_types import AgentImage, AgentText | ||||
|  | @ -664,11 +665,19 @@ nested_answer() | |||
| 
 | ||||
| 
 | ||||
| class TestMultiStepAgent: | ||||
|     def test_logging_to_terminal_is_disabled(self): | ||||
|     def test_instantiation_disables_logging_to_terminal(self): | ||||
|         fake_model = MagicMock() | ||||
|         agent = MultiStepAgent(tools=[], model=fake_model) | ||||
|         assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture" | ||||
| 
 | ||||
|     def test_instantiation_with_prompt_templates(self, prompt_templates): | ||||
|         agent = MultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates) | ||||
|         assert agent.prompt_templates == prompt_templates | ||||
|         assert agent.prompt_templates["system_prompt"] == "This is a test system prompt." | ||||
|         assert "managed_agent" in agent.prompt_templates | ||||
|         assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}" | ||||
|         assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}" | ||||
| 
 | ||||
|     def test_step_number(self): | ||||
|         fake_model = MagicMock() | ||||
|         fake_model.last_input_token_count = 10 | ||||
|  | @ -724,3 +733,11 @@ class TestMultiStepAgent: | |||
|                     assert isinstance(content, dict) | ||||
|                     assert "type" in content | ||||
|                     assert "text" in content | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def prompt_templates(): | ||||
|     return { | ||||
|         "system_prompt": "This is a test system prompt.", | ||||
|         "managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"}, | ||||
|     } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue