Simplify step logs
This commit is contained in:
		
							parent
							
								
									1606b9a80c
								
							
						
					
					
						commit
						0a0402d090
					
				
							
								
								
									
										119
									
								
								agents/agents.py
								
								
								
								
							
							
						
						
									
										119
									
								
								agents/agents.py
								
								
								
								
							|  | @ -79,6 +79,11 @@ class AgentGenerationError(AgentError): | ||||||
| 
 | 
 | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class ToolCall(): | ||||||
|  |     tool_name: str | ||||||
|  |     tool_arguments: Any | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class AgentStep: | class AgentStep: | ||||||
|     pass |     pass | ||||||
|  | @ -86,17 +91,16 @@ class AgentStep: | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ActionStep(AgentStep): | class ActionStep(AgentStep): | ||||||
|     tool_call: Dict[str, str] | None = None |  | ||||||
|     start_time: float | None = None |  | ||||||
|     step_end_time: float | None = None |  | ||||||
|     iteration: int | None = None |  | ||||||
|     final_answer: Any = None |  | ||||||
|     error: AgentError | None = None |  | ||||||
|     step_duration: float | None = None |  | ||||||
|     llm_output: str | None = None |  | ||||||
|     observation: str | None = None |  | ||||||
|     agent_memory: List[Dict[str, str]] | None = None |     agent_memory: List[Dict[str, str]] | None = None | ||||||
|     rationale: str | None = None |     tool_call: ToolCall | None = None | ||||||
|  |     start_time: float | None = None | ||||||
|  |     end_time: float | None = None | ||||||
|  |     iteration: int | None = None | ||||||
|  |     error: AgentError | None = None | ||||||
|  |     duration: float | None = None | ||||||
|  |     llm_output: str | None = None | ||||||
|  |     observations: str | None = None | ||||||
|  |     action_output: Any = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
|  | @ -222,7 +226,6 @@ class BaseAgent: | ||||||
|         self._toolbox.add_tool(FinalAnswerTool()) |         self._toolbox.add_tool(FinalAnswerTool()) | ||||||
| 
 | 
 | ||||||
|         self.system_prompt = self.initialize_system_prompt() |         self.system_prompt = self.initialize_system_prompt() | ||||||
|         print("SYS0:", self.system_prompt) |  | ||||||
|         self.prompt_messages = None |         self.prompt_messages = None | ||||||
|         self.logs = [] |         self.logs = [] | ||||||
|         self.task = None |         self.task = None | ||||||
|  | @ -313,15 +316,15 @@ class BaseAgent: | ||||||
|                     } |                     } | ||||||
|                     memory.append(tool_call_message) |                     memory.append(tool_call_message) | ||||||
| 
 | 
 | ||||||
|                 if step_log.error is not None or step_log.observation is not None: |                 if step_log.error is not None or step_log.observations is not None: | ||||||
|                     if step_log.error is not None: |                     if step_log.error is not None: | ||||||
|                         message_content = ( |                         message_content = ( | ||||||
|                             f"[OUTPUT OF STEP {i}] -> Error:\n" |                             f"[OUTPUT OF STEP {i}] -> Error:\n" | ||||||
|                             + str(step_log.error) |                             + str(step_log.error) | ||||||
|                             + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" |                             + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" | ||||||
|                         ) |                         ) | ||||||
|                     elif step_log.observation is not None: |                     elif step_log.observations is not None: | ||||||
|                         message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" |                         message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observations}" | ||||||
|                     tool_response_message = { |                     tool_response_message = { | ||||||
|                         "role": MessageRole.TOOL_RESPONSE, |                         "role": MessageRole.TOOL_RESPONSE, | ||||||
|                         "content": message_content, |                         "content": message_content, | ||||||
|  | @ -466,8 +469,8 @@ class ReactAgent(BaseAgent): | ||||||
|                 console.print(f"[bold red]{error_msg}") |                 console.print(f"[bold red]{error_msg}") | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg) | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: ActionStep): |     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||||
|         """To be implemented in children classes""" |         """To be implemented in children classes. Should return either None if the step is not final.""" | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     def run( |     def run( | ||||||
|  | @ -521,8 +524,8 @@ class ReactAgent(BaseAgent): | ||||||
|         if oneshot: |         if oneshot: | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(start_time=step_start_time) |             step_log = ActionStep(start_time=step_start_time) | ||||||
|             step_log.step_end_time = time.time() |             step_log.end_time = time.time() | ||||||
|             step_log.step_duration = step_log.step_end_time - step_start_time |             step_log.duration = step_log.end_time - step_start_time | ||||||
| 
 | 
 | ||||||
|             # Run the agent's step |             # Run the agent's step | ||||||
|             result = self.step(step_log) |             result = self.step(step_log) | ||||||
|  | @ -551,14 +554,14 @@ class ReactAgent(BaseAgent): | ||||||
|                         task, is_first_step=(iteration == 0), iteration=iteration |                         task, is_first_step=(iteration == 0), iteration=iteration | ||||||
|                     ) |                     ) | ||||||
|                 console.rule("[bold]New step") |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) | 
 | ||||||
|                 if step_log.final_answer is not None: |                 # Run one step! | ||||||
|                     final_answer = step_log.final_answer |                 final_answer = self.step(step_log) | ||||||
|             except AgentError as e: |             except AgentError as e: | ||||||
|                 step_log.error = e |                 step_log.error = e | ||||||
|             finally: |             finally: | ||||||
|                 step_log.step_end_time = time.time() |                 step_log.end_time = time.time() | ||||||
|                 step_log.step_duration = step_log.step_end_time - step_start_time |                 step_log.duration = step_log.end_time - step_start_time | ||||||
|                 self.logs.append(step_log) |                 self.logs.append(step_log) | ||||||
|                 for callback in self.step_callbacks: |                 for callback in self.step_callbacks: | ||||||
|                     callback(step_log) |                     callback(step_log) | ||||||
|  | @ -570,9 +573,9 @@ class ReactAgent(BaseAgent): | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             final_step_log.final_answer = final_answer |             final_step_log.action_output = final_answer | ||||||
|             final_step_log.step_end_time = time.time() |             final_step_log.end_time = time.time() | ||||||
|             final_step_log.step_duration = step_log.step_end_time - step_start_time |             final_step_log.duration = step_log.end_time - step_start_time | ||||||
|             for callback in self.step_callbacks: |             for callback in self.step_callbacks: | ||||||
|                 callback(final_step_log) |                 callback(final_step_log) | ||||||
|             yield final_step_log |             yield final_step_log | ||||||
|  | @ -597,15 +600,16 @@ class ReactAgent(BaseAgent): | ||||||
|                         task, is_first_step=(iteration == 0), iteration=iteration |                         task, is_first_step=(iteration == 0), iteration=iteration | ||||||
|                     ) |                     ) | ||||||
|                 console.rule("[bold]New step") |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) | 
 | ||||||
|                 if step_log.final_answer is not None: |                 # Run one step! | ||||||
|                     final_answer = step_log.final_answer |                 final_answer = self.step(step_log) | ||||||
|  | 
 | ||||||
|             except AgentError as e: |             except AgentError as e: | ||||||
|                 step_log.error = e |                 step_log.error = e | ||||||
|             finally: |             finally: | ||||||
|                 step_end_time = time.time() |                 step_end_time = time.time() | ||||||
|                 step_log.step_end_time = step_end_time |                 step_log.end_time = step_end_time | ||||||
|                 step_log.step_duration = step_end_time - step_start_time |                 step_log.duration = step_end_time - step_start_time | ||||||
|                 self.logs.append(step_log) |                 self.logs.append(step_log) | ||||||
|                 for callback in self.step_callbacks: |                 for callback in self.step_callbacks: | ||||||
|                     callback(step_log) |                     callback(step_log) | ||||||
|  | @ -616,8 +620,8 @@ class ReactAgent(BaseAgent): | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             final_step_log.final_answer = final_answer |             final_step_log.action_output = final_answer | ||||||
|             final_step_log.step_duration = 0 |             final_step_log.duration = 0 | ||||||
|             for callback in self.step_callbacks: |             for callback in self.step_callbacks: | ||||||
|                 callback(final_step_log) |                 callback(final_step_log) | ||||||
| 
 | 
 | ||||||
|  | @ -777,10 +781,10 @@ class JsonAgent(ReactAgent): | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: ActionStep): |     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||||
|         """ |         """ | ||||||
|         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. |         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||||||
|         The errors are raised here, they are caught and logged in the run() method. |         Returns None if the step is not final. | ||||||
|         """ |         """ | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|  | @ -823,8 +827,7 @@ class JsonAgent(ReactAgent): | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentParsingError(f"Could not parse the given action: {e}.") |             raise AgentParsingError(f"Could not parse the given action: {e}.") | ||||||
| 
 | 
 | ||||||
|         log_entry.rationale = rationale |         log_entry.tool_call = ToolCall(tool_name=tool_name, tool_arguments=arguments) | ||||||
|         log_entry.tool_call = {"tool_name": tool_name, "tool_arguments": arguments} |  | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         console.rule("Agent thoughts:") |         console.rule("Agent thoughts:") | ||||||
|  | @ -835,15 +838,15 @@ class JsonAgent(ReactAgent): | ||||||
|             if isinstance(arguments, dict): |             if isinstance(arguments, dict): | ||||||
|                 if "answer" in arguments: |                 if "answer" in arguments: | ||||||
|                     answer = arguments["answer"] |                     answer = arguments["answer"] | ||||||
|  |                 else: | ||||||
|  |                     answer = arguments | ||||||
|  |             else: | ||||||
|  |                 answer = arguments | ||||||
|             if ( |             if ( | ||||||
|                 isinstance(answer, str) and answer in self.state.keys() |                 isinstance(answer, str) and answer in self.state.keys() | ||||||
|             ):  # if the answer is a state variable, return the value |             ):  # if the answer is a state variable, return the value | ||||||
|                 answer = self.state[answer] |                 answer = self.state[answer] | ||||||
|                 else: |             log_entry.action_output = answer | ||||||
|                     answer = arguments |  | ||||||
|             else: |  | ||||||
|                 answer = arguments |  | ||||||
|             log_entry.final_answer = answer |  | ||||||
|             return answer |             return answer | ||||||
|         else: |         else: | ||||||
|             if arguments is None: |             if arguments is None: | ||||||
|  | @ -861,8 +864,8 @@ class JsonAgent(ReactAgent): | ||||||
|                 updated_information = f"Stored '{observation_name}' in memory." |                 updated_information = f"Stored '{observation_name}' in memory." | ||||||
|             else: |             else: | ||||||
|                 updated_information = str(observation).strip() |                 updated_information = str(observation).strip() | ||||||
|             log_entry.observation = updated_information |             log_entry.observations = updated_information | ||||||
|             return log_entry |             return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CodeAgent(ReactAgent): | class CodeAgent(ReactAgent): | ||||||
|  | @ -906,16 +909,15 @@ class CodeAgent(ReactAgent): | ||||||
|         self.authorized_imports = list( |         self.authorized_imports = list( | ||||||
|             set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) |             set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) | ||||||
|         ) |         ) | ||||||
|         print("SYSS:", self.system_prompt) |  | ||||||
|         self.system_prompt = self.system_prompt.replace( |         self.system_prompt = self.system_prompt.replace( | ||||||
|             "{{authorized_imports}}", str(self.authorized_imports) |             "{{authorized_imports}}", str(self.authorized_imports) | ||||||
|         ) |         ) | ||||||
|         self.custom_tools = {} |         self.custom_tools = {} | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: ActionStep): |     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||||
|         """ |         """ | ||||||
|         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. |         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||||||
|         The errors are raised here, they are caught and logged in the run() method. |         Returns None if the step is not final. | ||||||
|         """ |         """ | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|  | @ -967,11 +969,7 @@ class CodeAgent(ReactAgent): | ||||||
|             error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" |             error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg) | ||||||
| 
 | 
 | ||||||
|         log_entry.rationale = rationale |         log_entry.tool_call = ToolCall(tool_name="python_interpreter", tool_arguments=code_action) | ||||||
|         log_entry.tool_call = { |  | ||||||
|             "tool_name": "code interpreter", |  | ||||||
|             "tool_arguments": code_action, |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|  | @ -988,7 +986,7 @@ class CodeAgent(ReactAgent): | ||||||
|             } |             } | ||||||
|             if self.managed_agents is not None: |             if self.managed_agents is not None: | ||||||
|                 static_tools = {**static_tools, **self.managed_agents} |                 static_tools = {**static_tools, **self.managed_agents} | ||||||
|             result = self.python_evaluator( |             output = self.python_evaluator( | ||||||
|                 code_action, |                 code_action, | ||||||
|                 static_tools=static_tools, |                 static_tools=static_tools, | ||||||
|                 custom_tools=self.custom_tools, |                 custom_tools=self.custom_tools, | ||||||
|  | @ -998,13 +996,13 @@ class CodeAgent(ReactAgent): | ||||||
|             console.print("Print outputs:") |             console.print("Print outputs:") | ||||||
|             console.print(self.state["print_outputs"]) |             console.print(self.state["print_outputs"]) | ||||||
|             observation = "Print outputs:\n" + self.state["print_outputs"] |             observation = "Print outputs:\n" + self.state["print_outputs"] | ||||||
|             if result is not None: |             if output is not None: | ||||||
|                 console.rule("Last output from code snippet:", align="left") |                 console.rule("Last output from code snippet:", align="left") | ||||||
|                 console.print(str(result)) |                 console.print(str(output)) | ||||||
|                 observation += "Last output from code snippet:\n" + truncate_content( |                 observation += "Last output from code snippet:\n" + truncate_content( | ||||||
|                     str(result) |                     str(output) | ||||||
|                 ) |                 ) | ||||||
|             log_entry.observation = observation |             log_entry.observations = observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Code execution failed due to the following error:\n{str(e)}" |             error_msg = f"Code execution failed due to the following error:\n{str(e)}" | ||||||
|             if "'dict' object has no attribute 'read'" in str(e): |             if "'dict' object has no attribute 'read'" in str(e): | ||||||
|  | @ -1013,9 +1011,10 @@ class CodeAgent(ReactAgent): | ||||||
|         for line in code_action.split("\n"): |         for line in code_action.split("\n"): | ||||||
|             if line[: len("final_answer")] == "final_answer": |             if line[: len("final_answer")] == "final_answer": | ||||||
|                 console.print("Final answer:") |                 console.print("Final answer:") | ||||||
|                 console.print(f"[bold]{result}") |                 console.print(f"[bold]{output}") | ||||||
|                 log_entry.final_answer = result |                 log_entry.action_output = output | ||||||
|         return result |                 return output | ||||||
|  |         return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ManagedAgent: | class ManagedAgent: | ||||||
|  |  | ||||||
|  | @ -126,7 +126,9 @@ def get_remote_tools(logger, organization="huggingface-tools"): | ||||||
| class PythonInterpreterTool(Tool): | class PythonInterpreterTool(Tool): | ||||||
|     name = "python_interpreter" |     name = "python_interpreter" | ||||||
|     description = "This is a tool that evaluates python code. It can be used to perform calculations." |     description = "This is a tool that evaluates python code. It can be used to perform calculations." | ||||||
| 
 |     inputs = { | ||||||
|  |         "code": {"type": "string", "description": "The python code to run in interpreter"} | ||||||
|  |     } | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, authorized_imports=None, **kwargs): |     def __init__(self, *args, authorized_imports=None, **kwargs): | ||||||
|  | @ -147,7 +149,7 @@ class PythonInterpreterTool(Tool): | ||||||
|         } |         } | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| 
 | 
 | ||||||
