Delete prompts_path argument and use prompt_templates (#541)
This commit is contained in:
		
							parent
							
								
									fd9eec8433
								
							
						
					
					
						commit
						ecabb9ea4f
					
				|  | @ -88,7 +88,7 @@ class MultiStepAgent: | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. |         prompt_templates (`dict`, *optional*): Prompt templates. | ||||||
|         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. |         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. | ||||||
|         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. |         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. | ||||||
|         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. |         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. | ||||||
|  | @ -107,7 +107,7 @@ class MultiStepAgent: | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompts_path: Optional[str] = None, |         prompt_templates: Optional[dict] = None, | ||||||
|         max_steps: int = 6, |         max_steps: int = 6, | ||||||
|         tool_parser: Optional[Callable] = None, |         tool_parser: Optional[Callable] = None, | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|  | @ -125,6 +125,7 @@ class MultiStepAgent: | ||||||
|             tool_parser = parse_json_tool_call |             tool_parser = parse_json_tool_call | ||||||
|         self.agent_name = self.__class__.__name__ |         self.agent_name = self.__class__.__name__ | ||||||
|         self.model = model |         self.model = model | ||||||
|  |         self.prompt_templates = prompt_templates or {} | ||||||
|         self.max_steps = max_steps |         self.max_steps = max_steps | ||||||
|         self.step_number: int = 0 |         self.step_number: int = 0 | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|  | @ -633,7 +634,7 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. |         prompt_templates (`dict`, *optional*): Prompt templates. | ||||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. |         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||||
|         **kwargs: Additional keyword arguments. |         **kwargs: Additional keyword arguments. | ||||||
|     """ |     """ | ||||||
|  | @ -642,17 +643,17 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompts_path: Optional[str] = None, |         prompt_templates: Optional[dict] = None, | ||||||
|         planning_interval: Optional[int] = None, |         planning_interval: Optional[int] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         self.prompt_templates = yaml.safe_load( |         prompt_templates = prompt_templates or yaml.safe_load( | ||||||
|             importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml") |             importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml") | ||||||
|         ) |         ) | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             tools=tools, |             tools=tools, | ||||||
|             model=model, |             model=model, | ||||||
|             prompts_path=prompts_path, |             prompt_templates=prompt_templates, | ||||||
|             planning_interval=planning_interval, |             planning_interval=planning_interval, | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  | @ -755,7 +756,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. |         prompt_templates (`dict`, *optional*): Prompt templates. | ||||||
|         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. |         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. | ||||||
|         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. |         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. | ||||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. |         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||||
|  | @ -769,7 +770,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompts_path: Optional[str] = None, |         prompt_templates: Optional[dict] = None, | ||||||
|         grammar: Optional[Dict[str, str]] = None, |         grammar: Optional[Dict[str, str]] = None, | ||||||
|         additional_authorized_imports: Optional[List[str]] = None, |         additional_authorized_imports: Optional[List[str]] = None, | ||||||
|         planning_interval: Optional[int] = None, |         planning_interval: Optional[int] = None, | ||||||
|  | @ -779,10 +780,13 @@ class CodeAgent(MultiStepAgent): | ||||||
|     ): |     ): | ||||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] |         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] | ||||||
|         self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) |         self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) | ||||||
|         self.prompt_templates = yaml.safe_load(importlib.resources.read_text("smolagents.prompts", "code_agent.yaml")) |         prompt_templates = prompt_templates or yaml.safe_load( | ||||||
|  |             importlib.resources.read_text("smolagents.prompts", "code_agent.yaml") | ||||||
|  |         ) | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             tools=tools, |             tools=tools, | ||||||
|             model=model, |             model=model, | ||||||
|  |             prompt_templates=prompt_templates, | ||||||
|             grammar=grammar, |             grammar=grammar, | ||||||
|             planning_interval=planning_interval, |             planning_interval=planning_interval, | ||||||
|             **kwargs, |             **kwargs, | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ import uuid | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from unittest.mock import MagicMock | from unittest.mock import MagicMock | ||||||
| 
 | 
 | ||||||
|  | import pytest | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| from smolagents.agent_types import AgentImage, AgentText | from smolagents.agent_types import AgentImage, AgentText | ||||||
|  | @ -664,11 +665,19 @@ nested_answer() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestMultiStepAgent: | class TestMultiStepAgent: | ||||||
|     def test_logging_to_terminal_is_disabled(self): |     def test_instantiation_disables_logging_to_terminal(self): | ||||||
|         fake_model = MagicMock() |         fake_model = MagicMock() | ||||||
|         agent = MultiStepAgent(tools=[], model=fake_model) |         agent = MultiStepAgent(tools=[], model=fake_model) | ||||||
|         assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture" |         assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture" | ||||||
| 
 | 
 | ||||||
|  |     def test_instantiation_with_prompt_templates(self, prompt_templates): | ||||||
|  |         agent = MultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates) | ||||||
|  |         assert agent.prompt_templates == prompt_templates | ||||||
|  |         assert agent.prompt_templates["system_prompt"] == "This is a test system prompt." | ||||||
|  |         assert "managed_agent" in agent.prompt_templates | ||||||
|  |         assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}" | ||||||
|  |         assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}" | ||||||
|  | 
 | ||||||
|     def test_step_number(self): |     def test_step_number(self): | ||||||
|         fake_model = MagicMock() |         fake_model = MagicMock() | ||||||
|         fake_model.last_input_token_count = 10 |         fake_model.last_input_token_count = 10 | ||||||
|  | @ -724,3 +733,11 @@ class TestMultiStepAgent: | ||||||
|                     assert isinstance(content, dict) |                     assert isinstance(content, dict) | ||||||
|                     assert "type" in content |                     assert "type" in content | ||||||
|                     assert "text" in content |                     assert "text" in content | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def prompt_templates(): | ||||||
|  |     return { | ||||||
|  |         "system_prompt": "This is a test system prompt.", | ||||||
|  |         "managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"}, | ||||||
|  |     } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue