From dccef6248b0a50c53ad2b1f7e3879c9b2b16438d Mon Sep 17 00:00:00 2001 From: Aymeric Date: Sat, 21 Dec 2024 23:11:15 +0100 Subject: [PATCH] Multiple documentation improvements --- docs/source/_toctree.yml | 2 + docs/source/conceptual_guides/intro_agents.md | 22 +-- docs/source/conceptual_guides/react.md | 23 ++- docs/source/examples/text_to_sql.md | 28 +--- docs/source/index.md | 19 ++- docs/source/tutorials/building_good_agents.md | 1 - examples/dummytool.py | 14 -- src/agents/__init__.py | 5 +- src/agents/agents.py | 153 +++++++++++------- src/agents/default_tools/base.py | 6 +- src/agents/default_tools/search.py | 1 + src/agents/docker_alternative.py | 25 +-- src/agents/e2b_executor.py | 39 +++-- src/agents/gradio_ui.py | 3 +- src/agents/llm_engines.py | 8 +- src/agents/local_python_executor.py | 6 +- src/agents/monitoring.py | 5 +- src/agents/prompts.py | 18 +++ src/agents/tool_validation.py | 55 ++++--- src/agents/tools.py | 47 +++--- src/agents/utils.py | 83 ++++++---- tests/test_agents.py | 4 +- tests/test_all_docs.py | 4 +- tests/test_tools_common.py | 31 ++-- 24 files changed, 366 insertions(+), 236 deletions(-) delete mode 100644 examples/dummytool.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b392929..1027139 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -10,6 +10,8 @@ title: ✨ Building good agents - local: tutorials/tools title: 🛠️ Tools - in-depth guide + - local: tutorials/secure_code_execution + title: 🛡️ Secure your code execution with E2B - title: Conceptual guides sections: - local: conceptual_guides/intro_agents diff --git a/docs/source/conceptual_guides/intro_agents.md b/docs/source/conceptual_guides/intro_agents.md index 6d5c38e..648fc03 100644 --- a/docs/source/conceptual_guides/intro_agents.md +++ b/docs/source/conceptual_guides/intro_agents.md @@ -37,6 +37,14 @@ Then it can get more agentic. - If you use an LLM output to determine which function is run and with which arguments, that's tool calling. - If you use an LLM output to determine if you should keep iterating in a while loop, you get a multi-step agent. +| Agency Level | Description | How that's called | Example Pattern | +|-------------|-------------|-------------|-----------------| +| No Agency | LLM output has no impact on program flow | Simple Processor | `process_llm_output(llm_response)` | +| Basic Agency | LLM output determines basic control flow | Router | `if llm_decision(): path_a() else: path_b()` | +| Higher Agency | LLM output determines function execution | Tool Caller | `run_function(llm_chosen_tool, llm_chosen_args)` | +| High Agency | LLM output controls iteration and program continuation | Multi-step Agent | `while llm_should_continue(): execute_next_step()` | +| High Agency | One agentic workflow can start another agentic workflow | Multi-Agent | `if llm_trigger(): execute_agent()` | + Since the system’s versatility goes in lockstep with the level of agency that you give to the LLM, agentic systems can perform much broader tasks than any classic program. Programs are not just tools anymore, confined to an ultra-specialized task : they are agents. @@ -83,11 +91,11 @@ But wait, since we give room to LLMs in decisions, surely they will make mistake These will not be that straightforward to implement correctly, especially not together. That's why we decided that we needed to build a few abstractions to help people use these. -### Most important feature: Code agent +### Code agents -[Multiple](https://huggingface.co/papers/2402.01030) [research](https://huggingface.co/papers/2411.01747) [papers](https://huggingface.co/papers/2401.00812) have shown that having the LLM write its actions (the tool calls) in code is much better than the current standard format JSON. +[Multiple](https://huggingface.co/papers/2402.01030) [research](https://huggingface.co/papers/2411.01747) [papers](https://huggingface.co/papers/2401.00812) have shown that having the LLM write its actions (the tool calls) in code is much better than the current standard format for tool calling, which is across the industry different shades of "writing actions as a JSON of tools names and arguments to use". -Why is that? Well, because we crafted our code languages specifically to be great at expressing actions performed by a computer. If JSON snippets was a better way, this package would have been written in JSON snippets and the devil would be having a great time laughing at us. +Why is code better? Well, because we crafted our code languages specifically to be great at expressing actions performed by a computer. If JSON snippets was a better way, this package would have been written in JSON snippets and the devil would be laughing at us. Code is just a better way to express actions on a computer. It has better: - **Composability:** could you nest JSON actions within each other, or define a set of JSON actions to re-use later, the same way you could just define a python function? @@ -95,10 +103,4 @@ Code is just a better way to express actions on a computer. It has better: - **Generality:** code is built to express simply anything you can do have a computer do. - **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses? -So we shoul use code as the main expression type for agent actions. - -Few existing framework build on this idea to make code agents first-class citizens. We focused on it! - -Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: -- a secure python interpreter to run code more safely in your environment -- a sandboxed environment. \ No newline at end of file +This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building python interpreters. \ No newline at end of file diff --git a/docs/source/conceptual_guides/react.md b/docs/source/conceptual_guides/react.md index a4b6277..d2995ea 100644 --- a/docs/source/conceptual_guides/react.md +++ b/docs/source/conceptual_guides/react.md @@ -21,14 +21,15 @@ This agent has a planning step, then generates python code to execute all its ac ## React agents -This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations. +This is the go-to agent to solve reasoning tasks. -We implement two versions of JsonAgent: -- [`JsonAgent`] generates tool calls as a JSON in its output. -- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. +The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents. -> [!TIP] -> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents. +The name is based on the concatenation of two words, "Reason" and "Act." Indeed, agents following this architecture will solve their task in as many steps as needed, each step consisting of a Reasoning step, then an Action step where it formulates tool calls that will bring it closer to solving the task at hand. + +React process involves keeping a memory of past steps. + +Here is a video overview of how that works:
-![Framework of a React Agent](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/open-source-llms-as-agents/ReAct.png) \ No newline at end of file +![Framework of a React Agent](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/open-source-llms-as-agents/ReAct.png) + +We implement two versions of JsonAgent: +- [`JsonAgent`] generates tool calls as a JSON in its output. +- [`CodeAgent`] is a new type of JsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance. + +> [!TIP] +> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about ReAct agents. + diff --git a/docs/source/examples/text_to_sql.md b/docs/source/examples/text_to_sql.md index 40f1a88..4d31475 100644 --- a/docs/source/examples/text_to_sql.md +++ b/docs/source/examples/text_to_sql.md @@ -17,7 +17,7 @@ rendered properly in your Markdown viewer. In this tutorial, we’ll see how to implement an agent that leverages SQL using `agents`. -What’s the advantage over a standard text-to-SQL pipeline? +> Let's start with the goldnen question: why not keep it simple and use a standard text-to-SQL pipeline? A standard text-to-sql pipeline is brittle, since the generated SQL query can be incorrect. Even worse, the query could be incorrect, but not raise an error, instead giving some incorrect/useless outputs without raising an alarm. @@ -69,22 +69,6 @@ for row in rows: cursor = connection.execute(stmt) ``` -Let’s check that our system works with a basic query: - -```py -with engine.connect() as con: - rows = con.execute(text("""SELECT * from receipts""")) - for row in rows: - print(row) -``` -Output: -```text -(1, 'Alan Payne', 12.06, 1.2) -(2, 'Alex Mason', 23.86, 0.24) -(3, 'Woodrow Wilson', 53.43, 5.43) -(4, 'Margaret James', 21.11, 1.0) -``` - ### Build our agent Now let’s make our SQL table retrievable by a tool. @@ -107,9 +91,9 @@ Columns: - tip: FLOAT ``` -Now let’s build our tool. It needs the following: (read the documentation for more detail) -- A docstring with an `Args:` part -- Type hints +Now let’s build our tool. It needs the following: (read [the tool doc](../tutorials/tools) for more detail) +- A docstring with an `Args:` part listing arguments. +- Type hints on both inputs and output. ```py from transformers.agents import tool @@ -179,7 +163,7 @@ for row in rows: with engine.begin() as connection: cursor = connection.execute(stmt) ``` -We need to update the `SQLExecutorTool` with this table’s description to let the LLM properly leverage information from this table. +Since we changed the table, we update the `SQLExecutorTool` with this table’s description to let the LLM properly leverage information from this table. ```py updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output. @@ -196,7 +180,7 @@ for table in ["receipts", "waiters"]: print(updated_description) ``` -Since this request is a bit harder than the previous one, we’ll switch the llm engine to use the more powerful [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)! +Since this request is a bit harder than the previous one, we’ll switch the LLM engine to use the more powerful [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct)! ```py sql_engine.description = updated_description diff --git a/docs/source/index.md b/docs/source/index.md index 116602e..04b3436 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -19,7 +19,24 @@ Agents is a library that enables you to run powerful agents in a few lines of co It is: - lightweight - understandable (we kept abstractions to the minimum) -- the only library with first-class support for Code Agents, i.e. agents that write their actions in code! Head to [./conceptual_guides/intro_agents.md] to learn more. +- the only library with first-class support for Code Agents, i.e. agents that write their actions in code! + +Here is a demo: + +## How lightweight is it? + +We strived to keep abstractions to a strict minimum. +You could go lower and code it all yourself, but some of this stuff is non-trivial. For instance, if you define a format for tool expression, you have to specify the same format in your system prompt, your parser, and your possibke error logging to let the LLM correct itself. + + +## Code agents? + +We can let LLMs powering agentic systems write their actions in code. This approach is demonstrated to work better than the current industry practice of letting the LLM output a dictionary of the tools it wants to calls: [uses 30% fewer steps](https://huggingface.co/papers/2402.01030) (thus 30% fewer LLM calls) +and [reaches higher performance on difficult benchmarks](https://huggingface.co/papers/2411.01747). Head to [./conceptual_guides/intro_agents.md] to learn more on that. + +Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: + - a secure python interpreter to run code more safely in your environment + - a sandboxed environment.
diff --git a/docs/source/tutorials/building_good_agents.md b/docs/source/tutorials/building_good_agents.md index 00996e0..ed01a71 100644 --- a/docs/source/tutorials/building_good_agents.md +++ b/docs/source/tutorials/building_good_agents.md @@ -182,5 +182,4 @@ agent = CodeAgent( result = agent.run( "How long would a cheetah at full speed take to run the length of Pont Alexandre III?", ) -print("RESULT:", result) ``` \ No newline at end of file diff --git a/examples/dummytool.py b/examples/dummytool.py deleted file mode 100644 index 75edfd9..0000000 --- a/examples/dummytool.py +++ /dev/null @@ -1,14 +0,0 @@ -from agents.tools import Tool - - -class DummyTool(Tool): - name = "echo" - description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. - Each result has keys 'title', 'href' and 'body'.""" - inputs = { - "cmd": {"type": "string", "description": "The search query to perform."} - } - output_type = "any" - - def forward(self, cmd: str) -> str: - return cmd \ No newline at end of file diff --git a/src/agents/__init__.py b/src/agents/__init__.py index ba8f8f1..cbdeab4 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -39,13 +39,14 @@ if TYPE_CHECKING: else: import sys + _file = globals()["__file__"] import_structure = define_import_structure(_file) - import_structure[""]= {"__version__": __version__} + import_structure[""] = {"__version__": __version__} sys.modules[__name__] = _LazyModule( __name__, _file, import_structure, module_spec=__spec__, - extra_objects={"__version__": __version__} + extra_objects={"__version__": __version__}, ) diff --git a/src/agents/agents.py b/src/agents/agents.py index 079444c..d76f70e 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -41,6 +41,7 @@ from .prompts import ( USER_PROMPT_PLAN, SYSTEM_PROMPT_PLAN_UPDATE, SYSTEM_PROMPT_PLAN, + MANAGED_AGENT_PROMPT, ) from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonInterpreter from .e2b_executor import E2BExecutor @@ -170,6 +171,7 @@ def format_prompt_with_managed_agents_descriptions( else: return prompt_template.replace(agent_descriptions_placeholder, "") + class BaseAgent: def __init__( self, @@ -677,7 +679,9 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.print(Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction)) + console.print( + Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction) + ) else: # update plan agent_memory = self.write_inner_memory_from_logs( summary_mode=False @@ -731,15 +735,14 @@ Now begin!""", self.logs.append( PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) ) - console.print(Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction)) - + console.print( + Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction) + ) class JsonAgent(ReactAgent): """ - This agent that solves the given task step by step, using the ReAct framework: - While the objective is not reached, the agent will perform a cycle of thinking and acting. - The tool calls will be formulated by the LLM in JSON format, then parsed and executed. + In this agent, the tool calls will be formulated by the LLM in JSON format, then parsed and executed. """ def __init__( @@ -781,10 +784,16 @@ class JsonAgent(ReactAgent): log_entry.agent_memory = agent_memory.copy() if self.verbose: - console.print(Group( - Rule("[italic]Calling LLM engine with this last message:", align="left", style="orange"), - Text(str(self.prompt_messages[-1])) - )) + console.print( + Group( + Rule( + "[italic]Calling LLM engine with this last message:", + align="left", + style="orange", + ), + Text(str(self.prompt_messages[-1])), + ) + ) try: additional_args = ( @@ -800,10 +809,16 @@ class JsonAgent(ReactAgent): raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") if self.verbose: - console.print(Group( - Rule("[italic]Output message of the LLM:", align="left", style="orange"), - Text(llm_output) - )) + console.print( + Group( + Rule( + "[italic]Output message of the LLM:", + align="left", + style="orange", + ), + Text(llm_output), + ) + ) # Parse rationale, action = self.extract_action( @@ -819,7 +834,9 @@ class JsonAgent(ReactAgent): # Execute console.print(Rule("Agent thoughts:", align="left"), Text(rationale)) - console.print(Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}"))) + console.print( + Panel(Text(f"Calling tool: '{tool_name}' with arguments: {arguments}")) + ) if tool_name == "final_answer": if isinstance(arguments, dict): if "answer" in arguments: @@ -856,9 +873,7 @@ class JsonAgent(ReactAgent): class CodeAgent(ReactAgent): """ - This agent that solves the given task step by step, using the ReAct framework: - While the objective is not reached, the agent will perform a cycle of thinking and acting. - The tool calls will be formulated by the LLM in code format, then parsed and executed. + In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed. """ def __init__( @@ -893,13 +908,19 @@ class CodeAgent(ReactAgent): additional_authorized_imports if additional_authorized_imports else [] ) if use_e2b_executor and len(self.managed_agents) > 0: - raise Exception(f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution.") + raise Exception( + f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution." + ) all_tools = {**self.toolbox.tools, **self.managed_agents} if use_e2b_executor: - self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values())) + self.python_executor = E2BExecutor( + self.additional_authorized_imports, list(all_tools.values()) + ) else: - self.python_executor = LocalPythonInterpreter(self.additional_authorized_imports, all_tools) + self.python_executor = LocalPythonInterpreter( + self.additional_authorized_imports, all_tools + ) self.authorized_imports = list( set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) ) @@ -924,10 +945,16 @@ class CodeAgent(ReactAgent): log_entry.agent_memory = agent_memory.copy() if self.verbose: - console.print(Group( - Rule("[italic]Calling LLM engine with these last messages:", align="left", style="orange"), - Text(str(self.prompt_messages[-2:])) - )) + console.print( + Group( + Rule( + "[italic]Calling LLM engine with these last messages:", + align="left", + style="orange", + ), + Text(str(self.prompt_messages[-2:])), + ) + ) try: additional_args = ( @@ -943,10 +970,16 @@ class CodeAgent(ReactAgent): raise AgentGenerationError(f"Error in generating llm_engine output: {e}.") if self.verbose: - console.print(Group( - Rule("[italic]Output message of the LLM:", align="left", style="orange"), - Syntax(llm_output, lexer="markdown", theme="github-dark") - )) + console.print( + Group( + Rule( + "[italic]Output message of the LLM:", + align="left", + style="orange", + ), + Syntax(llm_output, lexer="markdown", theme="github-dark"), + ) + ) # Parse try: @@ -971,13 +1004,16 @@ class CodeAgent(ReactAgent): # Execute if self.verbose: - console.print(Group( - Rule("[italic]Agent thoughts", align="left"), - Text(rationale) - )) + console.print( + Group(Rule("[italic]Agent thoughts", align="left"), Text(rationale)) + ) - console.print(Panel( - Syntax(code_action, lexer="python", theme="github-dark"), title="[bold]Agent is executing the code below:", title_align="left") + console.print( + Panel( + Syntax(code_action, lexer="python", theme="github-dark"), + title="[bold]Agent is executing the code below:", + title_align="left", + ) ) try: @@ -985,13 +1021,18 @@ class CodeAgent(ReactAgent): code_action, ) if len(execution_logs) > 0: - console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs))) + console.print( + Group(Text("Execution logs:", style="bold"), Text(execution_logs)) + ) observation = "Execution logs:\n" + execution_logs if output is not None: - truncated_output = truncate_content( - str(output) + truncated_output = truncate_content(str(output)) + console.print( + Group( + Text("Last output from code snippet:", style="bold"), + Text(truncated_output), + ) ) - console.print(Group(Text("Last output from code snippet:", style="bold"), Text(truncated_output))) observation += "Last output from code snippet:\n" + truncate_content( str(output) ) @@ -1003,44 +1044,38 @@ class CodeAgent(ReactAgent): raise AgentExecutionError(error_msg) for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": - console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green"))) + console.print( + Group( + Text("Final answer:", style="bold"), + Text(str(output), style="bold green"), + ) + ) log_entry.action_output = output return output - class ManagedAgent: def __init__( self, agent, name, description, - additional_prompting=None, - provide_run_summary=False, + additional_prompting: Optional[str] = None, + provide_run_summary: bool = False, + managed_agent_prompt: Optional[str] = None, ): self.agent = agent self.name = name self.description = description self.additional_prompting = additional_prompting self.provide_run_summary = provide_run_summary + self.managed_agent_prompt = ( + managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT + ) def write_full_task(self, task): - full_task = f"""You're a helpful agent named '{self.name}'. -You have been submitted this task by your manager. ---- -Task: -{task} ---- -You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer. - -Your final_answer WILL HAVE to contain these parts: -### 1. Task outcome (short version): -### 2. Task outcome (extremely detailed version): -### 3. Additional context (if relevant): - -Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. -And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. -{{additional_prompting}}""" + """Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" + full_task = self.managed_agent_prompt.format(name=self.name, task=task) if self.additional_prompting: full_task = full_task.replace( "\n{{additional_prompting}}", self.additional_prompting diff --git a/src/agents/default_tools/base.py b/src/agents/default_tools/base.py index 65d6a9e..ceedd3b 100644 --- a/src/agents/default_tools/base.py +++ b/src/agents/default_tools/base.py @@ -21,7 +21,11 @@ from typing import Dict from huggingface_hub import hf_hub_download, list_spaces from transformers.utils import is_offline_mode -from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code +from ..local_python_executor import ( + BASE_BUILTIN_MODULES, + BASE_PYTHON_TOOLS, + evaluate_python_code, +) from ..tools import TOOL_CONFIG_FILE, Tool diff --git a/src/agents/default_tools/search.py b/src/agents/default_tools/search.py index fad2d46..cd72dbd 100644 --- a/src/agents/default_tools/search.py +++ b/src/agents/default_tools/search.py @@ -18,6 +18,7 @@ import re from ..tools import Tool + class DuckDuckGoSearchTool(Tool): name = "web_search" description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. diff --git a/src/agents/docker_alternative.py b/src/agents/docker_alternative.py index b035c7e..4dd4aa9 100644 --- a/src/agents/docker_alternative.py +++ b/src/agents/docker_alternative.py @@ -1,12 +1,13 @@ import docker -from typing import List, Optional +from typing import List, Optional import warnings import socket from agents.tools import Tool + class DockerPythonInterpreter: - def __init__(self): + def __init__(self): self.container = None try: self.client = docker.from_env() @@ -15,14 +16,14 @@ class DockerPythonInterpreter: raise RuntimeError( "Could not connect to Docker daemon. Please ensure Docker is installed and running." ) - - try: + + try: self.container = self.client.containers.run( "pyrunner:latest", - ports={'65432/tcp': 65432}, + ports={"65432/tcp": 65432}, detach=True, remove=True, - ) + ) except docker.errors.DockerException as e: raise RuntimeError(f"Failed to create Docker container: {e}") @@ -30,7 +31,7 @@ class DockerPythonInterpreter: """Cleanup: Stop and remove container when object is destroyed""" if self.container: try: - self.container.kill() # can consider .stop(), but this is faster + self.container.kill() # can consider .stop(), but this is faster except Exception as e: warnings.warn(f"Failed to stop Docker container: {e}") @@ -39,7 +40,7 @@ class DockerPythonInterpreter: Execute Python code in the container and return stdout and stderr """ - if tools != None: + if tools != None: tool_instance = tools[0]() import_code = f""" @@ -56,13 +57,13 @@ web_search = getattr(module, class_name)() try: # Connect to the server running inside the container - with socket.create_connection(('localhost', 65432)) as sock: - sock.sendall(code.encode('utf-8')) + with socket.create_connection(("localhost", 65432)) as sock: + sock.sendall(code.encode("utf-8")) output = sock.recv(4096) - return output.decode('utf-8') + return output.decode("utf-8") except Exception as e: return f"Error executing code: {str(e)}" -__all__ = ["DockerPythonInterpreter"] \ No newline at end of file +__all__ = ["DockerPythonInterpreter"] diff --git a/src/agents/e2b_executor.py b/src/agents/e2b_executor.py index e7d199a..a3ab9b6 100644 --- a/src/agents/e2b_executor.py +++ b/src/agents/e2b_executor.py @@ -21,19 +21,18 @@ from io import BytesIO from PIL import Image from e2b_code_interpreter import Sandbox -from typing import Dict, List, Callable, Tuple, Any +from typing import List, Tuple, Any from .tool_validation import validate_tool_attributes from .utils import instance_to_source, BASE_BUILTIN_MODULES from .tools import Tool -from .types import AgentImage load_dotenv() -class E2BExecutor(): +class E2BExecutor: def __init__(self, additional_imports: List[str], tools: List[Tool]): self.custom_tools = {} - self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") + self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") # TODO: validate installing agents package or not # print("Installing agents package on remote executor...") # self.sbx.commands.run( @@ -42,7 +41,9 @@ class E2BExecutor(): # ) # print("Installation of agents package finished.") if len(additional_imports) > 0: - execution = self.sbx.commands.run("pip install " + " ".join(additional_imports)) + execution = self.sbx.commands.run( + "pip install " + " ".join(additional_imports) + ) if execution.error: raise Exception(f"Error installing dependencies: {execution.error}") else: @@ -56,7 +57,9 @@ class E2BExecutor(): tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n" tool_codes.append(tool_code) - tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) + tool_definition_code = "\n".join( + [f"import {module}" for module in BASE_BUILTIN_MODULES] + ) tool_definition_code += textwrap.dedent(""" class Tool: def __call__(self, *args, **kwargs): @@ -75,7 +78,7 @@ class Tool: code, ) if execution.error: - logs = 'Executing code yielded an error:' + logs = "Executing code yielded an error:" logs += execution.error.name logs += execution.error.value logs += execution.error.traceback @@ -90,14 +93,28 @@ class Tool: else: for result in execution.results: if result.is_main_result: - for attribute_name in ['jpeg', 'png']: + for attribute_name in ["jpeg", "png"]: if getattr(result, attribute_name) is not None: image_output = getattr(result, attribute_name) - decoded_bytes = base64.b64decode(image_output.encode('utf-8')) + decoded_bytes = base64.b64decode( + image_output.encode("utf-8") + ) return Image.open(BytesIO(decoded_bytes)), execution_logs - for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']: + for attribute_name in [ + "chart", + "data", + "html", + "javascript", + "json", + "latex", + "markdown", + "pdf", + "svg", + "text", + ]: if getattr(result, attribute_name) is not None: return getattr(result, attribute_name), execution_logs raise ValueError("No main result returned by executor!") -__all__ = ["E2BExecutor"] \ No newline at end of file + +__all__ = ["E2BExecutor"] diff --git a/src/agents/gradio_ui.py b/src/agents/gradio_ui.py index e332451..8b2f6e2 100644 --- a/src/agents/gradio_ui.py +++ b/src/agents/gradio_ui.py @@ -58,7 +58,8 @@ def stream_to_gradio( for message in pull_messages_from_step(step_log, test_mode=test_mode): yield message - final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer + final_answer = step_log # Last log is the run's final_answer + final_answer = handle_agent_output_types(final_answer) if isinstance(final_answer, AgentText): yield gr.ChatMessage( diff --git a/src/agents/llm_engines.py b/src/agents/llm_engines.py index 4062222..7e24f75 100644 --- a/src/agents/llm_engines.py +++ b/src/agents/llm_engines.py @@ -58,12 +58,14 @@ llama_role_conversions = { MessageRole.TOOL_RESPONSE: MessageRole.USER, } + def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str: for stop_seq in stop_sequences: if content[-len(stop_seq) :] == stop_seq: content = content[: -len(stop_seq)] return content + def get_clean_message_list( message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} ) -> List[Dict[str, str]]: @@ -204,7 +206,6 @@ class HfApiEngine(HfEngine): grammar: Optional[str] = None, max_tokens: int = 1500, ) -> str: - # Get clean message list messages = get_clean_message_list( messages, role_conversions=llama_role_conversions ) @@ -235,7 +236,9 @@ class TransformersEngine(HfEngine): super().__init__() if model_id is None: model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" - logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'") + logger.warning( + f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" + ) try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) except Exception as e: @@ -254,7 +257,6 @@ class TransformersEngine(HfEngine): grammar: Optional[str] = None, max_tokens: int = 1500, ) -> str: - # Get clean message list messages = get_clean_message_list( messages, role_conversions=llama_role_conversions ) diff --git a/src/agents/local_python_executor.py b/src/agents/local_python_executor.py index 0546baf..53f67f5 100644 --- a/src/agents/local_python_executor.py +++ b/src/agents/local_python_executor.py @@ -46,6 +46,7 @@ ERRORS = { PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 + def custom_print(*args): return None @@ -103,6 +104,8 @@ BASE_PYTHON_TOOLS = { "issubclass": issubclass, "type": type, } + + class BreakException(Exception): pass @@ -1043,7 +1046,7 @@ def evaluate_python_code( raise InterpreterError(msg) -class LocalPythonInterpreter(): +class LocalPythonInterpreter: def __init__(self, additional_authorized_imports: List[str], tools: Dict): self.custom_tools = {} self.state = {} @@ -1069,4 +1072,5 @@ class LocalPythonInterpreter(): logs = self.state["print_outputs"] return output, logs + __all__ = ["evaluate_python_code", "LocalPythonInterpreter"] diff --git a/src/agents/monitoring.py b/src/agents/monitoring.py index a89cb5d..a636cf5 100644 --- a/src/agents/monitoring.py +++ b/src/agents/monitoring.py @@ -18,6 +18,7 @@ from .utils import console from rich.text import Text from rich.console import Group + class Monitor: def __init__(self, tracked_llm_engine): self.step_durations = [] @@ -34,7 +35,7 @@ class Monitor: self.step_durations.append(step_duration) console_outputs = [ Text(f"Step {len(self.step_durations)}:", style="bold"), - Text(f"- Time taken: {step_duration:.2f} seconds") + Text(f"- Time taken: {step_duration:.2f} seconds"), ] if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: @@ -46,7 +47,7 @@ class Monitor: ) console_outputs += [ Text(f"- Input tokens: {self.total_input_token_count:,}"), - Text(f"- Output tokens: {self.total_output_token_count:,}") + Text(f"- Output tokens: {self.total_output_token_count:,}"), ] console.print(Group(*console_outputs)) diff --git a/src/agents/prompts.py b/src/agents/prompts.py index 05721f2..5e9aeb6 100644 --- a/src/agents/prompts.py +++ b/src/agents/prompts.py @@ -491,10 +491,28 @@ Here is my new/updated plan of action to solve the task: {plan_update} ```""" +MANAGED_AGENT_PROMPT = """You're a helpful agent named '{name}'. +You have been submitted this task by your manager. +--- +Task: +{task} +--- +You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible to give them a clear understanding of the answer. + +Your final_answer WILL HAVE to contain these parts: +### 1. Task outcome (short version): +### 2. Task outcome (extremely detailed version): +### 3. Additional context (if relevant): + +Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. +And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. +{{additional_prompting}}""" + __all__ = [ "USER_PROMPT_PLAN_UPDATE", "PLAN_UPDATE_FINAL_PLAN_REDACTION", "ONESHOT_CODE_SYSTEM_PROMPT", "CODE_SYSTEM_PROMPT", "JSON_SYSTEM_PROMPT", + "MANAGED_AGENT_PROMPT", ] diff --git a/src/agents/tool_validation.py b/src/agents/tool_validation.py index cd71475..59079e0 100644 --- a/src/agents/tool_validation.py +++ b/src/agents/tool_validation.py @@ -2,8 +2,7 @@ import ast import inspect import importlib.util import builtins -from pathlib import Path -from typing import List, Set, Dict +from typing import Set import textwrap from .utils import BASE_BUILTIN_MODULES @@ -11,6 +10,7 @@ _BUILTIN_NAMES = set(vars(builtins)) IMPORTED_PACKAGES = BASE_BUILTIN_MODULES + def is_installed_package(module_name: str) -> bool: """ Check if an import is from an installed package. @@ -20,22 +20,24 @@ def is_installed_package(module_name: str) -> bool: spec = importlib.util.find_spec(module_name) if spec is None: return False # If we can't find the module, assume it's local - + # If the module is found and has a file path, check if it's in site-packages - if spec.origin and 'site-packages' not in spec.origin: + if spec.origin and "site-packages" not in spec.origin: # Check if it's a .py file in the current directory or subdirectories - return not spec.origin.endswith('.py') + return not spec.origin.endswith(".py") return False except ImportError: return False # If there's an import error, assume it's local + class MethodChecker(ast.NodeVisitor): """ Checks that a method - only uses defined names - contains no local imports (e.g. numpy is ok but local_script is not) """ + def __init__(self, class_attributes: Set[str], check_imports: bool = True): self.undefined_names = set() self.imports = {} @@ -53,22 +55,26 @@ class MethodChecker(ast.NodeVisitor): self.arg_names.add(node.kwarg.arg) if node.vararg: self.arg_names.add(node.vararg.arg) - + def visit_Import(self, node): for name in node.names: actual_name = name.asname or name.name if not is_installed_package(actual_name) and self.check_imports: - self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'") + self.errors.append( + f"Package not found in importlib, might be a local install: '{actual_name}'" + ) self.imports[actual_name] = name.name - + def visit_ImportFrom(self, node): module = node.module or "" for name in node.names: actual_name = name.asname or name.name if not is_installed_package(module) and self.check_imports: - self.errors.append(f"Package not found in importlib, might be a local install: '{module}'") + self.errors.append( + f"Package not found in importlib, might be a local install: '{module}'" + ) self.from_imports[actual_name] = (module, name.name) - + def visit_Assign(self, node): for target in node.targets: if isinstance(target, ast.Name): @@ -136,10 +142,11 @@ class MethodChecker(ast.NodeVisitor): or node.func.id in self.imports or node.func.id in self.from_imports or node.func.id in self.assigned_names - ): + ): self.errors.append(f"Name '{node.func.id}' is undefined.") self.generic_visit(node) + def validate_tool_attributes(cls, check_imports: bool = True) -> None: """ Validates that a Tool class follows the proper patterns: @@ -163,11 +170,15 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: raise ValueError("Source code must define a class") # Check that __init__ method takes no arguments - if not cls.__init__.__qualname__ == 'Tool.__init__': + if not cls.__init__.__qualname__ == "Tool.__init__": sig = inspect.signature(cls.__init__) - non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]) + non_self_params = list( + [arg_name for arg_name in sig.parameters.keys() if arg_name != "self"] + ) if len(non_self_params) > 0: - errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!") + errors.append( + f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!" + ) class_node = tree.body[0] @@ -193,15 +204,19 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: self.class_attributes.add(target.id) # Check if the assignment is more complex than simple literals - if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)) - for val in ast.walk(node.value)): + if not all( + isinstance( + val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set) + ) + for val in ast.walk(node.value) + ): for target in node.targets: if isinstance(target, ast.Name): self.complex_attributes.add(target.id) - + class_level_checker = ClassLevelChecker() class_level_checker.visit(class_node) - + if class_level_checker.complex_attributes: errors.append( f"Complex attributes should be defined in __init__, not as class attributes: " @@ -211,7 +226,9 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None: # Run checks on all methods for node in class_node.body: if isinstance(node, ast.FunctionDef): - method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports) + method_checker = MethodChecker( + class_level_checker.class_attributes, check_imports=check_imports + ) method_checker.visit(node) errors += [f"- {node.name}: {error}" for error in method_checker.errors] diff --git a/src/agents/tools.py b/src/agents/tools.py index 5640413..3acf9f2 100644 --- a/src/agents/tools.py +++ b/src/agents/tools.py @@ -15,20 +15,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast -import base64 -import builtins import importlib import inspect -import io import json import os -import re import tempfile import textwrap from functools import lru_cache, wraps from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, Set -import math +from typing import Callable, Dict, List, Optional, Union from huggingface_hub import ( create_repo, @@ -37,7 +32,7 @@ from huggingface_hub import ( metadata_update, upload_folder, ) -from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session +from huggingface_hub.utils import RepositoryNotFoundError from packaging import version from transformers.utils import ( @@ -46,7 +41,6 @@ from transformers.utils import ( get_json_schema, is_accelerate_available, is_torch_available, - is_vision_available, ) from transformers.dynamic_module_utils import get_imports from .types import ImageType, handle_agent_input_types, handle_agent_output_types @@ -67,6 +61,7 @@ if is_accelerate_available(): TOOL_CONFIG_FILE = "tool_config.json" + def get_repo_type(repo_id, repo_type=None, **hub_kwargs): if repo_type is not None: return repo_type @@ -240,7 +235,7 @@ class Tool: method_checker = MethodChecker(set()) method_checker.visit(forward_node) if len(method_checker.errors) > 0: - raise(ValueError("\n".join(method_checker.errors))) + raise (ValueError("\n".join(method_checker.errors))) forward_source_code = inspect.getsource(self.forward) tool_code = textwrap.dedent(f""" @@ -253,16 +248,17 @@ class Tool: output_type = "{self.output_type}" """).strip() import re + def add_self_argument(source_code: str) -> str: """Add 'self' as first argument to a function definition if not present.""" - pattern = r'def forward\(((?!self)[^)]*)\)' - + pattern = r"def forward\(((?!self)[^)]*)\)" + def replacement(match): args = match.group(1).strip() if args: # If there are other arguments - return f'def forward(self, {args})' - return 'def forward(self)' - + return f"def forward(self, {args})" + return "def forward(self)" + return re.sub(pattern, replacement, source_code) forward_source_code = forward_source_code.replace(self.name, "forward") @@ -270,10 +266,14 @@ class Tool: forward_source_code = forward_source_code.replace("@tool", "").strip() tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") - else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool - if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]: + else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool + if type(self).__name__ in [ + "SpaceToolWrapper", + "LangChainToolWrapper", + "GradioToolWrapper", + ]: raise ValueError( - f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors." + "Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors." ) validate_tool_attributes(self.__class__) @@ -286,14 +286,16 @@ class Tool: # Save app file app_file = os.path.join(output_dir, "app.py") with open(app_file, "w", encoding="utf-8") as f: - f.write(textwrap.dedent(f""" + f.write( + textwrap.dedent(f""" from agents import launch_gradio_demo from tool import {class_name} tool = {class_name}() launch_gradio_demo(tool) - """).lstrip()) + """).lstrip() + ) # Save requirements file requirements_file = os.path.join(output_dir, "requirements.txt") @@ -570,6 +572,7 @@ class Tool: def sanitize_argument_for_prediction(self, arg): from gradio_client.utils import is_http_url_like + if isinstance(arg, ImageType): temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) arg.save(temp_file.name) @@ -732,9 +735,7 @@ def launch_gradio_demo(tool: Tool): new_component = input_gradio_component_class(label=input_name) gradio_inputs.append(new_component) - output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[ - tool.output_type - ] + output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool.output_type] gradio_output = output_gradio_componentclass(label="Output") gr.Interface( @@ -893,7 +894,7 @@ def tool(tool_function: Callable) -> Tool: parameters["description"], parameters["parameters"]["properties"], parameters["return"]["type"], - function=tool_function + function=tool_function, ) original_signature = inspect.signature(tool_function) new_parameters = [ diff --git a/src/agents/utils.py b/src/agents/utils.py index cf3d324..1462c88 100644 --- a/src/agents/utils.py +++ b/src/agents/utils.py @@ -19,7 +19,6 @@ import re from typing import Tuple, Dict, Union import ast from rich.console import Console -import ast import inspect import types @@ -46,6 +45,8 @@ BASE_BUILTIN_MODULES = [ "statistics", "unicodedata", ] + + def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") @@ -141,9 +142,9 @@ class ImportFinder(ast.NodeVisitor): base_package = node.module.split(".")[0] self.packages.add(base_package) + import ast -import builtins -from typing import Set, Dict, List +from typing import Dict def get_method_source(method): @@ -158,17 +159,20 @@ def is_same_method(method1, method2): try: source1 = get_method_source(method1) source2 = get_method_source(method2) - + # Remove method decorators if any - source1 = '\n'.join(line for line in source1.split('\n') - if not line.strip().startswith('@')) - source2 = '\n'.join(line for line in source2.split('\n') - if not line.strip().startswith('@')) - + source1 = "\n".join( + line for line in source1.split("\n") if not line.strip().startswith("@") + ) + source2 = "\n".join( + line for line in source2.split("\n") if not line.strip().startswith("@") + ) + return source1 == source2 except (TypeError, OSError): return False + def is_same_item(item1, item2): """Compare two class items (methods or attributes) for equality.""" if callable(item1) and callable(item2): @@ -176,29 +180,34 @@ def is_same_item(item1, item2): else: return item1 == item2 + def instance_to_source(instance, base_cls=None): """Convert an instance to its class source code representation.""" cls = instance.__class__ class_name = cls.__name__ - + # Start building class lines class_lines = [] if base_cls: class_lines.append(f"class {class_name}({base_cls.__name__}):") else: class_lines.append(f"class {class_name}:") - + # Add docstring if it exists and differs from base if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__): class_lines.append(f' """{cls.__doc__}"""') - + # Add class-level attributes class_attrs = { - name: value for name, value in cls.__dict__.items() - if not name.startswith('__') and not callable(value) and - not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) + name: value + for name, value in cls.__dict__.items() + if not name.startswith("__") + and not callable(value) + and not ( + base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value + ) } - + for name, value in class_attrs.items(): if isinstance(value, str): if "\n" in value: @@ -206,39 +215,44 @@ def instance_to_source(instance, base_cls=None): else: class_lines.append(f' {name} = "{value}"') else: - class_lines.append(f' {name} = {repr(value)}') - + class_lines.append(f" {name} = {repr(value)}") + if class_attrs: class_lines.append("") - + # Add methods methods = { - name: func for name, func in cls.__dict__.items() - if callable(func) and - not (base_cls and hasattr(base_cls, name) and - getattr(base_cls, name).__code__.co_code == func.__code__.co_code) + name: func + for name, func in cls.__dict__.items() + if callable(func) + and not ( + base_cls + and hasattr(base_cls, name) + and getattr(base_cls, name).__code__.co_code == func.__code__.co_code + ) } - + for name, method in methods.items(): method_source = inspect.getsource(method) # Clean up the indentation - method_lines = method_source.split('\n') + method_lines = method_source.split("\n") first_line = method_lines[0] indent = len(first_line) - len(first_line.lstrip()) method_lines = [line[indent:] for line in method_lines] - method_source = '\n'.join([' ' + line if line.strip() else line - for line in method_lines]) + method_source = "\n".join( + [" " + line if line.strip() else line for line in method_lines] + ) class_lines.append(method_source) class_lines.append("") - + # Find required imports using ImportFinder import_finder = ImportFinder() - import_finder.visit(ast.parse('\n'.join(class_lines))) + import_finder.visit(ast.parse("\n".join(class_lines))) required_imports = import_finder.packages - + # Build final code with imports final_lines = [] - + # Add base class import if needed if base_cls: final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") @@ -246,13 +260,14 @@ def instance_to_source(instance, base_cls=None): # Add discovered imports for package in required_imports: final_lines.append(f"import {package}") - + if final_lines: # Add empty line after imports final_lines.append("") # Add the class code final_lines.extend(class_lines) - - return '\n'.join(final_lines) + + return "\n".join(final_lines) + __all__ = [] diff --git a/tests/test_agents.py b/tests/test_agents.py index 539f4cf..41689dd 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -232,7 +232,9 @@ Action: def test_additional_args_added_to_task(self): agent = CodeAgent(tools=[], llm_engine=fake_code_llm) - output = agent.run("What is 2 multiplied by 3.6452?", additional_instruction="Remember this.") + output = agent.run( + "What is 2 multiplied by 3.6452?", additional_instruction="Remember this." + ) assert "Remember this" in agent.task assert "Remember this" in str(agent.prompt_messages) diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index a3feb98..d1467ac 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -125,7 +125,9 @@ class TestDocs: "from_langchain", ] code_blocks = [ - block.replace("", self.hf_token).replace("{your_username}", "m-ric") + block.replace("", self.hf_token).replace( + "{your_username}", "m-ric" + ) for block in code_blocks if not any( [snippet in block for snippet in excluded_snippets] diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index 4875f9f..dd30078 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -174,7 +174,7 @@ class ToolTests(unittest.TestCase): Gets the current time. """ return str(datetime.now()) - + get_current_time.save("output") assert "datetime" in str(e) @@ -190,7 +190,7 @@ class ToolTests(unittest.TestCase): def forward(self): return str(datetime.now()) - + get_current_time = GetCurrentTimeTool() get_current_time.save("output") @@ -214,14 +214,17 @@ class ToolTests(unittest.TestCase): def forward(self): from datetime import datetime + return str(datetime.now()) - + def test_saving_tool_allows_no_arg_in_init(self): # Test one cannot save tool with additional args in init class FailTool(Tool): name = "specific" description = "test description" - inputs = {"input_str": {"type": "string", "description": "input description"}} + inputs = { + "input_str": {"type": "string", "description": "input description"} + } output_type = "string" def __init__(self, url): @@ -233,16 +236,19 @@ class ToolTests(unittest.TestCase): fail_tool = FailTool("dummy_url") with pytest.raises(Exception) as e: - fail_tool.save('output') - assert '__init__' in str(e) + fail_tool.save("output") + assert "__init__" in str(e) def test_saving_tool_allows_no_imports_from_outside_methods(self): # Test that using imports from outside functions fails from numpy import random + class FailTool2(Tool): name = "specific" description = "test description" - inputs = {"input_str": {"type": "string", "description": "input description"}} + inputs = { + "input_str": {"type": "string", "description": "input description"} + } output_type = "string" def useless_method(self): @@ -254,19 +260,22 @@ class ToolTests(unittest.TestCase): fail_tool_2 = FailTool2() with pytest.raises(Exception) as e: - fail_tool_2.save('output') - assert 'random' in str(e) + fail_tool_2.save("output") + assert "random" in str(e) # Test that putting these imports inside functions works class FailTool3(Tool): name = "specific" description = "test description" - inputs = {"input_str": {"type": "string", "description": "input description"}} + inputs = { + "input_str": {"type": "string", "description": "input description"} + } output_type = "string" def useless_method(self): from numpy import random + self.client = random.random() return "" @@ -274,4 +283,4 @@ class ToolTests(unittest.TestCase): return self.useless_method() + string_input fail_tool_3 = FailTool3() - fail_tool_3.save('output') + fail_tool_3.save("output")