|     def forward(self, code): |     def forward(self, code: str) -> str: | ||||||
|         output = str( |         output = str( | ||||||
|             evaluate_python_code( |             evaluate_python_code( | ||||||
|                 code, |                 code, | ||||||
|  |  | ||||||
|  | @ -22,20 +22,20 @@ import gradio as gr | ||||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|         yield gr.ChatMessage(role="assistant", content=step_log.rationale) |         yield gr.ChatMessage(role="assistant", content=step_log.llm_output) | ||||||
|         if step_log.tool_call is not None: |         if step_log.tool_call is not None: | ||||||
|             used_code = step_log.tool_call["tool_name"] == "code interpreter" |             used_code = step_log.tool_call.tool_name == "code interpreter" | ||||||
|             content = step_log.tool_call["tool_arguments"] |             content = step_log.tool_call.tool_arguments | ||||||
|             if used_code: |             if used_code: | ||||||
|                 content = f"```py\n{content}\n```" |                 content = f"```py\n{content}\n```" | ||||||
|             yield gr.ChatMessage( |             yield gr.ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 metadata={"title": f"🛠️ Used tool {step_log.tool_call['tool_name']}"}, |                 metadata={"title": f"🛠️ Used tool {step_log.tool_call.tool_name}"}, | ||||||
|                 content=str(content), |                 content=str(content), | ||||||
|             ) |             ) | ||||||
|         if step_log.observation is not None: |         if step_log.observations is not None: | ||||||
|             yield gr.ChatMessage( |             yield gr.ChatMessage( | ||||||
|                 role="assistant", content=f"```\n{step_log.observation}\n```" |                 role="assistant", content=f"```\n{step_log.observations}\n```" | ||||||
|             ) |             ) | ||||||
|         if step_log.error is not None: |         if step_log.error is not None: | ||||||
|             yield gr.ChatMessage( |             yield gr.ChatMessage( | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ class Monitor: | ||||||
|             self.total_output_token_count = 0 |             self.total_output_token_count = 0 | ||||||
| 
 | 
 | ||||||
|     def update_metrics(self, step_log): |     def update_metrics(self, step_log): | ||||||
|         step_duration = step_log.step_duration |         step_duration = step_log.duration | ||||||
|         self.step_durations.append(step_duration) |         self.step_durations.append(step_duration) | ||||||
|         console.print(f"Step {len(self.step_durations)}:") |         console.print(f"Step {len(self.step_durations)}:") | ||||||
|         console.print(f"- Time taken: {step_duration:.2f} seconds") |         console.print(f"- Time taken: {step_duration:.2f} seconds") | ||||||
|  |  | ||||||
|  | @ -150,7 +150,7 @@ class Tool: | ||||||
|     name: str |     name: str | ||||||
|     description: str |     description: str | ||||||
|     inputs: Dict[str, Dict[str, Union[str, type]]] |     inputs: Dict[str, Dict[str, Union[str, type]]] | ||||||
|     output_type: type |     output_type: str | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         self.is_initialized = False |         self.is_initialized = False | ||||||
|  |  | ||||||
|  | @ -16,9 +16,10 @@ import os | ||||||
| import tempfile | import tempfile | ||||||
| import unittest | import unittest | ||||||
| import uuid | import uuid | ||||||
| 
 |  | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
| from agents.agent_types import AgentText | from agents.agent_types import AgentText | ||||||
| from agents.agents import ( | from agents.agents import ( | ||||||
|     AgentMaxIterationsError, |     AgentMaxIterationsError, | ||||||
|  | @ -26,16 +27,18 @@ from agents.agents import ( | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     JsonAgent, |     JsonAgent, | ||||||
|     Toolbox, |     Toolbox, | ||||||
|  |     ToolCall | ||||||
| ) | ) | ||||||
|  | from agents.tools import tool | ||||||
| from agents.default_tools import PythonInterpreterTool | from agents.default_tools import PythonInterpreterTool | ||||||
| 
 | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|     directory = tempfile.mkdtemp() |     directory = tempfile.mkdtemp() | ||||||
|     return os.path.join(directory, str(uuid.uuid4()) + suffix) |     return os.path.join(directory, str(uuid.uuid4()) + suffix) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str: | def fake_json_llm(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
| 
 | 
 | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|  | @ -57,8 +60,29 @@ Action: | ||||||
| } | } | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
|  | def fake_json_llm_image(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|  |     prompt = str(messages) | ||||||
| 
 | 
 | ||||||
| def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str: |     if "special_marker" not in prompt: | ||||||
|  |         return """ | ||||||
|  | Thought: I should generate an image. special_marker | ||||||
|  | Action: | ||||||
|  | { | ||||||
|  |     "action": "fake_image_generation_tool", | ||||||
|  |     "action_input": {"prompt": "An image of a cat"} | ||||||
|  | } | ||||||
|  | """ | ||||||
|  |     else:  # We're at step 2 | ||||||
|  |         return """ | ||||||
|  | Thought: I can now answer the initial question | ||||||
|  | Action: | ||||||
|  | { | ||||||
|  |     "action": "final_answer", | ||||||
|  |     "action_input": "image.png" | ||||||
|  | } | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | def fake_code_llm(messages, stop_sequences=None, grammar=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return """ | ||||||
|  | @ -78,7 +102,7 @@ final_answer(7.2904) | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_react_code_llm_error(messages, stop_sequences=None) -> str: | def fake_code_llm_error(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return """ | ||||||
|  | @ -98,7 +122,7 @@ final_answer("got an error") | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def fake_react_code_functiondef(messages, stop_sequences=None) -> str: | def fake_code_functiondef(messages, stop_sequences=None) -> str: | ||||||
|     prompt = str(messages) |     prompt = str(messages) | ||||||
|     if "special_marker" not in prompt: |     if "special_marker" not in prompt: | ||||||
|         return """ |         return """ | ||||||
|  | @ -146,27 +170,23 @@ print(result) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentTests(unittest.TestCase): | class AgentTests(unittest.TestCase): | ||||||
|     def test_fake_code_agent(self): |     def test_fake_oneshot_code_agent(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot |             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot | ||||||
|         ) |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?", oneshot=True) | ||||||
|         assert isinstance(output, str) |         assert isinstance(output, str) | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
| 
 | 
 | ||||||
|     def test_fake_react_json_agent(self): |     def test_fake_json_agent(self): | ||||||
|         agent = JsonAgent( |         agent = JsonAgent( | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm |             tools=[PythonInterpreterTool()], llm_engine=fake_json_llm | ||||||
|         ) |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, str) |         assert isinstance(output, str) | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
|         assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" |         assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" | ||||||
|         assert agent.logs[2].observation == "7.2904" |         assert agent.logs[2].observations == "7.2904" | ||||||
|         assert ( |  | ||||||
|             agent.logs[2].rationale.strip() |  | ||||||
|             == "Thought: I should multiply 2 by 3.6452. special_marker" |  | ||||||
|         ) |  | ||||||
|         assert ( |         assert ( | ||||||
|             agent.logs[3].llm_output |             agent.logs[3].llm_output | ||||||
|             == """ |             == """ | ||||||
|  | @ -179,22 +199,43 @@ Action: | ||||||
| """ | """ | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def test_fake_react_code_agent(self): |     def test_json_agent_handles_image_tool_outputs(self): | ||||||
|  |         from PIL import Image | ||||||
|  | 
 | ||||||
|  |         @tool | ||||||
|  |         def fake_image_generation_tool(prompt: str) -> Image.Image: | ||||||
|  |             """Tool that generates an image. | ||||||
|  | 
 | ||||||
|  |             Args: | ||||||
|  |                 prompt: The prompt | ||||||
|  |             """ | ||||||
|  |             return Image.open( | ||||||
|  |                 Path(get_tests_dir("fixtures")) / "000000039769.png" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         agent = JsonAgent( | ||||||
|  |             tools=[fake_image_generation_tool], llm_engine=fake_json_llm_image | ||||||
|  |         ) | ||||||
|  |         output = agent.run("Make me an image.") | ||||||
|  |         assert isinstance(output, Image.Image) | ||||||
|  |         assert isinstance(agent.state["image.png"], Image.Image) | ||||||
|  | 
 | ||||||
|  |     def test_fake_code_agent(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm |             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm | ||||||
|         ) |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, float) |         assert isinstance(output, float) | ||||||
|         assert output == 7.2904 |         assert output == 7.2904 | ||||||
|         assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" |         assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" | ||||||
|         assert agent.logs[3].tool_call == { |         assert agent.logs[3].tool_call == ToolCall( | ||||||
|             "tool_arguments": "final_answer(7.2904)", |             tool_name="python_interpreter", | ||||||
|             "tool_name": "code interpreter", |             tool_arguments="final_answer(7.2904)", | ||||||
|         } |         ) | ||||||
| 
 | 
 | ||||||
|     def test_react_code_agent_code_errors_show_offending_lines(self): |     def test_code_agent_code_errors_show_offending_lines(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error |             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_error | ||||||
|         ) |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, AgentText) |         assert isinstance(output, AgentText) | ||||||
|  | @ -202,9 +243,9 @@ Action: | ||||||
|         assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) |         assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) | ||||||
| 
 | 
 | ||||||
|     def test_setup_agent_with_empty_toolbox(self): |     def test_setup_agent_with_empty_toolbox(self): | ||||||
|         JsonAgent(llm_engine=fake_react_json_llm, tools=[]) |         JsonAgent(llm_engine=fake_json_llm, tools=[]) | ||||||
| 
 | 
 | ||||||
|     def test_react_fails_max_iterations(self): |     def test_fails_max_iterations(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], |             tools=[PythonInterpreterTool()], | ||||||
|             llm_engine=fake_code_llm_no_return,  # use this callable because it never ends |             llm_engine=fake_code_llm_no_return,  # use this callable because it never ends | ||||||
|  | @ -216,19 +257,19 @@ Action: | ||||||
| 
 | 
 | ||||||
|     def test_init_agent_with_different_toolsets(self): |     def test_init_agent_with_different_toolsets(self): | ||||||
|         toolset_1 = [] |         toolset_1 = [] | ||||||
|         agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_1, llm_engine=fake_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 1 |             len(agent.toolbox.tools) == 1 | ||||||
|         )  # when no tools are provided, only the final_answer tool is added by default |         )  # when no tools are provided, only the final_answer tool is added by default | ||||||
| 
 | 
 | ||||||
|         toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] |         toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] | ||||||
|         agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_2, llm_engine=fake_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 2 |             len(agent.toolbox.tools) == 2 | ||||||
|         )  # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer |         )  # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer | ||||||
| 
 | 
 | ||||||
|         toolset_3 = Toolbox(toolset_2) |         toolset_3 = Toolbox(toolset_2) | ||||||
|         agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_3, llm_engine=fake_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 2 |             len(agent.toolbox.tools) == 2 | ||||||
|         )  # same as previous one, where toolset_3 is an instantiation of previous one |         )  # same as previous one, where toolset_3 is an instantiation of previous one | ||||||
|  | @ -236,12 +277,12 @@ Action: | ||||||
|         # check that add_base_tools will not interfere with existing tools |         # check that add_base_tools will not interfere with existing tools | ||||||
|         with pytest.raises(KeyError) as e: |         with pytest.raises(KeyError) as e: | ||||||
|             agent = JsonAgent( |             agent = JsonAgent( | ||||||
|                 tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True |                 tools=toolset_3, llm_engine=fake_json_llm, add_base_tools=True | ||||||
|             ) |             ) | ||||||
|         assert "already exists in the toolbox" in str(e) |         assert "already exists in the toolbox" in str(e) | ||||||
| 
 | 
 | ||||||
|         # check that python_interpreter base tool does not get added to code agents |         # check that python_interpreter base tool does not get added to code agents | ||||||
|         agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) |         agent = CodeAgent(tools=[], llm_engine=fake_code_llm, add_base_tools=True) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 2 |             len(agent.toolbox.tools) == 2 | ||||||
|         )  # added final_answer tool + search |         )  # added final_answer tool + search | ||||||
|  | @ -249,7 +290,7 @@ Action: | ||||||
|     def test_function_persistence_across_steps(self): |     def test_function_persistence_across_steps(self): | ||||||
|         agent = CodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             llm_engine=fake_react_code_functiondef, |             llm_engine=fake_code_functiondef, | ||||||
|             max_iterations=2, |             max_iterations=2, | ||||||
|             additional_authorized_imports=["numpy"], |             additional_authorized_imports=["numpy"], | ||||||
|         ) |         ) | ||||||
|  | @ -257,17 +298,17 @@ Action: | ||||||
|         assert res[0] == 0.5 |         assert res[0] == 0.5 | ||||||
| 
 | 
 | ||||||
|     def test_init_managed_agent(self): |     def test_init_managed_agent(self): | ||||||
|         agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) |         agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef) | ||||||
|         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") |         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") | ||||||
|         assert managed_agent.name == "managed_agent" |         assert managed_agent.name == "managed_agent" | ||||||
|         assert managed_agent.description == "Empty" |         assert managed_agent.description == "Empty" | ||||||
| 
 | 
 | ||||||
|     def test_agent_description_gets_correctly_inserted_in_system_prompt(self): |     def test_agent_description_gets_correctly_inserted_in_system_prompt(self): | ||||||
|         agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) |         agent = CodeAgent(tools=[], llm_engine=fake_code_functiondef) | ||||||
|         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") |         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") | ||||||
|         manager_agent = CodeAgent( |         manager_agent = CodeAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             llm_engine=fake_react_code_functiondef, |             llm_engine=fake_code_functiondef, | ||||||
|             managed_agents=[managed_agent], |             managed_agents=[managed_agent], | ||||||
|         ) |         ) | ||||||
|         assert "You can also give requests to team members." not in agent.system_prompt |         assert "You can also give requests to team members." not in agent.system_prompt | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue