Merge MultiStepAgent and BaseAgent

This commit is contained in:
Aymeric 2024-12-24 12:01:32 +01:00
parent a3cd9158a7
commit 77428c8e9c
2 changed files with 8 additions and 36 deletions

View File

@ -7,7 +7,7 @@ name = "smolagents"
version = "0.1.0" version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
dependencies = [ dependencies = [
"transformers>=4.0.0", "transformers>=4.0.0",
"pytest>=8.1.0", "pytest>=8.1.0",

View File

@ -172,14 +172,17 @@ def format_prompt_with_managed_agents_descriptions(
return prompt_template.replace(agent_descriptions_placeholder, "") return prompt_template.replace(agent_descriptions_placeholder, "")
class BaseAgent: class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment).
"""
def __init__( def __init__(
self, self,
tools: Union[List[Tool], Toolbox], tools: Union[List[Tool], Toolbox],
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None, llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None, tool_description_template: Optional[str] = None,
additional_args: Dict = {},
max_iterations: int = 6, max_iterations: int = 6,
tool_parser: Optional[Callable] = None, tool_parser: Optional[Callable] = None,
add_base_tools: bool = False, add_base_tools: bool = False,
@ -188,6 +191,7 @@ class BaseAgent:
managed_agents: Optional[Dict] = None, managed_agents: Optional[Dict] = None,
step_callbacks: Optional[List[Callable]] = None, step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True, monitor_metrics: bool = True,
planning_interval: Optional[int] = None,
): ):
if llm_engine is None: if llm_engine is None:
llm_engine = HfApiEngine() llm_engine = HfApiEngine()
@ -203,10 +207,10 @@ class BaseAgent:
if tool_description_template if tool_description_template
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) )
self.additional_args = additional_args
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar self.grammar = grammar
self.planning_interval = planning_interval
self.managed_agents = {} self.managed_agents = {}
if managed_agents is not None: if managed_agents is not None:
@ -374,38 +378,6 @@ class BaseAgent:
) )
return rationale.strip(), action.strip() return rationale.strip(), action.strip()
def run(self, **kwargs):
"""To be implemented in the child class"""
raise NotImplementedError
class MultiStepAgent(BaseAgent):
"""
This agent that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of thinking and acting.
The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
"""
def __init__(
self,
tools: List[Tool],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
super().__init__(
tools=tools,
llm_engine=llm_engine,
system_prompt=system_prompt,
grammar=grammar,
**kwargs,
)
self.planning_interval = planning_interval
def provide_final_answer(self, task) -> str: def provide_final_answer(self, task) -> str:
""" """