Fix additional args sent to e2b server
This commit is contained in:
		
							parent
							
								
									1abaf69b67
								
							
						
					
					
						commit
						f8b9cb34f9
					
				|  | @ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/ | |||
| 
 | ||||
| Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: | ||||
|   - a secure python interpreter to run code more safely in your environment | ||||
|   - a sandboxed environment. | ||||
|   - a sandboxed environment using [E2B](https://e2b.dev/). | ||||
| 
 | ||||
| ## How lightweight is it? | ||||
| 
 | ||||
|  |  | |||
|  | @ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient | |||
| 
 | ||||
| login("<YOUR_HUGGINGFACEHUB_API_TOKEN>") | ||||
| 
 | ||||
| model_id = "Qwen/Qwen2.5-72B-Instruct" | ||||
| model_id = "meta-llama/Llama-3.3-70B-Instruct" | ||||
| 
 | ||||
| client = InferenceClient(model=model_id) | ||||
| 
 | ||||
|  | @ -71,12 +71,19 @@ agent.run( | |||
| 
 | ||||
| Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. | ||||
| 
 | ||||
| You can use this to indicate the path to local or remote files for the model to use: | ||||
| You can use this to pass files in various formats: | ||||
| 
 | ||||
| ```py | ||||
| agent = CodeAgent(tools=[], model=model, add_base_tools=True) | ||||
| from smolagents import CodeAgent, HfApiModel | ||||
| 
 | ||||
| agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") | ||||
| model_id = "meta-llama/Llama-3.3-70B-Instruct" | ||||
| 
 | ||||
| agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True) | ||||
| 
 | ||||
| agent.run( | ||||
|     "Why does Mike not know many people in New York?", | ||||
|     additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'} | ||||
| ) | ||||
| ``` | ||||
| 
 | ||||
| It's important to explain as clearly as possible the task you want to perform. | ||||
|  |  | |||
|  | @ -27,12 +27,11 @@ LAUNCH_GRADIO = False | |||
| 
 | ||||
| get_cat_image = GetCatImageTool() | ||||
| 
 | ||||
| 
 | ||||
| agent = CodeAgent( | ||||
|     tools = [get_cat_image, VisitWebpageTool()], | ||||
|     model=HfApiModel(), | ||||
|     additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",  | ||||
|     use_e2b_executor=False | ||||
|     use_e2b_executor=True | ||||
| ) | ||||
| 
 | ||||
| if LAUNCH_GRADIO: | ||||
|  | @ -41,6 +40,5 @@ if LAUNCH_GRADIO: | |||
|     GradioUI(agent).launch() | ||||
| else: | ||||
|     agent.run( | ||||
|         "Return me an image of Lincoln's preferred pet", | ||||
|         additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" | ||||
|     ) | ||||
|         "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()} | ||||
|     ) # Asking to directly return the image from state tests that additional_args are properly sent to server. | ||||
|  |  | |||
|  | @ -188,6 +188,7 @@ class MultiStepAgent: | |||
|         self.tool_parser = tool_parser | ||||
|         self.grammar = grammar | ||||
|         self.planning_interval = planning_interval | ||||
|         self.state = {} | ||||
| 
 | ||||
|         self.managed_agents = {} | ||||
|         if managed_agents is not None: | ||||
|  | @ -370,8 +371,7 @@ class MultiStepAgent: | |||
|             return self.model(self.input_messages) | ||||
|         except Exception as e: | ||||
|             error_msg = f"Error in generating final LLM output:\n{e}" | ||||
|             console.print(f"[bold red]{error_msg}[/bold red]") | ||||
|             return error_msg | ||||
|             raise AgentGenerationError(error_msg) | ||||
| 
 | ||||
|     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: | ||||
|         """ | ||||
|  | @ -385,7 +385,6 @@ class MultiStepAgent: | |||
|         available_tools = {**self.toolbox.tools, **self.managed_agents} | ||||
|         if tool_name not in available_tools: | ||||
|             error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." | ||||
|             console.print(f"[bold red]{error_msg}") | ||||
|             raise AgentExecutionError(error_msg) | ||||
| 
 | ||||
|         try: | ||||
|  | @ -398,7 +397,6 @@ class MultiStepAgent: | |||
|                 observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) | ||||
|             else: | ||||
|                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." | ||||
|                 console.print(f"[bold red]{error_msg}") | ||||
|                 raise AgentExecutionError(error_msg) | ||||
|             return observation | ||||
|         except Exception as e: | ||||
|  | @ -410,14 +408,12 @@ class MultiStepAgent: | |||
|                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" | ||||
|                     f"As a reminder, this tool's description is the following:\n{tool_description}" | ||||
|                 ) | ||||
|                 console.print(f"[bold red]{error_msg}") | ||||
|                 raise AgentExecutionError(error_msg) | ||||
|             elif tool_name in self.managed_agents: | ||||
|                 error_msg = ( | ||||
|                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" | ||||
|                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" | ||||
|                 ) | ||||
|                 console.print(f"[bold red]{error_msg}") | ||||
|                 raise AgentExecutionError(error_msg) | ||||
| 
 | ||||
|     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||
|  | @ -430,7 +426,7 @@ class MultiStepAgent: | |||
|         stream: bool = False, | ||||
|         reset: bool = True, | ||||
|         single_step: bool = False, | ||||
|         **kwargs, | ||||
|         additional_args: Optional[Dict] = None, | ||||
|     ): | ||||
|         """ | ||||
|         Runs the agent for the given task. | ||||
|  | @ -440,6 +436,7 @@ class MultiStepAgent: | |||
|             stream (`bool`): Wether to run in a streaming way. | ||||
|             reset (`bool`): Wether to reset the conversation or keep it going from previous run. | ||||
|             single_step (`bool`): Should the agent run in one shot or multi-step fashion? | ||||
|             additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! | ||||
| 
 | ||||
|         Example: | ||||
|         ```py | ||||
|  | @ -449,11 +446,11 @@ class MultiStepAgent: | |||
|         ``` | ||||
|         """ | ||||
|         self.task = task | ||||
|         if len(kwargs) > 0: | ||||
|             self.task += ( | ||||
|                 f"\nYou have been provided with these initial arguments: {str(kwargs)}." | ||||
|             ) | ||||
|         self.state = kwargs.copy() | ||||
|         if additional_args is not None: | ||||
|             self.state.update(additional_args) | ||||
|             self.task += f""" | ||||
| You have been provided with these additional arguments, that you can access as variables in your python code using the keys: | ||||
| {str(additional_args)}.""" | ||||
| 
 | ||||
|         self.initialize_system_prompt() | ||||
|         system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) | ||||
|  | @ -468,14 +465,9 @@ class MultiStepAgent: | |||
|             else: | ||||
|                 self.logs.append(system_prompt_step) | ||||
| 
 | ||||
|         # console.print( | ||||
|         #     Group( | ||||
|         #         Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task) | ||||
|         #     ) | ||||
|         # ) | ||||
|         console.print( | ||||
|             Panel( | ||||
|                 f"\n[bold]{task.strip()}\n", | ||||
|                 f"\n[bold]{self.task.strip()}\n", | ||||
|                 title="[bold]New run", | ||||
|                 subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", | ||||
|                 border_style=YELLOW_HEX, | ||||
|  | @ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent): | |||
|             console.print_exception() | ||||
|             raise AgentGenerationError(f"Error in generating model output:\n{e}") | ||||
| 
 | ||||
|         # from rich.live import Live | ||||
|         # from rich.markdown import Markdown | ||||
|         # import time | ||||
| 
 | ||||
|         # with Live(console=console, vertical_overflow="visible") as live: | ||||
|         #     message = "" | ||||
|         #     for i in range(100): | ||||
|         #         time.sleep(0.02) | ||||
|         #         message += str(i) | ||||
|         #         live.update(Markdown(message)) | ||||
| 
 | ||||
|         if self.verbose: | ||||
|             console.print( | ||||
|                 Group( | ||||
|  | @ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent): | |||
|         try: | ||||
|             output, execution_logs = self.python_executor( | ||||
|                 code_action, | ||||
|                 self.state, | ||||
|             ) | ||||
|             execution_outputs_console = [] | ||||
|             if len(execution_logs) > 0: | ||||
|  |  | |||
|  | @ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool): | |||
|     pre_processor_class = WhisperProcessor | ||||
|     model_class = WhisperForConditionalGeneration | ||||
| 
 | ||||
|     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} | ||||
|     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}} | ||||
|     output_type = "string" | ||||
| 
 | ||||
|     def encode(self, audio): | ||||
|  |  | |||
|  | @ -17,13 +17,14 @@ | |||
| from dotenv import load_dotenv | ||||
| import textwrap | ||||
| import base64 | ||||
| import pickle | ||||
| from io import BytesIO | ||||
| from PIL import Image | ||||
| 
 | ||||
| from e2b_code_interpreter import Sandbox | ||||
| from typing import List, Tuple, Any | ||||
| from .tool_validation import validate_tool_attributes | ||||
| from .utils import instance_to_source, BASE_BUILTIN_MODULES | ||||
| from .utils import instance_to_source, BASE_BUILTIN_MODULES, console | ||||
| from .tools import Tool | ||||
| 
 | ||||
| load_dotenv() | ||||
|  | @ -40,6 +41,7 @@ class E2BExecutor: | |||
|         #     timeout=300 | ||||
|         # ) | ||||
|         # print("Installation of agents package finished.") | ||||
|         additional_imports = additional_imports + ["pickle5"] | ||||
|         if len(additional_imports) > 0: | ||||
|             execution = self.sbx.commands.run( | ||||
|                 "pip install " + " ".join(additional_imports) | ||||
|  | @ -47,7 +49,7 @@ class E2BExecutor: | |||
|             if execution.error: | ||||
|                 raise Exception(f"Error installing dependencies: {execution.error}") | ||||
|             else: | ||||
|                 print("Installation succeeded!") | ||||
|                 console.print(f"Installation of {additional_imports} succeeded!") | ||||
| 
 | ||||
|         tool_codes = [] | ||||
|         for tool in tools: | ||||
|  | @ -71,21 +73,44 @@ class E2BExecutor: | |||
|         tool_definition_code += "\n\n".join(tool_codes) | ||||
| 
 | ||||
|         tool_definition_execution = self.run_code_raise_errors(tool_definition_code) | ||||
|         print(tool_definition_execution.logs) | ||||
|         console.print(tool_definition_execution.logs) | ||||
| 
 | ||||
|     def run_code_raise_errors(self, code: str): | ||||
|         execution = self.sbx.run_code( | ||||
|             code, | ||||
|         ) | ||||
|         if execution.error: | ||||
|             logs = "Executing code yielded an error:" | ||||
|             execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||
|             logs = execution_logs | ||||
|             logs += "Executing code yielded an error:" | ||||
|             logs += execution.error.name | ||||
|             logs += execution.error.value | ||||
|             logs += execution.error.traceback | ||||
|             raise ValueError(logs) | ||||
|         return execution | ||||
| 
 | ||||
|     def __call__(self, code_action: str) -> Tuple[Any, Any]: | ||||
|     def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]: | ||||
|         if len(additional_args) > 0: | ||||
|             # Pickle additional_args to server | ||||
|             import tempfile | ||||
| 
 | ||||
|             with tempfile.NamedTemporaryFile() as f: | ||||
|                 pickle.dump(additional_args, f) | ||||
|                 f.flush() | ||||
|                 with open(f.name, "rb") as file: | ||||
|                     self.sbx.files.write("/home/state.pkl", file) | ||||
|             remote_unloading_code = """import pickle | ||||
| import os | ||||
| print("File path", os.path.getsize('/home/state.pkl')) | ||||
| with open('/home/state.pkl', 'rb') as f: | ||||
|     pickle_dict = pickle.load(f) | ||||
| locals().update({key: value for key, value in pickle_dict.items()}) | ||||
| """ | ||||
|             execution = self.run_code_raise_errors(remote_unloading_code) | ||||
|             execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||
|             console.print(execution_logs) | ||||
| 
 | ||||
|          | ||||
|         execution = self.run_code_raise_errors(code_action) | ||||
|         execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||
|         if not execution.results: | ||||
|  |  | |||
|  | @ -1058,7 +1058,8 @@ class LocalPythonInterpreter: | |||
|         } | ||||
|         # TODO: assert self.authorized imports are all installed locally | ||||
| 
 | ||||
|     def __call__(self, code_action: str) -> Tuple[Any, str]: | ||||
|     def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]: | ||||
|         self.state.update(additional_variables) | ||||
|         output = evaluate_python_code( | ||||
|             code_action, | ||||
|             static_tools=self.static_tools, | ||||
|  |  | |||
|  | @ -201,20 +201,21 @@ class Tool: | |||
| 
 | ||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||
| 
 | ||||
|         # Validate forward function signature | ||||
|         signature = inspect.signature(self.forward) | ||||
|         # Validate forward function signature, except for PipelineTool | ||||
|         if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True): | ||||
|             signature = inspect.signature(self.forward) | ||||
| 
 | ||||
|         if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||
|             raise Exception( | ||||
|                 "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||
|             ) | ||||
|             if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||
|                 raise Exception( | ||||
|                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||
|                 ) | ||||
| 
 | ||||
|         json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||
|         for key, value in self.inputs.items(): | ||||
|             if "nullable" in value: | ||||
|                 assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||
|             if key in json_schema and "nullable" in json_schema[key]: | ||||
|                 assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||
|             json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||
|             for key, value in self.inputs.items(): | ||||
|                 if "nullable" in value: | ||||
|                     assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||
|                 if key in json_schema and "nullable" in json_schema[key]: | ||||
|                     assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||
| 
 | ||||
|     def forward(self, *args, **kwargs): | ||||
|         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||
|  | @ -1074,6 +1075,7 @@ class PipelineTool(Tool): | |||
|     name = "pipeline" | ||||
|     inputs = {"prompt": str} | ||||
|     output_type = str | ||||
|     is_pipeline_tool = True | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue