From f8b9cb34f93ac6c8ad4c366fed9989834cb9913a Mon Sep 17 00:00:00 2001 From: Aymeric Date: Thu, 26 Dec 2024 17:59:15 +0100 Subject: [PATCH] Fix additional args sent to e2b server --- README.md | 2 +- docs/source/guided_tour.md | 15 +++++++--- examples/e2b_example.py | 8 ++--- src/smolagents/agents.py | 40 +++++++------------------ src/smolagents/default_tools.py | 2 +- src/smolagents/e2b_executor.py | 35 ++++++++++++++++++---- src/smolagents/local_python_executor.py | 3 +- src/smolagents/tools.py | 26 ++++++++-------- 8 files changed, 73 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index f8142cd..cc0f7b0 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/ Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: - a secure python interpreter to run code more safely in your environment - - a sandboxed environment. + - a sandboxed environment using [E2B](https://e2b.dev/). ## How lightweight is it? diff --git a/docs/source/guided_tour.md b/docs/source/guided_tour.md index 00e43f7..f965483 100644 --- a/docs/source/guided_tour.md +++ b/docs/source/guided_tour.md @@ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient login("") -model_id = "Qwen/Qwen2.5-72B-Instruct" +model_id = "meta-llama/Llama-3.3-70B-Instruct" client = InferenceClient(model=model_id) @@ -71,12 +71,19 @@ agent.run( Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. -You can use this to indicate the path to local or remote files for the model to use: +You can use this to pass files in various formats: ```py -agent = CodeAgent(tools=[], model=model, add_base_tools=True) +from smolagents import CodeAgent, HfApiModel -agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") +model_id = "meta-llama/Llama-3.3-70B-Instruct" + +agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True) + +agent.run( + "Why does Mike not know many people in New York?", + additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'} +) ``` It's important to explain as clearly as possible the task you want to perform. diff --git a/examples/e2b_example.py b/examples/e2b_example.py index b288d2f..0f3b0e8 100644 --- a/examples/e2b_example.py +++ b/examples/e2b_example.py @@ -27,12 +27,11 @@ LAUNCH_GRADIO = False get_cat_image = GetCatImageTool() - agent = CodeAgent( tools = [get_cat_image, VisitWebpageTool()], model=HfApiModel(), additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search", - use_e2b_executor=False + use_e2b_executor=True ) if LAUNCH_GRADIO: @@ -41,6 +40,5 @@ if LAUNCH_GRADIO: GradioUI(agent).launch() else: agent.run( - "Return me an image of Lincoln's preferred pet", - additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" - ) + "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()} + ) # Asking to directly return the image from state tests that additional_args are properly sent to server. diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index f5f881e..ab4d256 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -188,6 +188,7 @@ class MultiStepAgent: self.tool_parser = tool_parser self.grammar = grammar self.planning_interval = planning_interval + self.state = {} self.managed_agents = {} if managed_agents is not None: @@ -370,8 +371,7 @@ class MultiStepAgent: return self.model(self.input_messages) except Exception as e: error_msg = f"Error in generating final LLM output:\n{e}" - console.print(f"[bold red]{error_msg}[/bold red]") - return error_msg + raise AgentGenerationError(error_msg) def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: """ @@ -385,7 +385,6 @@ class MultiStepAgent: available_tools = {**self.toolbox.tools, **self.managed_agents} if tool_name not in available_tools: error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." - console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) try: @@ -398,7 +397,6 @@ class MultiStepAgent: observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) else: error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." - console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) return observation except Exception as e: @@ -410,14 +408,12 @@ class MultiStepAgent: f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"As a reminder, this tool's description is the following:\n{tool_description}" ) - console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) elif tool_name in self.managed_agents: error_msg = ( f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" ) - console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) def step(self, log_entry: ActionStep) -> Union[None, Any]: @@ -430,7 +426,7 @@ class MultiStepAgent: stream: bool = False, reset: bool = True, single_step: bool = False, - **kwargs, + additional_args: Optional[Dict] = None, ): """ Runs the agent for the given task. @@ -440,6 +436,7 @@ class MultiStepAgent: stream (`bool`): Wether to run in a streaming way. reset (`bool`): Wether to reset the conversation or keep it going from previous run. single_step (`bool`): Should the agent run in one shot or multi-step fashion? + additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! Example: ```py @@ -449,11 +446,11 @@ class MultiStepAgent: ``` """ self.task = task - if len(kwargs) > 0: - self.task += ( - f"\nYou have been provided with these initial arguments: {str(kwargs)}." - ) - self.state = kwargs.copy() + if additional_args is not None: + self.state.update(additional_args) + self.task += f""" +You have been provided with these additional arguments, that you can access as variables in your python code using the keys: +{str(additional_args)}.""" self.initialize_system_prompt() system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) @@ -468,14 +465,9 @@ class MultiStepAgent: else: self.logs.append(system_prompt_step) - # console.print( - # Group( - # Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task) - # ) - # ) console.print( Panel( - f"\n[bold]{task.strip()}\n", + f"\n[bold]{self.task.strip()}\n", title="[bold]New run", subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", border_style=YELLOW_HEX, @@ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent): console.print_exception() raise AgentGenerationError(f"Error in generating model output:\n{e}") - # from rich.live import Live - # from rich.markdown import Markdown - # import time - - # with Live(console=console, vertical_overflow="visible") as live: - # message = "" - # for i in range(100): - # time.sleep(0.02) - # message += str(i) - # live.update(Markdown(message)) - if self.verbose: console.print( Group( @@ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent): try: output, execution_logs = self.python_executor( code_action, + self.state, ) execution_outputs_console = [] if len(execution_logs) > 0: diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 6339836..a9451f6 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool): pre_processor_class = WhisperProcessor model_class = WhisperForConditionalGeneration - inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} + inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}} output_type = "string" def encode(self, audio): diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index 65afbdc..7cce059 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -17,13 +17,14 @@ from dotenv import load_dotenv import textwrap import base64 +import pickle from io import BytesIO from PIL import Image from e2b_code_interpreter import Sandbox from typing import List, Tuple, Any from .tool_validation import validate_tool_attributes -from .utils import instance_to_source, BASE_BUILTIN_MODULES +from .utils import instance_to_source, BASE_BUILTIN_MODULES, console from .tools import Tool load_dotenv() @@ -40,6 +41,7 @@ class E2BExecutor: # timeout=300 # ) # print("Installation of agents package finished.") + additional_imports = additional_imports + ["pickle5"] if len(additional_imports) > 0: execution = self.sbx.commands.run( "pip install " + " ".join(additional_imports) @@ -47,7 +49,7 @@ class E2BExecutor: if execution.error: raise Exception(f"Error installing dependencies: {execution.error}") else: - print("Installation succeeded!") + console.print(f"Installation of {additional_imports} succeeded!") tool_codes = [] for tool in tools: @@ -71,21 +73,44 @@ class E2BExecutor: tool_definition_code += "\n\n".join(tool_codes) tool_definition_execution = self.run_code_raise_errors(tool_definition_code) - print(tool_definition_execution.logs) + console.print(tool_definition_execution.logs) def run_code_raise_errors(self, code: str): execution = self.sbx.run_code( code, ) if execution.error: - logs = "Executing code yielded an error:" + execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) + logs = execution_logs + logs += "Executing code yielded an error:" logs += execution.error.name logs += execution.error.value logs += execution.error.traceback raise ValueError(logs) return execution - def __call__(self, code_action: str) -> Tuple[Any, Any]: + def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]: + if len(additional_args) > 0: + # Pickle additional_args to server + import tempfile + + with tempfile.NamedTemporaryFile() as f: + pickle.dump(additional_args, f) + f.flush() + with open(f.name, "rb") as file: + self.sbx.files.write("/home/state.pkl", file) + remote_unloading_code = """import pickle +import os +print("File path", os.path.getsize('/home/state.pkl')) +with open('/home/state.pkl', 'rb') as f: + pickle_dict = pickle.load(f) +locals().update({key: value for key, value in pickle_dict.items()}) +""" + execution = self.run_code_raise_errors(remote_unloading_code) + execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) + console.print(execution_logs) + + execution = self.run_code_raise_errors(code_action) execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) if not execution.results: diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index d2e30a4..3cb04c0 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -1058,7 +1058,8 @@ class LocalPythonInterpreter: } # TODO: assert self.authorized imports are all installed locally - def __call__(self, code_action: str) -> Tuple[Any, str]: + def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]: + self.state.update(additional_variables) output = evaluate_python_code( code_action, static_tools=self.static_tools, diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index fd39933..50abe4a 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -201,20 +201,21 @@ class Tool: assert getattr(self, "output_type", None) in AUTHORIZED_TYPES - # Validate forward function signature - signature = inspect.signature(self.forward) + # Validate forward function signature, except for PipelineTool + if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True): + signature = inspect.signature(self.forward) - if not set(signature.parameters.keys()) == set(self.inputs.keys()): - raise Exception( - "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." - ) + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) - json_schema = _convert_type_hints_to_json_schema(self.forward) - for key, value in self.inputs.items(): - if "nullable" in value: - assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." - if key in json_schema and "nullable" in json_schema[key]: - assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." + json_schema = _convert_type_hints_to_json_schema(self.forward) + for key, value in self.inputs.items(): + if "nullable" in value: + assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." + if key in json_schema and "nullable" in json_schema[key]: + assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." def forward(self, *args, **kwargs): return NotImplementedError("Write this method in your subclass of `Tool`.") @@ -1074,6 +1075,7 @@ class PipelineTool(Tool): name = "pipeline" inputs = {"prompt": str} output_type = str + is_pipeline_tool = True def __init__( self,