Merge MultiStepAgent and BaseAgent
This commit is contained in:
		
							parent
							
								
									a3cd9158a7
								
							
						
					
					
						commit
						77428c8e9c
					
				|  | @ -7,7 +7,7 @@ name = "smolagents" | ||||||
| version = "0.1.0" | version = "0.1.0" | ||||||
| description = "Add your description here" | description = "Add your description here" | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| requires-python = ">=3.9" | requires-python = ">=3.10" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|     "transformers>=4.0.0", |     "transformers>=4.0.0", | ||||||
|     "pytest>=8.1.0", |     "pytest>=8.1.0", | ||||||
|  |  | ||||||
|  | @ -172,14 +172,17 @@ def format_prompt_with_managed_agents_descriptions( | ||||||
|         return prompt_template.replace(agent_descriptions_placeholder, "") |         return prompt_template.replace(agent_descriptions_placeholder, "") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class BaseAgent: | class MultiStepAgent: | ||||||
|  |     """ | ||||||
|  |     Agent class that solves the given task step by step, using the ReAct framework: | ||||||
|  |     While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment). | ||||||
|  |     """ | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         tools: Union[List[Tool], Toolbox], |         tools: Union[List[Tool], Toolbox], | ||||||
|         llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None, |         llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None, | ||||||
|         system_prompt: Optional[str] = None, |         system_prompt: Optional[str] = None, | ||||||
|         tool_description_template: Optional[str] = None, |         tool_description_template: Optional[str] = None, | ||||||
|         additional_args: Dict = {}, |  | ||||||
|         max_iterations: int = 6, |         max_iterations: int = 6, | ||||||
|         tool_parser: Optional[Callable] = None, |         tool_parser: Optional[Callable] = None, | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|  | @ -188,6 +191,7 @@ class BaseAgent: | ||||||
|         managed_agents: Optional[Dict] = None, |         managed_agents: Optional[Dict] = None, | ||||||
|         step_callbacks: Optional[List[Callable]] = None, |         step_callbacks: Optional[List[Callable]] = None, | ||||||
|         monitor_metrics: bool = True, |         monitor_metrics: bool = True, | ||||||
|  |         planning_interval: Optional[int] = None, | ||||||
|     ): |     ): | ||||||
|         if llm_engine is None: |         if llm_engine is None: | ||||||
|             llm_engine = HfApiEngine() |             llm_engine = HfApiEngine() | ||||||
|  | @ -203,10 +207,10 @@ class BaseAgent: | ||||||
|             if tool_description_template |             if tool_description_template | ||||||
|             else DEFAULT_TOOL_DESCRIPTION_TEMPLATE |             else DEFAULT_TOOL_DESCRIPTION_TEMPLATE | ||||||
|         ) |         ) | ||||||
|         self.additional_args = additional_args |  | ||||||
|         self.max_iterations = max_iterations |         self.max_iterations = max_iterations | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|         self.grammar = grammar |         self.grammar = grammar | ||||||
|  |         self.planning_interval = planning_interval | ||||||
| 
 | 
 | ||||||
|         self.managed_agents = {} |         self.managed_agents = {} | ||||||
|         if managed_agents is not None: |         if managed_agents is not None: | ||||||
|  | @ -374,38 +378,6 @@ class BaseAgent: | ||||||
|             ) |             ) | ||||||
|         return rationale.strip(), action.strip() |         return rationale.strip(), action.strip() | ||||||
| 
 | 
 | ||||||
|     def run(self, **kwargs): |  | ||||||
|         """To be implemented in the child class""" |  | ||||||
|         raise NotImplementedError |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MultiStepAgent(BaseAgent): |  | ||||||
|     """ |  | ||||||
|     This agent that solves the given task step by step, using the ReAct framework: |  | ||||||
|     While the objective is not reached, the agent will perform a cycle of thinking and acting. |  | ||||||
|     The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         tools: List[Tool], |  | ||||||
|         llm_engine: Optional[Callable] = None, |  | ||||||
|         system_prompt: Optional[str] = None, |  | ||||||
|         grammar: Optional[Dict[str, str]] = None, |  | ||||||
|         planning_interval: Optional[int] = None, |  | ||||||
|         **kwargs, |  | ||||||
|     ): |  | ||||||
|         if system_prompt is None: |  | ||||||
|             system_prompt = CODE_SYSTEM_PROMPT |  | ||||||
| 
 |  | ||||||
|         super().__init__( |  | ||||||
|             tools=tools, |  | ||||||
|             llm_engine=llm_engine, |  | ||||||
|             system_prompt=system_prompt, |  | ||||||
|             grammar=grammar, |  | ||||||
|             **kwargs, |  | ||||||
|         ) |  | ||||||
|         self.planning_interval = planning_interval |  | ||||||
| 
 | 
 | ||||||
|     def provide_final_answer(self, task) -> str: |     def provide_final_answer(self, task) -> str: | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue