Add E2B code interpreter 🥳
This commit is contained in:
parent
7b0b01d8f3
commit
c18bc9037d
|
@ -36,6 +36,7 @@ sdist/
|
||||||
var/
|
var/
|
||||||
wheels/
|
wheels/
|
||||||
share/python-wheels/
|
share/python-wheels/
|
||||||
|
node_modules/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
|
@ -157,3 +158,7 @@ cython_debug/
|
||||||
|
|
||||||
# PyCharm
|
# PyCharm
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
# Archive
|
||||||
|
archive/
|
||||||
|
savedir/
|
|
@ -1,5 +1,5 @@
|
||||||
# Base Python image
|
# Base Python image
|
||||||
FROM python:3.9-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
# Set working directory
|
# Set working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
@ -7,8 +7,6 @@ WORKDIR /app
|
||||||
# Install build dependencies
|
# Install build dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
build-essential \
|
build-essential \
|
||||||
gcc \
|
|
||||||
g++ \
|
|
||||||
zlib1g-dev \
|
zlib1g-dev \
|
||||||
libjpeg-dev \
|
libjpeg-dev \
|
||||||
libpng-dev \
|
libpng-dev \
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# You can use most Debian-based base images
|
||||||
|
FROM e2bdev/code-interpreter:latest
|
||||||
|
|
||||||
|
# Install dependencies and customize sandbox
|
||||||
|
RUN pip install git+https://github.com/huggingface/agents.git
|
|
@ -0,0 +1,16 @@
|
||||||
|
# This is a config for E2B sandbox template.
|
||||||
|
# You can use template ID (qywp2ctmu2q7jzprcf4j) to create a sandbox:
|
||||||
|
|
||||||
|
# Python SDK
|
||||||
|
# from e2b import Sandbox, AsyncSandbox
|
||||||
|
# sandbox = Sandbox("qywp2ctmu2q7jzprcf4j") # Sync sandbox
|
||||||
|
# sandbox = await AsyncSandbox.create("qywp2ctmu2q7jzprcf4j") # Async sandbox
|
||||||
|
|
||||||
|
# JS SDK
|
||||||
|
# import { Sandbox } from 'e2b'
|
||||||
|
# const sandbox = await Sandbox.create('qywp2ctmu2q7jzprcf4j')
|
||||||
|
|
||||||
|
team_id = "f8776d3a-df2f-4a1d-af48-68c2e13b3b87"
|
||||||
|
start_cmd = "/root/.jupyter/start-up.sh"
|
||||||
|
dockerfile = "e2b.Dockerfile"
|
||||||
|
template_id = "qywp2ctmu2q7jzprcf4j"
|
|
@ -1,8 +1,8 @@
|
||||||
from agents.tools.search import DuckDuckGoSearchTool
|
from agents.default_tools.search import DuckDuckGoSearchTool
|
||||||
from agents.docker_alternative import DockerPythonInterpreter
|
from agents.docker_alternative import DockerPythonInterpreter
|
||||||
|
|
||||||
|
|
||||||
from agents.tool import Tool
|
from agents.tools import Tool
|
||||||
|
|
||||||
class DummyTool(Tool):
|
class DummyTool(Tool):
|
||||||
name = "echo"
|
name = "echo"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from agents.tool import Tool
|
from agents.tools import Tool
|
||||||
|
|
||||||
|
|
||||||
class DummyTool(Tool):
|
class DummyTool(Tool):
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
from agents import Tool, CodeAgent
|
||||||
|
from agents.default_tools.search import VisitWebpageTool
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
LAUNCH_GRADIO = False
|
||||||
|
|
||||||
|
class GetCatImageTool(Tool):
|
||||||
|
name="get_cat_image"
|
||||||
|
description = "Get a cat image"
|
||||||
|
inputs = {}
|
||||||
|
output_type = "image"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
response = requests.get(self.url)
|
||||||
|
|
||||||
|
return Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
|
get_cat_image = GetCatImageTool()
|
||||||
|
|
||||||
|
|
||||||
|
agent = CodeAgent(
|
||||||
|
tools = [get_cat_image, VisitWebpageTool()],
|
||||||
|
additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
|
||||||
|
use_e2b_executor=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if LAUNCH_GRADIO:
|
||||||
|
from agents.gradio_ui import GradioUI
|
||||||
|
|
||||||
|
GradioUI(agent).launch()
|
||||||
|
else:
|
||||||
|
agent.run(
|
||||||
|
"Return me an image of Lincoln's preferred pet",
|
||||||
|
additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/"
|
||||||
|
)
|
|
@ -24,22 +24,28 @@ from transformers.utils.import_utils import define_import_structure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import *
|
from .agents import *
|
||||||
from .default_tools import *
|
from .default_tools.base import *
|
||||||
|
from .default_tools.search import *
|
||||||
from .gradio_ui import *
|
from .gradio_ui import *
|
||||||
from .llm_engines import *
|
from .llm_engines import *
|
||||||
from .local_python_executor import *
|
from .local_python_executor import *
|
||||||
from .monitoring import *
|
from .monitoring import *
|
||||||
from .prompts import *
|
from .prompts import *
|
||||||
from .tools.search import *
|
from .tools import *
|
||||||
from .tool import *
|
|
||||||
from .types import *
|
from .types import *
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
from .default_tools.search import *
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
_file = globals()["__file__"]
|
_file = globals()["__file__"]
|
||||||
|
import_structure = define_import_structure(_file)
|
||||||
|
import_structure[""]= {"__version__": __version__}
|
||||||
sys.modules[__name__] = _LazyModule(
|
sys.modules[__name__] = _LazyModule(
|
||||||
__name__, _file, define_import_structure(_file), module_spec=__spec__
|
__name__,
|
||||||
|
_file,
|
||||||
|
import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={"__version__": __version__}
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,7 +27,7 @@ from transformers.utils import is_torch_available
|
||||||
|
|
||||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||||
from .types import AgentAudio, AgentImage
|
from .types import AgentAudio, AgentImage
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool
|
from .default_tools.base import FinalAnswerTool
|
||||||
from .llm_engines import HfApiEngine, MessageRole
|
from .llm_engines import HfApiEngine, MessageRole
|
||||||
from .monitoring import Monitor
|
from .monitoring import Monitor
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
|
@ -42,8 +42,9 @@ from .prompts import (
|
||||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||||
SYSTEM_PROMPT_PLAN,
|
SYSTEM_PROMPT_PLAN,
|
||||||
)
|
)
|
||||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor
|
||||||
from .tool import (
|
from .e2b_executor import E2BExecutor
|
||||||
|
from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
Tool,
|
Tool,
|
||||||
get_tool_description_with_args,
|
get_tool_description_with_args,
|
||||||
|
@ -169,17 +170,6 @@ def format_prompt_with_managed_agents_descriptions(
|
||||||
else:
|
else:
|
||||||
return prompt_template.replace(agent_descriptions_placeholder, "")
|
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_imports(
|
|
||||||
prompt_template: str, authorized_imports: List[str]
|
|
||||||
) -> str:
|
|
||||||
if "<<authorized_imports>>" not in prompt_template:
|
|
||||||
raise AgentError(
|
|
||||||
"Tag '<<authorized_imports>>' should be provided in the prompt."
|
|
||||||
)
|
|
||||||
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent:
|
class BaseAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -264,11 +254,6 @@ class BaseAgent:
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||||
self.system_prompt, self.managed_agents
|
self.system_prompt, self.managed_agents
|
||||||
)
|
)
|
||||||
if hasattr(self, "authorized_imports"):
|
|
||||||
self.system_prompt = format_prompt_with_imports(
|
|
||||||
self.system_prompt,
|
|
||||||
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.system_prompt
|
return self.system_prompt
|
||||||
|
|
||||||
|
@ -439,9 +424,7 @@ class ReactAgent(BaseAgent):
|
||||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||||
"""
|
"""
|
||||||
available_tools = self.toolbox.tools
|
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||||
if self.managed_agents is not None:
|
|
||||||
available_tools = {**available_tools, **self.managed_agents}
|
|
||||||
if tool_name not in available_tools:
|
if tool_name not in available_tools:
|
||||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
console.print(f"[bold red]{error_msg}")
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
@ -674,8 +657,6 @@ Now begin!""",
|
||||||
),
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents)
|
show_agents_descriptions(self.managed_agents)
|
||||||
if self.managed_agents is not None
|
|
||||||
else ""
|
|
||||||
),
|
),
|
||||||
answer_facts=answer_facts,
|
answer_facts=answer_facts,
|
||||||
),
|
),
|
||||||
|
@ -729,8 +710,6 @@ Now begin!""",
|
||||||
),
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents)
|
show_agents_descriptions(self.managed_agents)
|
||||||
if self.managed_agents is not None
|
|
||||||
else ""
|
|
||||||
),
|
),
|
||||||
facts_update=facts_update,
|
facts_update=facts_update,
|
||||||
remaining_steps=(self.max_iterations - iteration),
|
remaining_steps=(self.max_iterations - iteration),
|
||||||
|
@ -891,6 +870,7 @@ class CodeAgent(ReactAgent):
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
|
use_e2b_executor: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if llm_engine is None:
|
if llm_engine is None:
|
||||||
|
@ -909,17 +889,24 @@ class CodeAgent(ReactAgent):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
|
||||||
self.additional_authorized_imports = (
|
self.additional_authorized_imports = (
|
||||||
additional_authorized_imports if additional_authorized_imports else []
|
additional_authorized_imports if additional_authorized_imports else []
|
||||||
)
|
)
|
||||||
|
all_tools = {**self.toolbox.tools, **self.managed_agents}
|
||||||
|
if use_e2b_executor:
|
||||||
|
self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values()))
|
||||||
|
else:
|
||||||
|
self.python_executor = LocalPythonExecutor(self.additional_authorized_imports, all_tools)
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||||
|
)
|
||||||
|
if "{{authorized_imports}}" not in self.system_prompt:
|
||||||
|
raise AgentError(
|
||||||
|
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||||
)
|
)
|
||||||
self.system_prompt = self.system_prompt.replace(
|
self.system_prompt = self.system_prompt.replace(
|
||||||
"{{authorized_imports}}", str(self.authorized_imports)
|
"{{authorized_imports}}", str(self.authorized_imports)
|
||||||
)
|
)
|
||||||
self.custom_tools = {}
|
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -991,22 +978,12 @@ class CodeAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
static_tools = {
|
output, execution_logs = self.python_executor(
|
||||||
**BASE_PYTHON_TOOLS.copy(),
|
|
||||||
**self.toolbox.tools,
|
|
||||||
}
|
|
||||||
if self.managed_agents is not None:
|
|
||||||
static_tools = {**static_tools, **self.managed_agents}
|
|
||||||
output = self.python_evaluator(
|
|
||||||
code_action,
|
code_action,
|
||||||
static_tools=static_tools,
|
|
||||||
custom_tools=self.custom_tools,
|
|
||||||
state=self.state,
|
|
||||||
authorized_imports=self.authorized_imports,
|
|
||||||
)
|
)
|
||||||
if len(self.state["print_outputs"]) > 0:
|
if len(execution_logs) > 0:
|
||||||
console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"])))
|
console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs)))
|
||||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
observation = "Execution logs:\n" + execution_logs
|
||||||
if output is not None:
|
if output is not None:
|
||||||
truncated_output = truncate_content(
|
truncated_output = truncate_content(
|
||||||
str(output)
|
str(output)
|
||||||
|
@ -1026,7 +1003,7 @@ class CodeAgent(ReactAgent):
|
||||||
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
|
console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green")))
|
||||||
log_entry.action_output = output
|
log_entry.action_output = output
|
||||||
return output
|
return output
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ManagedAgent:
|
class ManagedAgent:
|
||||||
|
|
|
@ -15,75 +15,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import sqrt
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
|
|
||||||
from transformers.utils import is_offline_mode
|
from transformers.utils import is_offline_mode
|
||||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code
|
||||||
from .tool import TOOL_CONFIG_FILE, Tool
|
from ..tools import TOOL_CONFIG_FILE, Tool
|
||||||
|
|
||||||
|
|
||||||
def custom_print(*args):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
BASE_PYTHON_TOOLS = {
|
|
||||||
"print": custom_print,
|
|
||||||
"isinstance": isinstance,
|
|
||||||
"range": range,
|
|
||||||
"float": float,
|
|
||||||
"int": int,
|
|
||||||
"bool": bool,
|
|
||||||
"str": str,
|
|
||||||
"set": set,
|
|
||||||
"list": list,
|
|
||||||
"dict": dict,
|
|
||||||
"tuple": tuple,
|
|
||||||
"round": round,
|
|
||||||
"ceil": math.ceil,
|
|
||||||
"floor": math.floor,
|
|
||||||
"log": math.log,
|
|
||||||
"exp": math.exp,
|
|
||||||
"sin": math.sin,
|
|
||||||
"cos": math.cos,
|
|
||||||
"tan": math.tan,
|
|
||||||
"asin": math.asin,
|
|
||||||
"acos": math.acos,
|
|
||||||
"atan": math.atan,
|
|
||||||
"atan2": math.atan2,
|
|
||||||
"degrees": math.degrees,
|
|
||||||
"radians": math.radians,
|
|
||||||
"pow": math.pow,
|
|
||||||
"sqrt": sqrt,
|
|
||||||
"len": len,
|
|
||||||
"sum": sum,
|
|
||||||
"max": max,
|
|
||||||
"min": min,
|
|
||||||
"abs": abs,
|
|
||||||
"enumerate": enumerate,
|
|
||||||
"zip": zip,
|
|
||||||
"reversed": reversed,
|
|
||||||
"sorted": sorted,
|
|
||||||
"all": all,
|
|
||||||
"any": any,
|
|
||||||
"map": map,
|
|
||||||
"filter": filter,
|
|
||||||
"ord": ord,
|
|
||||||
"chr": chr,
|
|
||||||
"next": next,
|
|
||||||
"iter": iter,
|
|
||||||
"divmod": divmod,
|
|
||||||
"callable": callable,
|
|
||||||
"getattr": getattr,
|
|
||||||
"hasattr": hasattr,
|
|
||||||
"setattr": setattr,
|
|
||||||
"issubclass": issubclass,
|
|
||||||
"type": type,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -136,10 +75,10 @@ class PythonInterpreterTool(Tool):
|
||||||
|
|
||||||
def __init__(self, *args, authorized_imports=None, **kwargs):
|
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||||
if authorized_imports is None:
|
if authorized_imports is None:
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
||||||
else:
|
else:
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(LIST_SAFE_MODULES) | set(authorized_imports)
|
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
|
||||||
)
|
)
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
"code": {
|
"code": {
|
|
@ -16,15 +16,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import requests
|
|
||||||
from requests.exceptions import RequestException
|
|
||||||
|
|
||||||
from ..tools import Tool
|
from ..tools import Tool
|
||||||
|
|
||||||
|
|
||||||
class DuckDuckGoSearchTool(Tool):
|
class DuckDuckGoSearchTool(Tool):
|
||||||
name = "web_search"
|
name = "web_search"
|
||||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||||
Each result has keys 'title', 'href' and 'body'."""
|
Each result has keys 'title', 'href' and 'body'."""
|
||||||
inputs = {
|
inputs = {
|
||||||
"query": {"type": "string", "description": "The search query to perform."}
|
"query": {"type": "string", "description": "The search query to perform."}
|
||||||
|
@ -56,9 +52,11 @@ class VisitWebpageTool(Tool):
|
||||||
def forward(self, url: str) -> str:
|
def forward(self, url: str) -> str:
|
||||||
try:
|
try:
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import RequestException
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
|
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# Send a GET request to the URL
|
# Send a GET request to the URL
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||||
import warnings
|
import warnings
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from agents.tool import Tool
|
from agents.tools import Tool
|
||||||
|
|
||||||
class DockerPythonInterpreter:
|
class DockerPythonInterpreter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -343,7 +343,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
||||||
from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES
|
from .local_python_executor import evaluate_python_code, BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
"""Execute code locally with state transfer."""
|
"""Execute code locally with state transfer."""
|
||||||
state_manager = StateManager(work_dir)
|
state_manager = StateManager(work_dir)
|
||||||
|
@ -363,7 +363,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any:
|
||||||
tools,
|
tools,
|
||||||
{},
|
{},
|
||||||
namespace,
|
namespace,
|
||||||
LIST_SAFE_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save state for Docker
|
# Save state for Docker
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import textwrap
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from e2b_code_interpreter import Sandbox
|
||||||
|
from typing import Dict, List, Callable, Tuple, Any
|
||||||
|
from .tool_validation import validate_tool_attributes
|
||||||
|
from .utils import instance_to_source, BASE_BUILTIN_MODULES
|
||||||
|
from .tools import Tool
|
||||||
|
from .types import AgentImage
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class E2BExecutor():
|
||||||
|
def __init__(self, additional_imports: List[str], tools: List[Tool]):
|
||||||
|
self.custom_tools = {}
|
||||||
|
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
|
||||||
|
# TODO: validate installing agents package or not
|
||||||
|
# print("Installing agents package on remote executor...")
|
||||||
|
# self.sbx.commands.run(
|
||||||
|
# "pip install git+https://github.com/huggingface/agents.git",
|
||||||
|
# timeout=300
|
||||||
|
# )
|
||||||
|
# print("Installation of agents package finished.")
|
||||||
|
if len(additional_imports) > 0:
|
||||||
|
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
|
||||||
|
if execution.error:
|
||||||
|
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||||
|
else:
|
||||||
|
print("Installation succeeded!")
|
||||||
|
|
||||||
|
tool_codes = []
|
||||||
|
for tool in tools:
|
||||||
|
validate_tool_attributes(tool.__class__, check_imports=False)
|
||||||
|
tool_code = instance_to_source(tool, base_cls=Tool)
|
||||||
|
tool_code = tool_code.replace("from agents.tools import Tool", "")
|
||||||
|
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
|
||||||
|
tool_codes.append(tool_code)
|
||||||
|
|
||||||
|
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
|
||||||
|
tool_definition_code += textwrap.dedent("""
|
||||||
|
class Tool:
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
pass # to be implemented in child class
|
||||||
|
""")
|
||||||
|
tool_definition_code += "\n\n".join(tool_codes)
|
||||||
|
|
||||||
|
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
|
||||||
|
print(tool_definition_execution.logs)
|
||||||
|
|
||||||
|
def run_code_raise_errors(self, code: str):
|
||||||
|
execution = self.sbx.run_code(
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
if execution.error:
|
||||||
|
logs = 'Executing code yielded an error:'
|
||||||
|
logs += execution.error.name
|
||||||
|
logs += execution.error.value
|
||||||
|
logs += execution.error.traceback
|
||||||
|
raise ValueError(logs)
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def __call__(self, code_action: str) -> Tuple[Any, Any]:
|
||||||
|
execution = self.run_code_raise_errors(code_action)
|
||||||
|
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
|
||||||
|
if not execution.results:
|
||||||
|
return None, execution_logs
|
||||||
|
else:
|
||||||
|
for result in execution.results:
|
||||||
|
if result.is_main_result:
|
||||||
|
for attribute_name in ['jpeg', 'png']:
|
||||||
|
if getattr(result, attribute_name) is not None:
|
||||||
|
image_output = getattr(result, attribute_name)
|
||||||
|
decoded_bytes = base64.b64decode(image_output.encode('utf-8'))
|
||||||
|
return Image.open(BytesIO(decoded_bytes)), execution_logs
|
||||||
|
for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']:
|
||||||
|
if getattr(result, attribute_name) is not None:
|
||||||
|
return getattr(result, attribute_name), execution_logs
|
||||||
|
raise ValueError("No main result returned by executor!")
|
||||||
|
|
||||||
|
__all__ = ["E2BExecutor"]
|
|
@ -14,7 +14,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .types import AgentAudio, AgentImage, AgentText
|
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||||
from .agents import BaseAgent, AgentStep, ActionStep
|
from .agents import BaseAgent, AgentStep, ActionStep
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ def stream_to_gradio(
|
||||||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
final_answer = step_log # Last log is the run's final_answer
|
final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer
|
||||||
|
|
||||||
if isinstance(final_answer, AgentText):
|
if isinstance(final_answer, AgentText):
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
|
@ -93,7 +93,7 @@ class GradioUI:
|
||||||
yield messages
|
yield messages
|
||||||
yield messages
|
yield messages
|
||||||
|
|
||||||
def run(self):
|
def launch(self):
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
stored_message = gr.State([])
|
stored_message = gr.State([])
|
||||||
chatbot = gr.Chatbot(
|
chatbot = gr.Chatbot(
|
||||||
|
|
|
@ -19,12 +19,12 @@ import builtins
|
||||||
import difflib
|
import difflib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from .utils import truncate_content
|
from .utils import truncate_content, BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
|
@ -43,24 +43,66 @@ ERRORS = {
|
||||||
and issubclass(getattr(builtins, name), BaseException)
|
and issubclass(getattr(builtins, name), BaseException)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
LIST_SAFE_MODULES = [
|
|
||||||
"random",
|
|
||||||
"collections",
|
|
||||||
"math",
|
|
||||||
"time",
|
|
||||||
"queue",
|
|
||||||
"itertools",
|
|
||||||
"re",
|
|
||||||
"stat",
|
|
||||||
"statistics",
|
|
||||||
"unicodedata",
|
|
||||||
]
|
|
||||||
|
|
||||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||||
|
|
||||||
|
def custom_print(*args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
BASE_PYTHON_TOOLS = {
|
||||||
|
"print": custom_print,
|
||||||
|
"isinstance": isinstance,
|
||||||
|
"range": range,
|
||||||
|
"float": float,
|
||||||
|
"int": int,
|
||||||
|
"bool": bool,
|
||||||
|
"str": str,
|
||||||
|
"set": set,
|
||||||
|
"list": list,
|
||||||
|
"dict": dict,
|
||||||
|
"tuple": tuple,
|
||||||
|
"round": round,
|
||||||
|
"ceil": math.ceil,
|
||||||
|
"floor": math.floor,
|
||||||
|
"log": math.log,
|
||||||
|
"exp": math.exp,
|
||||||
|
"sin": math.sin,
|
||||||
|
"cos": math.cos,
|
||||||
|
"tan": math.tan,
|
||||||
|
"asin": math.asin,
|
||||||
|
"acos": math.acos,
|
||||||
|
"atan": math.atan,
|
||||||
|
"atan2": math.atan2,
|
||||||
|
"degrees": math.degrees,
|
||||||
|
"radians": math.radians,
|
||||||
|
"pow": math.pow,
|
||||||
|
"sqrt": math.sqrt,
|
||||||
|
"len": len,
|
||||||
|
"sum": sum,
|
||||||
|
"max": max,
|
||||||
|
"min": min,
|
||||||
|
"abs": abs,
|
||||||
|
"enumerate": enumerate,
|
||||||
|
"zip": zip,
|
||||||
|
"reversed": reversed,
|
||||||
|
"sorted": sorted,
|
||||||
|
"all": all,
|
||||||
|
"any": any,
|
||||||
|
"map": map,
|
||||||
|
"filter": filter,
|
||||||
|
"ord": ord,
|
||||||
|
"chr": chr,
|
||||||
|
"next": next,
|
||||||
|
"iter": iter,
|
||||||
|
"divmod": divmod,
|
||||||
|
"callable": callable,
|
||||||
|
"getattr": getattr,
|
||||||
|
"hasattr": hasattr,
|
||||||
|
"setattr": setattr,
|
||||||
|
"issubclass": issubclass,
|
||||||
|
"type": type,
|
||||||
|
}
|
||||||
class BreakException(Exception):
|
class BreakException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -771,7 +813,7 @@ def evaluate_ast(
|
||||||
state: Dict[str, Any],
|
state: Dict[str, Any],
|
||||||
static_tools: Dict[str, Callable],
|
static_tools: Dict[str, Callable],
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||||
|
@ -949,7 +991,7 @@ def evaluate_python_code(
|
||||||
static_tools: Optional[Dict[str, Callable]] = None,
|
static_tools: Optional[Dict[str, Callable]] = None,
|
||||||
custom_tools: Optional[Dict[str, Callable]] = None,
|
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||||
state: Optional[Dict[str, Any]] = None,
|
state: Optional[Dict[str, Any]] = None,
|
||||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||||
|
@ -1001,4 +1043,30 @@ def evaluate_python_code(
|
||||||
raise InterpreterError(msg)
|
raise InterpreterError(msg)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["evaluate_python_code"]
|
class LocalPythonExecutor():
|
||||||
|
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
|
||||||
|
self.custom_tools = {}
|
||||||
|
self.state = {}
|
||||||
|
self.additional_authorized_imports = additional_authorized_imports
|
||||||
|
self.authorized_imports = list(
|
||||||
|
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||||
|
)
|
||||||
|
# Add base trusted tools to list
|
||||||
|
self.static_tools = {
|
||||||
|
**tools,
|
||||||
|
**BASE_PYTHON_TOOLS.copy(),
|
||||||
|
}
|
||||||
|
# TODO: assert self.authorized imports are all installed locally
|
||||||
|
|
||||||
|
def __call__(self, code_action: str) -> Tuple[Any, str]:
|
||||||
|
output = evaluate_python_code(
|
||||||
|
code_action,
|
||||||
|
static_tools=self.static_tools,
|
||||||
|
custom_tools=self.custom_tools,
|
||||||
|
state=self.state,
|
||||||
|
authorized_imports=self.authorized_imports,
|
||||||
|
)
|
||||||
|
logs = self.state["print_outputs"]
|
||||||
|
return output, logs
|
||||||
|
|
||||||
|
__all__ = ["evaluate_python_code", "LocalPythonExecutor"]
|
||||||
|
|
|
@ -370,7 +370,7 @@ Here are the rules you should always follow to solve your task:
|
||||||
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
||||||
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
|
||||||
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||||
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||||
|
|
||||||
|
|
|
@ -5,27 +5,30 @@ import builtins
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Set, Dict
|
from typing import List, Set, Dict
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from .utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
_BUILTIN_NAMES = set(vars(builtins))
|
_BUILTIN_NAMES = set(vars(builtins))
|
||||||
|
|
||||||
def is_local_import(module_name: str) -> bool:
|
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
|
def is_installed_package(module_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an import is from a local file or a package.
|
Check if an import is from an installed package.
|
||||||
Returns True if it's a local file import.
|
Returns False if it's not found or a local file import.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
spec = importlib.util.find_spec(module_name)
|
spec = importlib.util.find_spec(module_name)
|
||||||
if spec is None:
|
if spec is None:
|
||||||
return True # If we can't find the module, assume it's local
|
return False # If we can't find the module, assume it's local
|
||||||
|
|
||||||
# If the module is found and has a file path, check if it's in site-packages
|
# If the module is found and has a file path, check if it's in site-packages
|
||||||
if spec.origin and 'site-packages' not in spec.origin:
|
if spec.origin and 'site-packages' not in spec.origin:
|
||||||
# Check if it's a .py file in the current directory or subdirectories
|
# Check if it's a .py file in the current directory or subdirectories
|
||||||
return spec.origin.endswith('.py')
|
return not spec.origin.endswith('.py')
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return True # If there's an import error, assume it's local
|
return False # If there's an import error, assume it's local
|
||||||
|
|
||||||
class MethodChecker(ast.NodeVisitor):
|
class MethodChecker(ast.NodeVisitor):
|
||||||
"""
|
"""
|
||||||
|
@ -33,7 +36,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
- only uses defined names
|
- only uses defined names
|
||||||
- contains no local imports (e.g. numpy is ok but local_script is not)
|
- contains no local imports (e.g. numpy is ok but local_script is not)
|
||||||
"""
|
"""
|
||||||
def __init__(self, class_attributes: Set[str]):
|
def __init__(self, class_attributes: Set[str], check_imports: bool = True):
|
||||||
self.undefined_names = set()
|
self.undefined_names = set()
|
||||||
self.imports = {}
|
self.imports = {}
|
||||||
self.from_imports = {}
|
self.from_imports = {}
|
||||||
|
@ -41,6 +44,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
self.arg_names = set()
|
self.arg_names = set()
|
||||||
self.class_attributes = class_attributes
|
self.class_attributes = class_attributes
|
||||||
self.errors = []
|
self.errors = []
|
||||||
|
self.check_imports = check_imports
|
||||||
|
|
||||||
def visit_arguments(self, node):
|
def visit_arguments(self, node):
|
||||||
"""Collect function arguments"""
|
"""Collect function arguments"""
|
||||||
|
@ -53,16 +57,16 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
def visit_Import(self, node):
|
def visit_Import(self, node):
|
||||||
for name in node.names:
|
for name in node.names:
|
||||||
actual_name = name.asname or name.name
|
actual_name = name.asname or name.name
|
||||||
if is_local_import(actual_name):
|
if not is_installed_package(actual_name) and self.check_imports:
|
||||||
self.errors.append(f"Local import '{actual_name}'")
|
self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'")
|
||||||
self.imports[actual_name] = name.name
|
self.imports[actual_name] = name.name
|
||||||
|
|
||||||
def visit_ImportFrom(self, node):
|
def visit_ImportFrom(self, node):
|
||||||
module = node.module or ""
|
module = node.module or ""
|
||||||
for name in node.names:
|
for name in node.names:
|
||||||
actual_name = name.asname or name.name
|
actual_name = name.asname or name.name
|
||||||
if is_local_import(module):
|
if not is_installed_package(module) and self.check_imports:
|
||||||
self.errors.append(f"Local import '{module}'")
|
self.errors.append(f"Package not found in importlib, might be a local install: '{module}'")
|
||||||
self.from_imports[actual_name] = (module, name.name)
|
self.from_imports[actual_name] = (module, name.name)
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
|
@ -71,6 +75,20 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
self.assigned_names.add(target.id)
|
self.assigned_names.add(target.id)
|
||||||
self.visit(node.value)
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_With(self, node):
|
||||||
|
"""Track aliases in 'with' statements (the 'y' in 'with X as y')"""
|
||||||
|
for item in node.items:
|
||||||
|
if item.optional_vars: # This is the 'y' in 'with X as y'
|
||||||
|
if isinstance(item.optional_vars, ast.Name):
|
||||||
|
self.assigned_names.add(item.optional_vars.id)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ExceptHandler(self, node):
|
||||||
|
"""Track exception aliases (the 'e' in 'except Exception as e')"""
|
||||||
|
if node.name: # This is the 'e' in 'except Exception as e'
|
||||||
|
self.assigned_names.add(node.name)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
def visit_AnnAssign(self, node):
|
def visit_AnnAssign(self, node):
|
||||||
"""Track annotated assignments."""
|
"""Track annotated assignments."""
|
||||||
if isinstance(node.target, ast.Name):
|
if isinstance(node.target, ast.Name):
|
||||||
|
@ -97,6 +115,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
if isinstance(node.ctx, ast.Load):
|
if isinstance(node.ctx, ast.Load):
|
||||||
if not (
|
if not (
|
||||||
node.id in _BUILTIN_NAMES
|
node.id in _BUILTIN_NAMES
|
||||||
|
or node.id in IMPORTED_PACKAGES
|
||||||
or node.id in self.arg_names
|
or node.id in self.arg_names
|
||||||
or node.id == "self"
|
or node.id == "self"
|
||||||
or node.id in self.class_attributes
|
or node.id in self.class_attributes
|
||||||
|
@ -110,6 +129,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
if isinstance(node.func, ast.Name):
|
if isinstance(node.func, ast.Name):
|
||||||
if not (
|
if not (
|
||||||
node.func.id in _BUILTIN_NAMES
|
node.func.id in _BUILTIN_NAMES
|
||||||
|
or node.func.id in IMPORTED_PACKAGES
|
||||||
or node.func.id in self.arg_names
|
or node.func.id in self.arg_names
|
||||||
or node.func.id == "self"
|
or node.func.id == "self"
|
||||||
or node.func.id in self.class_attributes
|
or node.func.id in self.class_attributes
|
||||||
|
@ -120,7 +140,7 @@ class MethodChecker(ast.NodeVisitor):
|
||||||
self.errors.append(f"Name '{node.func.id}' is undefined.")
|
self.errors.append(f"Name '{node.func.id}' is undefined.")
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
def validate_tool_attributes(cls) -> None:
|
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Validates that a Tool class follows the proper patterns:
|
Validates that a Tool class follows the proper patterns:
|
||||||
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
|
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
|
||||||
|
@ -156,8 +176,17 @@ def validate_tool_attributes(cls) -> None:
|
||||||
self.imported_names = set()
|
self.imported_names = set()
|
||||||
self.complex_attributes = set()
|
self.complex_attributes = set()
|
||||||
self.class_attributes = set()
|
self.class_attributes = set()
|
||||||
|
self.in_method = False
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
old_context = self.in_method
|
||||||
|
self.in_method = True
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.in_method = old_context
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
|
if self.in_method:
|
||||||
|
return
|
||||||
# Track class attributes
|
# Track class attributes
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
|
@ -182,7 +211,7 @@ def validate_tool_attributes(cls) -> None:
|
||||||
# Run checks on all methods
|
# Run checks on all methods
|
||||||
for node in class_node.body:
|
for node in class_node.body:
|
||||||
if isinstance(node, ast.FunctionDef):
|
if isinstance(node, ast.FunctionDef):
|
||||||
method_checker = MethodChecker(class_level_checker.class_attributes)
|
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
|
||||||
method_checker.visit(node)
|
method_checker.visit(node)
|
||||||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import textwrap
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, Set
|
from typing import Any, Callable, Dict, List, Optional, Union, Set
|
||||||
|
import math
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
create_repo,
|
create_repo,
|
||||||
|
@ -48,7 +49,7 @@ from transformers.utils import (
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.dynamic_module_utils import get_imports
|
from transformers.dynamic_module_utils import get_imports
|
||||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||||
from .utils import instance_to_source
|
from .utils import instance_to_source
|
||||||
from .tool_validation import validate_tool_attributes, MethodChecker
|
from .tool_validation import validate_tool_attributes, MethodChecker
|
||||||
|
|
||||||
|
@ -66,7 +67,6 @@ if is_accelerate_available():
|
||||||
|
|
||||||
TOOL_CONFIG_FILE = "tool_config.json"
|
TOOL_CONFIG_FILE = "tool_config.json"
|
||||||
|
|
||||||
|
|
||||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||||
if repo_type is not None:
|
if repo_type is not None:
|
||||||
return repo_type
|
return repo_type
|
||||||
|
@ -197,12 +197,15 @@ class Tool:
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
|
||||||
if not self.is_initialized:
|
if not self.is_initialized:
|
||||||
self.setup()
|
self.setup()
|
||||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
if sanitize_inputs_outputs:
|
||||||
|
args, kwargs = handle_agent_input_types(*args, **kwargs)
|
||||||
outputs = self.forward(*args, **kwargs)
|
outputs = self.forward(*args, **kwargs)
|
||||||
return handle_agent_outputs(outputs, self.output_type)
|
if sanitize_inputs_outputs:
|
||||||
|
outputs = handle_agent_output_types(outputs, self.output_type)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
"""
|
"""
|
||||||
|
@ -267,8 +270,6 @@ class Tool:
|
||||||
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||||
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||||
|
|
||||||
with open(tool_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write(tool_code)
|
|
||||||
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
|
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
|
||||||
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
|
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -278,6 +279,7 @@ class Tool:
|
||||||
validate_tool_attributes(self.__class__)
|
validate_tool_attributes(self.__class__)
|
||||||
|
|
||||||
tool_code = instance_to_source(self, base_cls=Tool)
|
tool_code = instance_to_source(self, base_cls=Tool)
|
||||||
|
|
||||||
with open(tool_file, "w", encoding="utf-8") as f:
|
with open(tool_file, "w", encoding="utf-8") as f:
|
||||||
f.write(tool_code)
|
f.write(tool_code)
|
||||||
|
|
||||||
|
@ -719,6 +721,9 @@ def launch_gradio_demo(tool: Tool):
|
||||||
"number": gr.Textbox,
|
"number": gr.Textbox,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def fn(*args, **kwargs):
|
||||||
|
return tool(*args, **kwargs, sanitize_inputs_outputs=True)
|
||||||
|
|
||||||
gradio_inputs = []
|
gradio_inputs = []
|
||||||
for input_name, input_details in tool.inputs.items():
|
for input_name, input_details in tool.inputs.items():
|
||||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||||
|
@ -733,7 +738,7 @@ def launch_gradio_demo(tool: Tool):
|
||||||
gradio_output = output_gradio_componentclass(label="Output")
|
gradio_output = output_gradio_componentclass(label="Output")
|
||||||
|
|
||||||
gr.Interface(
|
gr.Interface(
|
||||||
fn=tool, # This works because `tool` has a __call__ method
|
fn=fn,
|
||||||
inputs=gradio_inputs,
|
inputs=gradio_inputs,
|
||||||
outputs=gradio_output,
|
outputs=gradio_output,
|
||||||
title=tool.name,
|
title=tool.name,
|
||||||
|
@ -823,61 +828,6 @@ def add_description(description):
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
## Will move to the Hub
|
|
||||||
class EndpointClient:
|
|
||||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
|
||||||
self.headers = {
|
|
||||||
**build_hf_headers(token=token),
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
self.endpoint_url = endpoint_url
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode_image(image):
|
|
||||||
_bytes = io.BytesIO()
|
|
||||||
image.save(_bytes, format="PNG")
|
|
||||||
b64 = base64.b64encode(_bytes.getvalue())
|
|
||||||
return b64.decode("utf-8")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_image(raw_image):
|
|
||||||
if not is_vision_available():
|
|
||||||
raise ImportError(
|
|
||||||
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
|
|
||||||
)
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
b64 = base64.b64decode(raw_image)
|
|
||||||
_bytes = io.BytesIO(b64)
|
|
||||||
return Image.open(_bytes)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
|
||||||
params: Optional[Dict] = None,
|
|
||||||
data: Optional[bytes] = None,
|
|
||||||
output_image: bool = False,
|
|
||||||
) -> Any:
|
|
||||||
# Build payload
|
|
||||||
payload = {}
|
|
||||||
if inputs:
|
|
||||||
payload["inputs"] = inputs
|
|
||||||
if params:
|
|
||||||
payload["parameters"] = params
|
|
||||||
|
|
||||||
# Make API call
|
|
||||||
response = get_session().post(
|
|
||||||
self.endpoint_url, headers=self.headers, json=payload, data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
# By default, parse the response for the user.
|
|
||||||
if output_image:
|
|
||||||
return self.decode_image(response.content)
|
|
||||||
else:
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCollection:
|
class ToolCollection:
|
||||||
"""
|
"""
|
||||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
||||||
|
@ -1063,4 +1013,5 @@ __all__ = [
|
||||||
"load_tool",
|
"load_tool",
|
||||||
"launch_gradio_demo",
|
"launch_gradio_demo",
|
||||||
"Toolbox",
|
"Toolbox",
|
||||||
|
"ToolCollection",
|
||||||
]
|
]
|
|
@ -16,6 +16,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -105,6 +106,8 @@ class AgentImage(AgentType, ImageType):
|
||||||
|
|
||||||
if isinstance(value, ImageType):
|
if isinstance(value, ImageType):
|
||||||
self._raw = value
|
self._raw = value
|
||||||
|
elif isinstance(value, bytes):
|
||||||
|
self._raw = Image.open(BytesIO(value))
|
||||||
elif isinstance(value, (str, pathlib.Path)):
|
elif isinstance(value, (str, pathlib.Path)):
|
||||||
self._path = value
|
self._path = value
|
||||||
elif isinstance(value, torch.Tensor):
|
elif isinstance(value, torch.Tensor):
|
||||||
|
@ -241,13 +244,13 @@ class AgentAudio(AgentType, str):
|
||||||
|
|
||||||
|
|
||||||
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
|
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_inputs(*args, **kwargs):
|
def handle_agent_input_types(*args, **kwargs):
|
||||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||||
kwargs = {
|
kwargs = {
|
||||||
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
||||||
|
@ -255,7 +258,7 @@ def handle_agent_inputs(*args, **kwargs):
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_outputs(output, output_type=None):
|
def handle_agent_output_types(output, output_type=None):
|
||||||
if output_type in AGENT_TYPE_MAPPING:
|
if output_type in AGENT_TYPE_MAPPING:
|
||||||
# If the class has defined outputs, we can map directly according to the class definition
|
# If the class has defined outputs, we can map directly according to the class definition
|
||||||
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
||||||
|
|
|
@ -34,7 +34,18 @@ def is_pygments_available():
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
BASE_BUILTIN_MODULES = [
|
||||||
|
"random",
|
||||||
|
"collections",
|
||||||
|
"math",
|
||||||
|
"time",
|
||||||
|
"queue",
|
||||||
|
"itertools",
|
||||||
|
"re",
|
||||||
|
"stat",
|
||||||
|
"statistics",
|
||||||
|
"unicodedata",
|
||||||
|
]
|
||||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
first_accolade_index = json_blob.find("{")
|
first_accolade_index = json_blob.find("{")
|
||||||
|
@ -190,6 +201,9 @@ def instance_to_source(instance, base_cls=None):
|
||||||
|
|
||||||
for name, value in class_attrs.items():
|
for name, value in class_attrs.items():
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
if "\n" in value:
|
||||||
|
class_lines.append(f' {name} = """{value}"""')
|
||||||
|
else:
|
||||||
class_lines.append(f' {name} = "{value}"')
|
class_lines.append(f' {name} = "{value}"')
|
||||||
else:
|
else:
|
||||||
class_lines.append(f' {name} = {repr(value)}')
|
class_lines.append(f' {name} = {repr(value)}')
|
||||||
|
@ -230,7 +244,8 @@ def instance_to_source(instance, base_cls=None):
|
||||||
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
|
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
|
||||||
|
|
||||||
# Add discovered imports
|
# Add discovered imports
|
||||||
final_lines.extend(required_imports)
|
for package in required_imports:
|
||||||
|
final_lines.append(f"import {package}")
|
||||||
|
|
||||||
if final_lines: # Add empty line after imports
|
if final_lines: # Add empty line after imports
|
||||||
final_lines.append("")
|
final_lines.append("")
|
||||||
|
|
|
@ -29,7 +29,7 @@ from agents.agents import (
|
||||||
Toolbox,
|
Toolbox,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
from agents.tool import tool
|
from agents.tools import tool
|
||||||
from agents.default_tools import PythonInterpreterTool
|
from agents.default_tools import PythonInterpreterTool
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from agents.types import (
|
||||||
AgentImage,
|
AgentImage,
|
||||||
AgentText,
|
AgentText,
|
||||||
)
|
)
|
||||||
from agents.tool import Tool, tool, AUTHORIZED_TYPES
|
from agents.tools import Tool, tool, AUTHORIZED_TYPES
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue