Delete prompts_path argument and use prompt_templates (#541)
This commit is contained in:
parent
fd9eec8433
commit
ecabb9ea4f
|
@ -88,7 +88,7 @@ class MultiStepAgent:
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
|
prompt_templates (`dict`, *optional*): Prompt templates.
|
||||||
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
|
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
|
||||||
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
|
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
|
||||||
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
|
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
|
||||||
|
@ -107,7 +107,7 @@ class MultiStepAgent:
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompts_path: Optional[str] = None,
|
prompt_templates: Optional[dict] = None,
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
tool_parser: Optional[Callable] = None,
|
tool_parser: Optional[Callable] = None,
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
|
@ -125,6 +125,7 @@ class MultiStepAgent:
|
||||||
tool_parser = parse_json_tool_call
|
tool_parser = parse_json_tool_call
|
||||||
self.agent_name = self.__class__.__name__
|
self.agent_name = self.__class__.__name__
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.prompt_templates = prompt_templates or {}
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.step_number: int = 0
|
self.step_number: int = 0
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
|
@ -633,7 +634,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
|
prompt_templates (`dict`, *optional*): Prompt templates.
|
||||||
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
@ -642,17 +643,17 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompts_path: Optional[str] = None,
|
prompt_templates: Optional[dict] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.prompt_templates = yaml.safe_load(
|
prompt_templates = prompt_templates or yaml.safe_load(
|
||||||
importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml")
|
importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml")
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=model,
|
model=model,
|
||||||
prompts_path=prompts_path,
|
prompt_templates=prompt_templates,
|
||||||
planning_interval=planning_interval,
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -755,7 +756,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
Args:
|
Args:
|
||||||
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
tools (`list[Tool]`): [`Tool`]s that the agent can use.
|
||||||
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
|
||||||
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
|
prompt_templates (`dict`, *optional*): Prompt templates.
|
||||||
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
|
||||||
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
|
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
|
||||||
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
|
||||||
|
@ -769,7 +770,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
self,
|
self,
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
model: Callable[[List[Dict[str, str]]], ChatMessage],
|
||||||
prompts_path: Optional[str] = None,
|
prompt_templates: Optional[dict] = None,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
|
@ -779,10 +780,13 @@ class CodeAgent(MultiStepAgent):
|
||||||
):
|
):
|
||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||||
self.prompt_templates = yaml.safe_load(importlib.resources.read_text("smolagents.prompts", "code_agent.yaml"))
|
prompt_templates = prompt_templates or yaml.safe_load(
|
||||||
|
importlib.resources.read_text("smolagents.prompts", "code_agent.yaml")
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=model,
|
model=model,
|
||||||
|
prompt_templates=prompt_templates,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
planning_interval=planning_interval,
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -19,6 +19,7 @@ import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents.agent_types import AgentImage, AgentText
|
from smolagents.agent_types import AgentImage, AgentText
|
||||||
|
@ -664,11 +665,19 @@ nested_answer()
|
||||||
|
|
||||||
|
|
||||||
class TestMultiStepAgent:
|
class TestMultiStepAgent:
|
||||||
def test_logging_to_terminal_is_disabled(self):
|
def test_instantiation_disables_logging_to_terminal(self):
|
||||||
fake_model = MagicMock()
|
fake_model = MagicMock()
|
||||||
agent = MultiStepAgent(tools=[], model=fake_model)
|
agent = MultiStepAgent(tools=[], model=fake_model)
|
||||||
assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture"
|
assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture"
|
||||||
|
|
||||||
|
def test_instantiation_with_prompt_templates(self, prompt_templates):
|
||||||
|
agent = MultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates)
|
||||||
|
assert agent.prompt_templates == prompt_templates
|
||||||
|
assert agent.prompt_templates["system_prompt"] == "This is a test system prompt."
|
||||||
|
assert "managed_agent" in agent.prompt_templates
|
||||||
|
assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}"
|
||||||
|
assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}"
|
||||||
|
|
||||||
def test_step_number(self):
|
def test_step_number(self):
|
||||||
fake_model = MagicMock()
|
fake_model = MagicMock()
|
||||||
fake_model.last_input_token_count = 10
|
fake_model.last_input_token_count = 10
|
||||||
|
@ -724,3 +733,11 @@ class TestMultiStepAgent:
|
||||||
assert isinstance(content, dict)
|
assert isinstance(content, dict)
|
||||||
assert "type" in content
|
assert "type" in content
|
||||||
assert "text" in content
|
assert "text" in content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_templates():
|
||||||
|
return {
|
||||||
|
"system_prompt": "This is a test system prompt.",
|
||||||
|
"managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"},
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue