Improve code logging
This commit is contained in:
		
							parent
							
								
									154d1e938e
								
							
						
					
					
						commit
						43a3f46835
					
				|  | @ -17,6 +17,7 @@ | ||||||
| import time | import time | ||||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | from rich.syntax import Syntax | ||||||
| 
 | 
 | ||||||
| from transformers.utils import is_torch_available | from transformers.utils import is_torch_available | ||||||
| from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | ||||||
|  | @ -353,14 +354,6 @@ class BaseAgent: | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     def log_rationale_code_action(self, rationale: str, code_action: str) -> None: |  | ||||||
|         if self.verbose: |  | ||||||
|             console.rule("Agent thoughts") |  | ||||||
|             console.print(rationale) |  | ||||||
|         console.rule("Agent is executing the code below:", align="left") |  | ||||||
|         console.print(code_action) |  | ||||||
|         console.rule("", align="left") |  | ||||||
| 
 |  | ||||||
|     def run(self, **kwargs): |     def run(self, **kwargs): | ||||||
|         """To be implemented in the child class""" |         """To be implemented in the child class""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  | @ -447,20 +440,23 @@ class ReactAgent(BaseAgent): | ||||||
|         else: |         else: | ||||||
|             self.logs.append(TaskStep(task=task)) |             self.logs.append(TaskStep(task=task)) | ||||||
| 
 | 
 | ||||||
|         if oneshot: |         with console.status( | ||||||
|             step_start_time = time.time() |             "Agent is running...", spinner="aesthetic" | ||||||
|             step_log = ActionStep(start_time=step_start_time) |         ): | ||||||
|             step_log.step_end_time = time.time() |             if oneshot: | ||||||
|             step_log.step_duration = step_log.step_end_time - step_start_time |                 step_start_time = time.time() | ||||||
|  |                 step_log = ActionStep(start_time=step_start_time) | ||||||
|  |                 step_log.step_end_time = time.time() | ||||||
|  |                 step_log.step_duration = step_log.step_end_time - step_start_time | ||||||
| 
 | 
 | ||||||
|             # Run the agent's step |                 # Run the agent's step | ||||||
|             result = self.step(step_log) |                 result = self.step(step_log) | ||||||
|             return result |                 return result | ||||||
| 
 | 
 | ||||||
|         if stream: |             if stream: | ||||||
|             return self.stream_run(task) |                 return self.stream_run(task) | ||||||
|         else: |             else: | ||||||
|             return self.direct_run(task) |                 return self.direct_run(task) | ||||||
| 
 | 
 | ||||||
|     def stream_run(self, task: str): |     def stream_run(self, task: str): | ||||||
|         """ |         """ | ||||||
|  | @ -687,7 +683,7 @@ class JsonAgent(ReactAgent): | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("Calling LLM engine with this last message:", align="left") |             console.rule("[italic]Calling LLM engine with this last message:", align="left") | ||||||
|             console.print(self.prompt[-1]) |             console.print(self.prompt[-1]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|  | @ -806,7 +802,7 @@ class CodeAgent(ReactAgent): | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("Calling LLM engine with these last messages:", align="left") |             console.rule("[italic]Calling LLM engine with these last messages:", align="left") | ||||||
|             console.print(self.prompt[-2:]) |             console.print(self.prompt[-2:]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|  | @ -819,8 +815,8 @@ class CodeAgent(ReactAgent): | ||||||
|             raise AgentGenerationError(f"Error in generating llm output: {e}.") |             raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("Output message of the LLM:") |             console.rule("[italic]Output message of the LLM:") | ||||||
|             console.print(llm_output) |             console.print(Syntax(llm_output, lexer='markdown', background_color='default')) | ||||||
|         log_entry.llm_output = llm_output |         log_entry.llm_output = llm_output | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|  | @ -840,7 +836,13 @@ class CodeAgent(ReactAgent): | ||||||
|         log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action} |         log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action} | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         self.log_rationale_code_action(rationale, code_action) |         if self.verbose: | ||||||
|  |             console.rule("[italic]Agent thoughts") | ||||||
|  |             console.print(rationale) | ||||||
|  |         console.rule("[bold]Agent is executing the code below:", align="left") | ||||||
|  |         console.print(Syntax(code_action, lexer='python', background_color='default')) | ||||||
|  |         console.rule("", align="left") | ||||||
|  | 
 | ||||||
|         try: |         try: | ||||||
|             static_tools = { |             static_tools = { | ||||||
|                 **BASE_PYTHON_TOOLS.copy(), |                 **BASE_PYTHON_TOOLS.copy(), | ||||||
|  | @ -859,7 +861,7 @@ class CodeAgent(ReactAgent): | ||||||
|             console.print(self.state["print_outputs"]) |             console.print(self.state["print_outputs"]) | ||||||
|             observation = "Print outputs:\n" + self.state["print_outputs"] |             observation = "Print outputs:\n" + self.state["print_outputs"] | ||||||
|             if result is not None: |             if result is not None: | ||||||
|                 console.print("Last output from code snippet:") |                 console.rule("Last output from code snippet:", align="left") | ||||||
|                 console.print(str(result)) |                 console.print(str(result)) | ||||||
|                 observation += "Last output from code snippet:\n" + truncate_content(str(result)) |                 observation += "Last output from code snippet:\n" + truncate_content(str(result)) | ||||||
|             log_entry.observation = observation |             log_entry.observation = observation | ||||||
|  |  | ||||||
|  | @ -20,7 +20,7 @@ agent = CodeAgent( | ||||||
| 
 | 
 | ||||||
| # Run it! | # Run it! | ||||||
| result = agent.run( | result = agent.run( | ||||||
|     "When was Llama 3 first released?", oneshot=True |     "When was Llama 3 first released?" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| print(result) | print(result) | ||||||
		Loading…
	
		Reference in New Issue