Fix additional args sent to e2b server
This commit is contained in:
		
							parent
							
								
									1abaf69b67
								
							
						
					
					
						commit
						f8b9cb34f9
					
				|  | @ -62,7 +62,7 @@ and [reaches higher performance on difficult benchmarks](https://huggingface.co/ | ||||||
| 
 | 
 | ||||||
| Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: | Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime: | ||||||
|   - a secure python interpreter to run code more safely in your environment |   - a secure python interpreter to run code more safely in your environment | ||||||
|   - a sandboxed environment. |   - a sandboxed environment using [E2B](https://e2b.dev/). | ||||||
| 
 | 
 | ||||||
| ## How lightweight is it? | ## How lightweight is it? | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -33,7 +33,7 @@ from huggingface_hub import login, InferenceClient | ||||||
| 
 | 
 | ||||||
| login("<YOUR_HUGGINGFACEHUB_API_TOKEN>") | login("<YOUR_HUGGINGFACEHUB_API_TOKEN>") | ||||||
| 
 | 
 | ||||||
| model_id = "Qwen/Qwen2.5-72B-Instruct" | model_id = "meta-llama/Llama-3.3-70B-Instruct" | ||||||
| 
 | 
 | ||||||
| client = InferenceClient(model=model_id) | client = InferenceClient(model=model_id) | ||||||
| 
 | 
 | ||||||
|  | @ -71,12 +71,19 @@ agent.run( | ||||||
| 
 | 
 | ||||||
| Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. | Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text. | ||||||
| 
 | 
 | ||||||
| You can use this to indicate the path to local or remote files for the model to use: | You can use this to pass files in various formats: | ||||||
| 
 | 
 | ||||||
| ```py | ```py | ||||||
| agent = CodeAgent(tools=[], model=model, add_base_tools=True) | from smolagents import CodeAgent, HfApiModel | ||||||
| 
 | 
 | ||||||
| agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") | model_id = "meta-llama/Llama-3.3-70B-Instruct" | ||||||
|  | 
 | ||||||
|  | agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True) | ||||||
|  | 
 | ||||||
|  | agent.run( | ||||||
|  |     "Why does Mike not know many people in New York?", | ||||||
|  |     additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'} | ||||||
|  | ) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| It's important to explain as clearly as possible the task you want to perform. | It's important to explain as clearly as possible the task you want to perform. | ||||||
|  |  | ||||||
|  | @ -27,12 +27,11 @@ LAUNCH_GRADIO = False | ||||||
| 
 | 
 | ||||||
| get_cat_image = GetCatImageTool() | get_cat_image = GetCatImageTool() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| agent = CodeAgent( | agent = CodeAgent( | ||||||
|     tools = [get_cat_image, VisitWebpageTool()], |     tools = [get_cat_image, VisitWebpageTool()], | ||||||
|     model=HfApiModel(), |     model=HfApiModel(), | ||||||
|     additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",  |     additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",  | ||||||
|     use_e2b_executor=False |     use_e2b_executor=True | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| if LAUNCH_GRADIO: | if LAUNCH_GRADIO: | ||||||
|  | @ -41,6 +40,5 @@ if LAUNCH_GRADIO: | ||||||
|     GradioUI(agent).launch() |     GradioUI(agent).launch() | ||||||
| else: | else: | ||||||
|     agent.run( |     agent.run( | ||||||
|         "Return me an image of Lincoln's preferred pet", |         "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()} | ||||||
|         additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" |     ) # Asking to directly return the image from state tests that additional_args are properly sent to server. | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  | @ -188,6 +188,7 @@ class MultiStepAgent: | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|         self.grammar = grammar |         self.grammar = grammar | ||||||
|         self.planning_interval = planning_interval |         self.planning_interval = planning_interval | ||||||
|  |         self.state = {} | ||||||
| 
 | 
 | ||||||
|         self.managed_agents = {} |         self.managed_agents = {} | ||||||
|         if managed_agents is not None: |         if managed_agents is not None: | ||||||
|  | @ -370,8 +371,7 @@ class MultiStepAgent: | ||||||
|             return self.model(self.input_messages) |             return self.model(self.input_messages) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Error in generating final LLM output:\n{e}" |             error_msg = f"Error in generating final LLM output:\n{e}" | ||||||
|             console.print(f"[bold red]{error_msg}[/bold red]") |             raise AgentGenerationError(error_msg) | ||||||
|             return error_msg |  | ||||||
| 
 | 
 | ||||||
|     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: |     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: | ||||||
|         """ |         """ | ||||||
|  | @ -385,7 +385,6 @@ class MultiStepAgent: | ||||||
|         available_tools = {**self.toolbox.tools, **self.managed_agents} |         available_tools = {**self.toolbox.tools, **self.managed_agents} | ||||||
|         if tool_name not in available_tools: |         if tool_name not in available_tools: | ||||||
|             error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." |             error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." | ||||||
|             console.print(f"[bold red]{error_msg}") |  | ||||||
|             raise AgentExecutionError(error_msg) |             raise AgentExecutionError(error_msg) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -398,7 +397,6 @@ class MultiStepAgent: | ||||||
|                 observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) |                 observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) | ||||||
|             else: |             else: | ||||||
|                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." |                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg) | ||||||
|             return observation |             return observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -410,14 +408,12 @@ class MultiStepAgent: | ||||||
|                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" |                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" | ||||||
|                     f"As a reminder, this tool's description is the following:\n{tool_description}" |                     f"As a reminder, this tool's description is the following:\n{tool_description}" | ||||||
|                 ) |                 ) | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg) | ||||||
|             elif tool_name in self.managed_agents: |             elif tool_name in self.managed_agents: | ||||||
|                 error_msg = ( |                 error_msg = ( | ||||||
|                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" |                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" | ||||||
|                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" |                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" | ||||||
|                 ) |                 ) | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |                 raise AgentExecutionError(error_msg) | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: ActionStep) -> Union[None, Any]: |     def step(self, log_entry: ActionStep) -> Union[None, Any]: | ||||||
|  | @ -430,7 +426,7 @@ class MultiStepAgent: | ||||||
|         stream: bool = False, |         stream: bool = False, | ||||||
|         reset: bool = True, |         reset: bool = True, | ||||||
|         single_step: bool = False, |         single_step: bool = False, | ||||||
|         **kwargs, |         additional_args: Optional[Dict] = None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Runs the agent for the given task. |         Runs the agent for the given task. | ||||||
|  | @ -440,6 +436,7 @@ class MultiStepAgent: | ||||||
|             stream (`bool`): Wether to run in a streaming way. |             stream (`bool`): Wether to run in a streaming way. | ||||||
|             reset (`bool`): Wether to reset the conversation or keep it going from previous run. |             reset (`bool`): Wether to reset the conversation or keep it going from previous run. | ||||||
|             single_step (`bool`): Should the agent run in one shot or multi-step fashion? |             single_step (`bool`): Should the agent run in one shot or multi-step fashion? | ||||||
|  |             additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! | ||||||
| 
 | 
 | ||||||
|         Example: |         Example: | ||||||
|         ```py |         ```py | ||||||
|  | @ -449,11 +446,11 @@ class MultiStepAgent: | ||||||
|         ``` |         ``` | ||||||
|         """ |         """ | ||||||
|         self.task = task |         self.task = task | ||||||
|         if len(kwargs) > 0: |         if additional_args is not None: | ||||||
|             self.task += ( |             self.state.update(additional_args) | ||||||
|                 f"\nYou have been provided with these initial arguments: {str(kwargs)}." |             self.task += f""" | ||||||
|             ) | You have been provided with these additional arguments, that you can access as variables in your python code using the keys: | ||||||
|         self.state = kwargs.copy() | {str(additional_args)}.""" | ||||||
| 
 | 
 | ||||||
|         self.initialize_system_prompt() |         self.initialize_system_prompt() | ||||||
|         system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) |         system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) | ||||||
|  | @ -468,14 +465,9 @@ class MultiStepAgent: | ||||||
|             else: |             else: | ||||||
|                 self.logs.append(system_prompt_step) |                 self.logs.append(system_prompt_step) | ||||||
| 
 | 
 | ||||||
|         # console.print( |  | ||||||
|         #     Group( |  | ||||||
|         #         Rule("[bold]New run", characters="═", style=YELLOW_HEX), Text(self.task) |  | ||||||
|         #     ) |  | ||||||
|         # ) |  | ||||||
|         console.print( |         console.print( | ||||||
|             Panel( |             Panel( | ||||||
|                 f"\n[bold]{task.strip()}\n", |                 f"\n[bold]{self.task.strip()}\n", | ||||||
|                 title="[bold]New run", |                 title="[bold]New run", | ||||||
|                 subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", |                 subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, "model_id") else "")}", | ||||||
|                 border_style=YELLOW_HEX, |                 border_style=YELLOW_HEX, | ||||||
|  | @ -891,17 +883,6 @@ class CodeAgent(MultiStepAgent): | ||||||
|             console.print_exception() |             console.print_exception() | ||||||
|             raise AgentGenerationError(f"Error in generating model output:\n{e}") |             raise AgentGenerationError(f"Error in generating model output:\n{e}") | ||||||
| 
 | 
 | ||||||
|         # from rich.live import Live |  | ||||||
|         # from rich.markdown import Markdown |  | ||||||
|         # import time |  | ||||||
| 
 |  | ||||||
|         # with Live(console=console, vertical_overflow="visible") as live: |  | ||||||
|         #     message = "" |  | ||||||
|         #     for i in range(100): |  | ||||||
|         #         time.sleep(0.02) |  | ||||||
|         #         message += str(i) |  | ||||||
|         #         live.update(Markdown(message)) |  | ||||||
| 
 |  | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.print( |             console.print( | ||||||
|                 Group( |                 Group( | ||||||
|  | @ -946,6 +927,7 @@ class CodeAgent(MultiStepAgent): | ||||||
|         try: |         try: | ||||||
|             output, execution_logs = self.python_executor( |             output, execution_logs = self.python_executor( | ||||||
|                 code_action, |                 code_action, | ||||||
|  |                 self.state, | ||||||
|             ) |             ) | ||||||
|             execution_outputs_console = [] |             execution_outputs_console = [] | ||||||
|             if len(execution_logs) > 0: |             if len(execution_logs) > 0: | ||||||
|  |  | ||||||
|  | @ -295,7 +295,7 @@ class SpeechToTextTool(PipelineTool): | ||||||
|     pre_processor_class = WhisperProcessor |     pre_processor_class = WhisperProcessor | ||||||
|     model_class = WhisperForConditionalGeneration |     model_class = WhisperForConditionalGeneration | ||||||
| 
 | 
 | ||||||
|     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} |     inputs = {"audio": {"type": "audio", "description": "The audio to transcribe. Can be a local path, an url, or a tensor."}} | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|     def encode(self, audio): |     def encode(self, audio): | ||||||
|  |  | ||||||
|  | @ -17,13 +17,14 @@ | ||||||
| from dotenv import load_dotenv | from dotenv import load_dotenv | ||||||
| import textwrap | import textwrap | ||||||
| import base64 | import base64 | ||||||
|  | import pickle | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from PIL import Image | from PIL import Image | ||||||
| 
 | 
 | ||||||
| from e2b_code_interpreter import Sandbox | from e2b_code_interpreter import Sandbox | ||||||
| from typing import List, Tuple, Any | from typing import List, Tuple, Any | ||||||
| from .tool_validation import validate_tool_attributes | from .tool_validation import validate_tool_attributes | ||||||
| from .utils import instance_to_source, BASE_BUILTIN_MODULES | from .utils import instance_to_source, BASE_BUILTIN_MODULES, console | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
| 
 | 
 | ||||||
| load_dotenv() | load_dotenv() | ||||||
|  | @ -40,6 +41,7 @@ class E2BExecutor: | ||||||
|         #     timeout=300 |         #     timeout=300 | ||||||
|         # ) |         # ) | ||||||
|         # print("Installation of agents package finished.") |         # print("Installation of agents package finished.") | ||||||
|  |         additional_imports = additional_imports + ["pickle5"] | ||||||
|         if len(additional_imports) > 0: |         if len(additional_imports) > 0: | ||||||
|             execution = self.sbx.commands.run( |             execution = self.sbx.commands.run( | ||||||
|                 "pip install " + " ".join(additional_imports) |                 "pip install " + " ".join(additional_imports) | ||||||
|  | @ -47,7 +49,7 @@ class E2BExecutor: | ||||||
|             if execution.error: |             if execution.error: | ||||||
|                 raise Exception(f"Error installing dependencies: {execution.error}") |                 raise Exception(f"Error installing dependencies: {execution.error}") | ||||||
|             else: |             else: | ||||||
|                 print("Installation succeeded!") |                 console.print(f"Installation of {additional_imports} succeeded!") | ||||||
| 
 | 
 | ||||||
|         tool_codes = [] |         tool_codes = [] | ||||||
|         for tool in tools: |         for tool in tools: | ||||||
|  | @ -71,21 +73,44 @@ class E2BExecutor: | ||||||
|         tool_definition_code += "\n\n".join(tool_codes) |         tool_definition_code += "\n\n".join(tool_codes) | ||||||
| 
 | 
 | ||||||
|         tool_definition_execution = self.run_code_raise_errors(tool_definition_code) |         tool_definition_execution = self.run_code_raise_errors(tool_definition_code) | ||||||
|         print(tool_definition_execution.logs) |         console.print(tool_definition_execution.logs) | ||||||
| 
 | 
 | ||||||
|     def run_code_raise_errors(self, code: str): |     def run_code_raise_errors(self, code: str): | ||||||
|         execution = self.sbx.run_code( |         execution = self.sbx.run_code( | ||||||
|             code, |             code, | ||||||
|         ) |         ) | ||||||
|         if execution.error: |         if execution.error: | ||||||
|             logs = "Executing code yielded an error:" |             execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||||
|  |             logs = execution_logs | ||||||
|  |             logs += "Executing code yielded an error:" | ||||||
|             logs += execution.error.name |             logs += execution.error.name | ||||||
|             logs += execution.error.value |             logs += execution.error.value | ||||||
|             logs += execution.error.traceback |             logs += execution.error.traceback | ||||||
|             raise ValueError(logs) |             raise ValueError(logs) | ||||||
|         return execution |         return execution | ||||||
| 
 | 
 | ||||||
|     def __call__(self, code_action: str) -> Tuple[Any, Any]: |     def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]: | ||||||
|  |         if len(additional_args) > 0: | ||||||
|  |             # Pickle additional_args to server | ||||||
|  |             import tempfile | ||||||
|  | 
 | ||||||
|  |             with tempfile.NamedTemporaryFile() as f: | ||||||
|  |                 pickle.dump(additional_args, f) | ||||||
|  |                 f.flush() | ||||||
|  |                 with open(f.name, "rb") as file: | ||||||
|  |                     self.sbx.files.write("/home/state.pkl", file) | ||||||
|  |             remote_unloading_code = """import pickle | ||||||
|  | import os | ||||||
|  | print("File path", os.path.getsize('/home/state.pkl')) | ||||||
|  | with open('/home/state.pkl', 'rb') as f: | ||||||
|  |     pickle_dict = pickle.load(f) | ||||||
|  | locals().update({key: value for key, value in pickle_dict.items()}) | ||||||
|  | """ | ||||||
|  |             execution = self.run_code_raise_errors(remote_unloading_code) | ||||||
|  |             execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||||
|  |             console.print(execution_logs) | ||||||
|  | 
 | ||||||
|  |          | ||||||
|         execution = self.run_code_raise_errors(code_action) |         execution = self.run_code_raise_errors(code_action) | ||||||
|         execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) |         execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) | ||||||
|         if not execution.results: |         if not execution.results: | ||||||
|  |  | ||||||
|  | @ -1058,7 +1058,8 @@ class LocalPythonInterpreter: | ||||||
|         } |         } | ||||||
|         # TODO: assert self.authorized imports are all installed locally |         # TODO: assert self.authorized imports are all installed locally | ||||||
| 
 | 
 | ||||||
|     def __call__(self, code_action: str) -> Tuple[Any, str]: |     def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str]: | ||||||
|  |         self.state.update(additional_variables) | ||||||
|         output = evaluate_python_code( |         output = evaluate_python_code( | ||||||
|             code_action, |             code_action, | ||||||
|             static_tools=self.static_tools, |             static_tools=self.static_tools, | ||||||
|  |  | ||||||
|  | @ -201,20 +201,21 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES |         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||||
| 
 | 
 | ||||||
|         # Validate forward function signature |         # Validate forward function signature, except for PipelineTool | ||||||
|         signature = inspect.signature(self.forward) |         if not (hasattr(self, "is_pipeline_tool") and getattr(self, "is_pipeline_tool") is True): | ||||||
|  |             signature = inspect.signature(self.forward) | ||||||
| 
 | 
 | ||||||
|         if not set(signature.parameters.keys()) == set(self.inputs.keys()): |             if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||||
|             raise Exception( |                 raise Exception( | ||||||
|                 "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." |                     "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." | ||||||
|             ) |                 ) | ||||||
| 
 | 
 | ||||||
|         json_schema = _convert_type_hints_to_json_schema(self.forward) |             json_schema = _convert_type_hints_to_json_schema(self.forward) | ||||||
|         for key, value in self.inputs.items(): |             for key, value in self.inputs.items(): | ||||||
|             if "nullable" in value: |                 if "nullable" in value: | ||||||
|                 assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." |                     assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." | ||||||
|             if key in json_schema and "nullable" in json_schema[key]: |                 if key in json_schema and "nullable" in json_schema[key]: | ||||||
|                 assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." |                     assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." | ||||||
| 
 | 
 | ||||||
|     def forward(self, *args, **kwargs): |     def forward(self, *args, **kwargs): | ||||||
|         return NotImplementedError("Write this method in your subclass of `Tool`.") |         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||||
|  | @ -1074,6 +1075,7 @@ class PipelineTool(Tool): | ||||||
|     name = "pipeline" |     name = "pipeline" | ||||||
|     inputs = {"prompt": str} |     inputs = {"prompt": str} | ||||||
|     output_type = str |     output_type = str | ||||||
|  |     is_pipeline_tool = True | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue