From ecabb9ea4f611781c1ee75cf12c2b055dfff6f7a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:29:03 +0100 Subject: [PATCH] Delete prompts_path argument and use prompt_templates (#541) --- src/smolagents/agents.py | 22 +++++++++++++--------- tests/test_agents.py | 19 ++++++++++++++++++- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 9366950..cbd1a39 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -88,7 +88,7 @@ class MultiStepAgent: Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. + prompt_templates (`dict`, *optional*): Prompt templates. max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. @@ -107,7 +107,7 @@ class MultiStepAgent: self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompts_path: Optional[str] = None, + prompt_templates: Optional[dict] = None, max_steps: int = 6, tool_parser: Optional[Callable] = None, add_base_tools: bool = False, @@ -125,6 +125,7 @@ class MultiStepAgent: tool_parser = parse_json_tool_call self.agent_name = self.__class__.__name__ self.model = model + self.prompt_templates = prompt_templates or {} self.max_steps = max_steps self.step_number: int = 0 self.tool_parser = tool_parser @@ -633,7 +634,7 @@ class ToolCallingAgent(MultiStepAgent): Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. + prompt_templates (`dict`, *optional*): Prompt templates. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. **kwargs: Additional keyword arguments. """ @@ -642,17 +643,17 @@ class ToolCallingAgent(MultiStepAgent): self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompts_path: Optional[str] = None, + prompt_templates: Optional[dict] = None, planning_interval: Optional[int] = None, **kwargs, ): - self.prompt_templates = yaml.safe_load( + prompt_templates = prompt_templates or yaml.safe_load( importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml") ) super().__init__( tools=tools, model=model, - prompts_path=prompts_path, + prompt_templates=prompt_templates, planning_interval=planning_interval, **kwargs, ) @@ -755,7 +756,7 @@ class CodeAgent(MultiStepAgent): Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary. + prompt_templates (`dict`, *optional*): Prompt templates. grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. @@ -769,7 +770,7 @@ class CodeAgent(MultiStepAgent): self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompts_path: Optional[str] = None, + prompt_templates: Optional[dict] = None, grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, @@ -779,10 +780,13 @@ class CodeAgent(MultiStepAgent): ): self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) - self.prompt_templates = yaml.safe_load(importlib.resources.read_text("smolagents.prompts", "code_agent.yaml")) + prompt_templates = prompt_templates or yaml.safe_load( + importlib.resources.read_text("smolagents.prompts", "code_agent.yaml") + ) super().__init__( tools=tools, model=model, + prompt_templates=prompt_templates, grammar=grammar, planning_interval=planning_interval, **kwargs, diff --git a/tests/test_agents.py b/tests/test_agents.py index 3a06747..2b9adf2 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -19,6 +19,7 @@ import uuid from pathlib import Path from unittest.mock import MagicMock +import pytest from transformers.testing_utils import get_tests_dir from smolagents.agent_types import AgentImage, AgentText @@ -664,11 +665,19 @@ nested_answer() class TestMultiStepAgent: - def test_logging_to_terminal_is_disabled(self): + def test_instantiation_disables_logging_to_terminal(self): fake_model = MagicMock() agent = MultiStepAgent(tools=[], model=fake_model) assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture" + def test_instantiation_with_prompt_templates(self, prompt_templates): + agent = MultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates) + assert agent.prompt_templates == prompt_templates + assert agent.prompt_templates["system_prompt"] == "This is a test system prompt." + assert "managed_agent" in agent.prompt_templates + assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}" + assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}" + def test_step_number(self): fake_model = MagicMock() fake_model.last_input_token_count = 10 @@ -724,3 +733,11 @@ class TestMultiStepAgent: assert isinstance(content, dict) assert "type" in content assert "text" in content + + +@pytest.fixture +def prompt_templates(): + return { + "system_prompt": "This is a test system prompt.", + "managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"}, + }