Add visualization method to display the agent' structure as a tree 🌳 (#470)
This commit is contained in:
		
							parent
							
								
									39908439d2
								
							
						
					
					
						commit
						42f95d8ee1
					
				|  | @ -225,6 +225,10 @@ class MultiStepAgent: | ||||||
|             messages.extend(memory_step.to_messages(summary_mode=summary_mode)) |             messages.extend(memory_step.to_messages(summary_mode=summary_mode)) | ||||||
|         return messages |         return messages | ||||||
| 
 | 
 | ||||||
|  |     def visualize(self): | ||||||
|  |         """Creates a rich tree visualization of the agent's structure.""" | ||||||
|  |         self.logger.visualize_agent_tree(self) | ||||||
|  | 
 | ||||||
|     def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]: |     def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]: | ||||||
|         """ |         """ | ||||||
|         Parse action from the LLM output |         Parse action from the LLM output | ||||||
|  |  | ||||||
|  | @ -241,8 +241,6 @@ class Model: | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
|         self.last_input_token_count = None |         self.last_input_token_count = None | ||||||
|         self.last_output_token_count = None |         self.last_output_token_count = None | ||||||
|         # Set default values for common parameters |  | ||||||
|         kwargs.setdefault("max_tokens", 4096) |  | ||||||
|         self.kwargs = kwargs |         self.kwargs = kwargs | ||||||
| 
 | 
 | ||||||
|     def _prepare_completion_kwargs( |     def _prepare_completion_kwargs( | ||||||
|  | @ -643,15 +641,19 @@ class LiteLLMModel(Model): | ||||||
|             The base URL of the OpenAI-compatible API server. |             The base URL of the OpenAI-compatible API server. | ||||||
|         api_key (`str`, *optional*): |         api_key (`str`, *optional*): | ||||||
|             The API key to use for authentication. |             The API key to use for authentication. | ||||||
|  |         custom_role_conversions (`dict[str, str]`, *optional*): | ||||||
|  |             Custom role conversion mapping to convert message roles in others. | ||||||
|  |             Useful for specific models that do not support specific message roles like "system". | ||||||
|         **kwargs: |         **kwargs: | ||||||
|             Additional keyword arguments to pass to the OpenAI API. |             Additional keyword arguments to pass to the OpenAI API. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model_id="anthropic/claude-3-5-sonnet-20240620", |         model_id: str = "anthropic/claude-3-5-sonnet-20240620", | ||||||
|         api_base=None, |         api_base=None, | ||||||
|         api_key=None, |         api_key=None, | ||||||
|  |         custom_role_conversions: Optional[Dict[str, str]] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         try: |         try: | ||||||
|  | @ -667,6 +669,7 @@ class LiteLLMModel(Model): | ||||||
|         litellm.add_function_to_prompt = True |         litellm.add_function_to_prompt = True | ||||||
|         self.api_base = api_base |         self.api_base = api_base | ||||||
|         self.api_key = api_key |         self.api_key = api_key | ||||||
|  |         self.custom_role_conversions = custom_role_conversions | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, |         self, | ||||||
|  | @ -687,6 +690,7 @@ class LiteLLMModel(Model): | ||||||
|             api_base=self.api_base, |             api_base=self.api_base, | ||||||
|             api_key=self.api_key, |             api_key=self.api_key, | ||||||
|             convert_images_to_image_urls=True, |             convert_images_to_image_urls=True, | ||||||
|  |             custom_role_conversions=self.custom_role_conversions, | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -23,7 +23,9 @@ from rich.console import Console, Group | ||||||
| from rich.panel import Panel | from rich.panel import Panel | ||||||
| from rich.rule import Rule | from rich.rule import Rule | ||||||
| from rich.syntax import Syntax | from rich.syntax import Syntax | ||||||
|  | from rich.table import Table | ||||||
| from rich.text import Text | from rich.text import Text | ||||||
|  | from rich.tree import Tree | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Monitor: | class Monitor: | ||||||
|  | @ -162,5 +164,42 @@ class AgentLogger: | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def visualize_agent_tree(self, agent): | ||||||
|  |         def create_tools_section(tools_dict): | ||||||
|  |             table = Table(show_header=True, header_style="bold") | ||||||
|  |             table.add_column("Name", style="blue") | ||||||
|  |             table.add_column("Description") | ||||||
|  |             table.add_column("Arguments") | ||||||
|  | 
 | ||||||
|  |             for name, tool in tools_dict.items(): | ||||||
|  |                 args = [ | ||||||
|  |                     f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}" | ||||||
|  |                     for arg_name, info in getattr(tool, "inputs", {}).items() | ||||||
|  |                 ] | ||||||
|  |                 table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args)) | ||||||
|  | 
 | ||||||
|  |             return Group(Text("🛠️ Tools", style="bold italic blue"), table) | ||||||
|  | 
 | ||||||
|  |         def build_agent_tree(parent_tree, agent_obj): | ||||||
|  |             """Recursively builds the agent tree.""" | ||||||
|  |             if agent_obj.tools: | ||||||
|  |                 parent_tree.add(create_tools_section(agent_obj.tools)) | ||||||
|  | 
 | ||||||
|  |             if agent_obj.managed_agents: | ||||||
|  |                 agents_branch = parent_tree.add("[bold italic blue]🤖 Managed agents") | ||||||
|  |                 for name, managed_agent in agent_obj.managed_agents.items(): | ||||||
|  |                     agent_node_text = f"[bold {YELLOW_HEX}]{name} - {managed_agent.agent.__class__.__name__}" | ||||||
|  |                     agent_tree = agents_branch.add(agent_node_text) | ||||||
|  |                     if hasattr(managed_agent, "description"): | ||||||
|  |                         agent_tree.add( | ||||||
|  |                             f"[bold italic blue]📝 Description:[/bold italic blue] {managed_agent.description}" | ||||||
|  |                         ) | ||||||
|  |                     if hasattr(managed_agent, "agent"): | ||||||
|  |                         build_agent_tree(agent_tree, managed_agent.agent) | ||||||
|  | 
 | ||||||
|  |         main_tree = Tree(f"[bold {YELLOW_HEX}]{agent.__class__.__name__}") | ||||||
|  |         build_agent_tree(main_tree, agent) | ||||||
|  |         self.console.print(main_tree) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ["AgentLogger", "Monitor"] | __all__ = ["AgentLogger", "Monitor"] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue