Create PromptTemplates typed dict (#547)
This commit is contained in:
		
							parent
							
								
									a17f915f61
								
							
						
					
					
						commit
						02b2b7ebb9
					
				|  | @ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n | |||
| > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. | ||||
| 
 | ||||
| [[autodoc]] GradioUI | ||||
| 
 | ||||
| ## Prompts | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PromptTemplates | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
|  |  | |||
|  | @ -154,4 +154,12 @@ model = OpenAIServerModel( | |||
|     api_base="https://api.openai.com/v1", | ||||
|     api_key=os.environ["OPENAI_API_KEY"], | ||||
| ) | ||||
| ``` | ||||
| ``` | ||||
| 
 | ||||
| ## Prompts | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PromptTemplates | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
|  |  | |||
|  | @ -146,4 +146,12 @@ model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_ | |||
| print(model(messages)) | ||||
| ``` | ||||
| 
 | ||||
| [[autodoc]] LiteLLMModel | ||||
| [[autodoc]] LiteLLMModel | ||||
| 
 | ||||
| ## Prompts | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PromptTemplates | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.PlanningPromptTemplate | ||||
| 
 | ||||
| [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate | ||||
|  |  | |||
|  | @ -14,6 +14,9 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| __all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"] | ||||
| 
 | ||||
| import importlib.resources | ||||
| import inspect | ||||
| import re | ||||
|  | @ -21,7 +24,7 @@ import textwrap | |||
| import time | ||||
| from collections import deque | ||||
| from logging import getLogger | ||||
| from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union | ||||
| from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union | ||||
| 
 | ||||
| import yaml | ||||
| from jinja2 import StrictUndefined, Template | ||||
|  | @ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str: | |||
|         raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") | ||||
| 
 | ||||
| 
 | ||||
| class PlanningPromptTemplate(TypedDict): | ||||
|     """ | ||||
|     Prompt templates for the planning step. | ||||
| 
 | ||||
|     Args: | ||||
|         initial_facts (`str`): Initial facts prompt. | ||||
|         initial_plan (`str`): Initial plan prompt. | ||||
|         update_facts_pre_messages (`str`): Update facts pre-messages prompt. | ||||
|         update_facts_post_messages (`str`): Update facts post-messages prompt. | ||||
|         update_plan_pre_messages (`str`): Update plan pre-messages prompt. | ||||
|         update_plan_post_messages (`str`): Update plan post-messages prompt. | ||||
|     """ | ||||
| 
 | ||||
|     initial_facts: str | ||||
|     initial_plan: str | ||||
|     update_facts_pre_messages: str | ||||
|     update_facts_post_messages: str | ||||
|     update_plan_pre_messages: str | ||||
|     update_plan_post_messages: str | ||||
| 
 | ||||
| 
 | ||||
| class ManagedAgentPromptTemplate(TypedDict): | ||||
|     """ | ||||
|     Prompt templates for the managed agent. | ||||
| 
 | ||||
|     Args: | ||||
|         task (`str`): Task prompt. | ||||
|         report (`str`): Report prompt. | ||||
|     """ | ||||
| 
 | ||||
|     task: str | ||||
|     report: str | ||||
| 
 | ||||
| 
 | ||||
| class PromptTemplates(TypedDict): | ||||
|     """ | ||||
|     Prompt templates for the agent. | ||||
| 
 | ||||
|     Args: | ||||
|         system_prompt (`str`): System prompt. | ||||
|         planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. | ||||
|         managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. | ||||
|     """ | ||||
| 
 | ||||
|     system_prompt: str | ||||
|     planning: PlanningPromptTemplate | ||||
|     managed_agent: ManagedAgentPromptTemplate | ||||
| 
 | ||||
| 
 | ||||
| EMPTY_PROMPT_TEMPLATES = PromptTemplates( | ||||
|     system_prompt="", | ||||
|     planning=PlanningPromptTemplate( | ||||
|         initial_facts="", | ||||
|         initial_plan="", | ||||
|         update_facts_pre_messages="", | ||||
|         update_facts_post_messages="", | ||||
|         update_plan_pre_messages="", | ||||
|         update_plan_post_messages="", | ||||
|     ), | ||||
|     managed_agent=ManagedAgentPromptTemplate(task="", report=""), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class MultiStepAgent: | ||||
|     """ | ||||
|     Agent class that solves the given task step by step, using the ReAct framework: | ||||
|  | @ -88,7 +154,7 @@ class MultiStepAgent: | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||
|         max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. | ||||
|         tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. | ||||
|         add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. | ||||
|  | @ -107,7 +173,7 @@ class MultiStepAgent: | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         prompt_templates: Optional[PromptTemplates] = None, | ||||
|         max_steps: int = 6, | ||||
|         tool_parser: Optional[Callable] = None, | ||||
|         add_base_tools: bool = False, | ||||
|  | @ -125,7 +191,7 @@ class MultiStepAgent: | |||
|             tool_parser = parse_json_tool_call | ||||
|         self.agent_name = self.__class__.__name__ | ||||
|         self.model = model | ||||
|         self.prompt_templates = prompt_templates or {} | ||||
|         self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES | ||||
|         self.max_steps = max_steps | ||||
|         self.step_number: int = 0 | ||||
|         self.tool_parser = tool_parser | ||||
|  | @ -653,7 +719,7 @@ class ToolCallingAgent(MultiStepAgent): | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||
|         **kwargs: Additional keyword arguments. | ||||
|     """ | ||||
|  | @ -662,7 +728,7 @@ class ToolCallingAgent(MultiStepAgent): | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         prompt_templates: Optional[PromptTemplates] = None, | ||||
|         planning_interval: Optional[int] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|  | @ -775,7 +841,7 @@ class CodeAgent(MultiStepAgent): | |||
|     Args: | ||||
|         tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||||
|         model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||||
|         prompt_templates (`dict`, *optional*): Prompt templates. | ||||
|         prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||||
|         grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. | ||||
|         additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. | ||||
|         planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||||
|  | @ -789,7 +855,7 @@ class CodeAgent(MultiStepAgent): | |||
|         self, | ||||
|         tools: List[Tool], | ||||
|         model: Callable[[List[Dict[str, str]]], ChatMessage], | ||||
|         prompt_templates: Optional[dict] = None, | ||||
|         prompt_templates: Optional[PromptTemplates] = None, | ||||
|         grammar: Optional[Dict[str, str]] = None, | ||||
|         additional_authorized_imports: Optional[List[str]] = None, | ||||
|         planning_interval: Optional[int] = None, | ||||
|  | @ -941,6 +1007,3 @@ class CodeAgent(MultiStepAgent): | |||
|         self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) | ||||
|         memory_step.action_output = output | ||||
|         return output if is_final_answer else None | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue