Create PromptTemplates typed dict (#547)
This commit is contained in:
		
							parent
							
								
									a17f915f61
								
							
						
					
					
						commit
						02b2b7ebb9
					
				|  | @ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n | ||||||
| > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. | > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. | ||||||
| 
 | 
 | ||||||
| [[autodoc]] GradioUI | [[autodoc]] GradioUI | ||||||
|  | 
 | ||||||
|  | ## Prompts | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PromptTemplates | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  |  | ||||||
|  | @ -155,3 +155,11 @@ model = OpenAIServerModel( | ||||||
|     api_key=os.environ["OPENAI_API_KEY"], |     api_key=os.environ["OPENAI_API_KEY"], | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
|  | 
 | ||||||
|  | ## Prompts | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PromptTemplates | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  |  | ||||||
|  | @ -147,3 +147,11 @@ print(model(messages)) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| [[autodoc]] LiteLLMModel | [[autodoc]] LiteLLMModel | ||||||
|  | 
 | ||||||
|  | ## Prompts | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PromptTemplates | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||||
|  | 
 | ||||||
|  | [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||||
|  |  | ||||||
|  | @ -14,6 +14,9 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
|  | 
 | ||||||
|  | __all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"] | ||||||
|  | 
 | ||||||
| import importlib.resources | import importlib.resources | ||||||
| import inspect | import inspect | ||||||
| import re | import re | ||||||
|  | @ -21,7 +24,7 @@ import textwrap | ||||||
| import time | import time | ||||||
| from collections import deque | from collections import deque | ||||||
| from logging import getLogger | from logging import getLogger | ||||||
| from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union | from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union | ||||||
| 
 | 
 | ||||||
| import yaml | import yaml | ||||||
| from jinja2 import StrictUndefined, Template | from jinja2 import StrictUndefined, Template | ||||||
|  | @ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str: | ||||||
|         raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") |         raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class PlanningPromptTemplate(TypedDict): | ||||||
|  |     """ | ||||||
|  |     Prompt templates for the planning step. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         initial_facts (`str`): Initial facts prompt. | ||||||
|  |         initial_plan (`str`): Initial plan prompt. | ||||||
|  |         update_facts_pre_messages (`str`): Update facts pre-messages prompt. | ||||||
|  |         update_facts_post_messages (`str`): Update facts post-messages prompt. | ||||||
|  |         update_plan_pre_messages (`str`): Update plan pre-messages prompt. | ||||||
|  |         update_plan_post_messages (`str`): Update plan post-messages prompt. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     initial_facts: str | ||||||
|  |     initial_plan: str | ||||||
|  |     update_facts_pre_messages: str | ||||||
|  |     update_facts_post_messages: str | ||||||
|  |     update_plan_pre_messages: str | ||||||
|  |     update_plan_post_messages: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ManagedAgentPromptTemplate(TypedDict): | ||||||
|  |     """ | ||||||
|  |     Prompt templates for the managed agent. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         task (`str`): Task prompt. | ||||||
|  |         report (`str`): Report prompt. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     task: str | ||||||
|  |     report: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PromptTemplates(TypedDict): | ||||||
|  |     """ | ||||||
|  |     Prompt templates for the agent. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         system_prompt (`str`): System prompt. | ||||||
|  |         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. | ||||||
|  |         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     system_prompt: str | ||||||
|  |     planning: PlanningPromptTemplate | ||||||
|  |     managed_agent: ManagedAgentPromptTemplate | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | EMPTY_PROMPT_TEMPLATES = PromptTemplates( | ||||||
|  |     system_prompt="", | ||||||
|  |     planning=PlanningPromptTemplate( | ||||||
|  |         initial_facts="", | ||||||
|  |         initial_plan="", | ||||||
|  |         update_facts_pre_messages="", | ||||||
|  |         update_facts_post_messages="", | ||||||
|  |         update_plan_pre_messages="", | ||||||
|  |         update_plan_post_messages="", | ||||||
|  |     ), | ||||||
|  |     managed_agent=ManagedAgentPromptTemplate(task="", report=""), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MultiStepAgent: | class MultiStepAgent: | ||||||
|     """ |     """ | ||||||
|     Agent class that solves the given task step by step, using the ReAct framework: |     Agent class that solves the given task step by step, using the ReAct framework: | ||||||
|  | @ -88,7 +154,7 @@ class MultiStepAgent: | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompt_templates (`dict`, *optional*): Prompt templates. |         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||||
|         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. |         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. | ||||||
|         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. |         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. | ||||||
|         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. |         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. | ||||||
|  | @ -107,7 +173,7 @@ class MultiStepAgent: | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompt_templates: Optional[dict] = None, |         prompt_templates: Optional[PromptTemplates] = None, | ||||||
|         max_steps: int = 6, |         max_steps: int = 6, | ||||||
|         tool_parser: Optional[Callable] = None, |         tool_parser: Optional[Callable] = None, | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|  | @ -125,7 +191,7 @@ class MultiStepAgent: | ||||||
|             tool_parser = parse_json_tool_call |             tool_parser = parse_json_tool_call | ||||||
|         self.agent_name = self.__class__.__name__ |         self.agent_name = self.__class__.__name__ | ||||||
|         self.model = model |         self.model = model | ||||||
|         self.prompt_templates = prompt_templates or {} |         self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES | ||||||
|         self.max_steps = max_steps |         self.max_steps = max_steps | ||||||
|         self.step_number: int = 0 |         self.step_number: int = 0 | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|  | @ -653,7 +719,7 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompt_templates (`dict`, *optional*): Prompt templates. |         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. |         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||||
|         **kwargs: Additional keyword arguments. |         **kwargs: Additional keyword arguments. | ||||||
|     """ |     """ | ||||||
|  | @ -662,7 +728,7 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompt_templates: Optional[dict] = None, |         prompt_templates: Optional[PromptTemplates] = None, | ||||||
|         planning_interval: Optional[int] = None, |         planning_interval: Optional[int] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|  | @ -775,7 +841,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|     Args: |     Args: | ||||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. |         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. |         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||||
|         prompt_templates (`dict`, *optional*): Prompt templates. |         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||||
|         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. |         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. | ||||||
|         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. |         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. | ||||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. |         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||||
|  | @ -789,7 +855,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|         self, |         self, | ||||||
|         tools: List[Tool], |         tools: List[Tool], | ||||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], |         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||||
|         prompt_templates: Optional[dict] = None, |         prompt_templates: Optional[PromptTemplates] = None, | ||||||
|         grammar: Optional[Dict[str, str]] = None, |         grammar: Optional[Dict[str, str]] = None, | ||||||
|         additional_authorized_imports: Optional[List[str]] = None, |         additional_authorized_imports: Optional[List[str]] = None, | ||||||
|         planning_interval: Optional[int] = None, |         planning_interval: Optional[int] = None, | ||||||
|  | @ -941,6 +1007,3 @@ class CodeAgent(MultiStepAgent): | ||||||
|         self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) |         self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) | ||||||
|         memory_step.action_output = output |         memory_step.action_output = output | ||||||
|         return output if is_final_answer else None |         return output if is_final_answer else None | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| __all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"] |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue