Ruff formatting
This commit is contained in:
parent
851e177e71
commit
67deb6808f
|
@ -24,10 +24,25 @@ from transformers.utils import (
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"],
|
"agents": [
|
||||||
|
"Agent",
|
||||||
|
"CodeAgent",
|
||||||
|
"ManagedAgent",
|
||||||
|
"ReactAgent",
|
||||||
|
"CodeAgent",
|
||||||
|
"JsonAgent",
|
||||||
|
"Toolbox",
|
||||||
|
],
|
||||||
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
||||||
"monitoring": ["stream_to_gradio"],
|
"monitoring": ["stream_to_gradio"],
|
||||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
|
"tools": [
|
||||||
|
"PipelineTool",
|
||||||
|
"Tool",
|
||||||
|
"ToolCollection",
|
||||||
|
"launch_gradio_demo",
|
||||||
|
"load_tool",
|
||||||
|
"tool",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -45,10 +60,25 @@ else:
|
||||||
_import_structure["translation"] = ["TranslationTool"]
|
_import_structure["translation"] = ["TranslationTool"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox
|
from .agents import (
|
||||||
|
Agent,
|
||||||
|
CodeAgent,
|
||||||
|
ManagedAgent,
|
||||||
|
ReactAgent,
|
||||||
|
CodeAgent,
|
||||||
|
JsonAgent,
|
||||||
|
Toolbox,
|
||||||
|
)
|
||||||
from .llm_engine import HfApiEngine, TransformersEngine
|
from .llm_engine import HfApiEngine, TransformersEngine
|
||||||
from .monitoring import stream_to_gradio
|
from .monitoring import stream_to_gradio
|
||||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
|
from .tools import (
|
||||||
|
PipelineTool,
|
||||||
|
Tool,
|
||||||
|
ToolCollection,
|
||||||
|
launch_gradio_demo,
|
||||||
|
load_tool,
|
||||||
|
tool,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
@ -66,4 +96,6 @@ if TYPE_CHECKING:
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
|
||||||
|
)
|
||||||
|
|
|
@ -19,7 +19,11 @@ import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
from transformers.utils import (
|
||||||
|
is_soundfile_availble,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,7 +112,9 @@ class AgentImage(AgentType, ImageType):
|
||||||
elif isinstance(value, np.ndarray):
|
elif isinstance(value, np.ndarray):
|
||||||
self._tensor = torch.from_numpy(value)
|
self._tensor = torch.from_numpy(value)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
raise TypeError(
|
||||||
|
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _ipython_display_(self, include=None, exclude=None):
|
def _ipython_display_(self, include=None, exclude=None):
|
||||||
"""
|
"""
|
||||||
|
@ -159,7 +165,7 @@ class AgentImage(AgentType, ImageType):
|
||||||
|
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
def save(self, output_bytes, format : str = None, **params):
|
def save(self, output_bytes, format: str = None, **params):
|
||||||
"""
|
"""
|
||||||
Saves the image to a file.
|
Saves the image to a file.
|
||||||
Args:
|
Args:
|
||||||
|
@ -243,7 +249,9 @@ if is_torch_available():
|
||||||
|
|
||||||
def handle_agent_inputs(*args, **kwargs):
|
def handle_agent_inputs(*args, **kwargs):
|
||||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||||
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
kwargs = {
|
||||||
|
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
||||||
|
}
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
|
372
agents/agents.py
372
agents/agents.py
|
@ -15,12 +15,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
|
|
||||||
from langfuse.decorators import langfuse_context, observe
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
|
||||||
|
@ -51,6 +49,7 @@ from .tools import (
|
||||||
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||||
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
class AgentError(Exception):
|
||||||
"""Base class for other agent-related exceptions"""
|
"""Base class for other agent-related exceptions"""
|
||||||
|
|
||||||
|
@ -60,7 +59,6 @@ class AgentError(Exception):
|
||||||
console.print(f"[bold red]{message}[/bold red]")
|
console.print(f"[bold red]{message}[/bold red]")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AgentParsingError(AgentError):
|
class AgentParsingError(AgentError):
|
||||||
"""Exception raised for errors in parsing in the agent"""
|
"""Exception raised for errors in parsing in the agent"""
|
||||||
|
|
||||||
|
@ -84,12 +82,14 @@ class AgentGenerationError(AgentError):
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentStep:
|
class AgentStep:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionStep(AgentStep):
|
class ActionStep(AgentStep):
|
||||||
tool_call: str | None = None
|
tool_call: Dict[str, str] | None = None
|
||||||
start_time: float | None = None
|
start_time: float | None = None
|
||||||
step_end_time: float | None = None
|
step_end_time: float | None = None
|
||||||
iteration: int | None = None
|
iteration: int | None = None
|
||||||
|
@ -97,32 +97,43 @@ class ActionStep(AgentStep):
|
||||||
error: AgentError | None = None
|
error: AgentError | None = None
|
||||||
step_duration: float | None = None
|
step_duration: float | None = None
|
||||||
llm_output: str | None = None
|
llm_output: str | None = None
|
||||||
|
observation: str | None = None
|
||||||
|
agent_memory: List[Dict[str, str]] | None = None
|
||||||
|
rationale: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlanningStep(AgentStep):
|
class PlanningStep(AgentStep):
|
||||||
plan: str
|
plan: str
|
||||||
facts: str
|
facts: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskStep(AgentStep):
|
class TaskStep(AgentStep):
|
||||||
task: str
|
task: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SystemPromptStep(AgentStep):
|
class SystemPromptStep(AgentStep):
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
def format_prompt_with_tools(
|
||||||
|
toolbox: Toolbox, prompt_template: str, tool_description_template: str
|
||||||
|
) -> str:
|
||||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||||
|
|
||||||
if "{{tool_names}}" in prompt:
|
if "{{tool_names}}" in prompt:
|
||||||
prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]))
|
prompt = prompt.replace(
|
||||||
|
"{{tool_names}}",
|
||||||
|
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
|
||||||
|
)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def show_agents_descriptions(managed_agents: list):
|
def show_agents_descriptions(managed_agents: Dict):
|
||||||
managed_agents_descriptions = """
|
managed_agents_descriptions = """
|
||||||
You can also give requests to team members.
|
You can also give requests to team members.
|
||||||
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
|
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
|
||||||
|
@ -133,16 +144,24 @@ Here is a list of the team members that you can call:"""
|
||||||
return managed_agents_descriptions
|
return managed_agents_descriptions
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
|
def format_prompt_with_managed_agents_descriptions(
|
||||||
|
prompt_template, managed_agents=None
|
||||||
|
) -> str:
|
||||||
if managed_agents is not None:
|
if managed_agents is not None:
|
||||||
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
|
return prompt_template.replace(
|
||||||
|
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
def format_prompt_with_imports(
|
||||||
|
prompt_template: str, authorized_imports: List[str]
|
||||||
|
) -> str:
|
||||||
if "<<authorized_imports>>" not in prompt_template:
|
if "<<authorized_imports>>" not in prompt_template:
|
||||||
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
raise AgentError(
|
||||||
|
"Tag '<<authorized_imports>>' should be provided in the prompt."
|
||||||
|
)
|
||||||
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,7 +169,7 @@ class BaseAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tools: Union[List[Tool], Toolbox],
|
tools: Union[List[Tool], Toolbox],
|
||||||
llm_engine: Callable = None,
|
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tool_description_template: Optional[str] = None,
|
tool_description_template: Optional[str] = None,
|
||||||
additional_args: Dict = {},
|
additional_args: Dict = {},
|
||||||
|
@ -159,10 +178,12 @@ class BaseAgent:
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
managed_agents: Optional[List] = None,
|
managed_agents: Optional[Dict] = None,
|
||||||
step_callbacks: Optional[List[Callable]] = None,
|
step_callbacks: Optional[List[Callable]] = None,
|
||||||
monitor_metrics: bool = True,
|
monitor_metrics: bool = True,
|
||||||
):
|
):
|
||||||
|
if llm_engine is None:
|
||||||
|
llm_engine = HfApiEngine()
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = CODE_SYSTEM_PROMPT
|
system_prompt = CODE_SYSTEM_PROMPT
|
||||||
if tool_parser is None:
|
if tool_parser is None:
|
||||||
|
@ -171,14 +192,16 @@ class BaseAgent:
|
||||||
self.llm_engine = llm_engine
|
self.llm_engine = llm_engine
|
||||||
self.system_prompt_template = system_prompt
|
self.system_prompt_template = system_prompt
|
||||||
self.tool_description_template = (
|
self.tool_description_template = (
|
||||||
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
tool_description_template
|
||||||
|
if tool_description_template
|
||||||
|
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
)
|
)
|
||||||
self.additional_args = additional_args
|
self.additional_args = additional_args
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
|
|
||||||
self.managed_agents = None
|
self.managed_agents = {}
|
||||||
if managed_agents is not None:
|
if managed_agents is not None:
|
||||||
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||||
|
|
||||||
|
@ -186,9 +209,13 @@ class BaseAgent:
|
||||||
self._toolbox = tools
|
self._toolbox = tools
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Using the base tools requires torch to be installed.")
|
raise ImportError(
|
||||||
|
"Using the base tools requires torch to be installed."
|
||||||
|
)
|
||||||
|
|
||||||
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent))
|
self._toolbox.add_base_tools(
|
||||||
|
add_python_interpreter=(self.__class__ == JsonAgent)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self._toolbox.add_tool(FinalAnswerTool())
|
||||||
|
@ -196,7 +223,9 @@ class BaseAgent:
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||||
)
|
)
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||||
|
self.system_prompt, self.managed_agents
|
||||||
|
)
|
||||||
self.prompt_messages = None
|
self.prompt_messages = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
|
@ -222,15 +251,20 @@ class BaseAgent:
|
||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
self.tool_description_template,
|
self.tool_description_template,
|
||||||
)
|
)
|
||||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||||
|
self.system_prompt, self.managed_agents
|
||||||
|
)
|
||||||
if hasattr(self, "authorized_imports"):
|
if hasattr(self, "authorized_imports"):
|
||||||
self.system_prompt = format_prompt_with_imports(
|
self.system_prompt = format_prompt_with_imports(
|
||||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
self.system_prompt,
|
||||||
|
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.system_prompt
|
return self.system_prompt
|
||||||
|
|
||||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
def write_inner_memory_from_logs(
|
||||||
|
self, summary_mode: Optional[bool] = False
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||||
that can be used as input to the LLM.
|
that can be used as input to the LLM.
|
||||||
|
@ -253,7 +287,10 @@ class BaseAgent:
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
if not summary_mode:
|
if not summary_mode:
|
||||||
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log.plan.strip()}
|
thought_message = {
|
||||||
|
"role": MessageRole.ASSISTANT,
|
||||||
|
"content": "[PLAN]:\n" + step_log.plan.strip(),
|
||||||
|
}
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
elif isinstance(step_log, TaskStep):
|
elif isinstance(step_log, TaskStep):
|
||||||
|
@ -265,13 +302,17 @@ class BaseAgent:
|
||||||
|
|
||||||
elif isinstance(step_log, ActionStep):
|
elif isinstance(step_log, ActionStep):
|
||||||
if step_log.llm_output is not None and not summary_mode:
|
if step_log.llm_output is not None and not summary_mode:
|
||||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log.llm_output.strip()}
|
thought_message = {
|
||||||
|
"role": MessageRole.ASSISTANT,
|
||||||
|
"content": step_log.llm_output.strip(),
|
||||||
|
}
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
if step_log.tool_call is not None and summary_mode:
|
if step_log.tool_call is not None and summary_mode:
|
||||||
tool_call_message = {
|
tool_call_message = {
|
||||||
"role": MessageRole.ASSISTANT,
|
"role": MessageRole.ASSISTANT,
|
||||||
"content": f"[STEP {i} TOOL CALL]: " + str(step_log.tool_call).strip(),
|
"content": f"[STEP {i} TOOL CALL]: "
|
||||||
|
+ str(step_log.tool_call).strip(),
|
||||||
}
|
}
|
||||||
memory.append(tool_call_message)
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
|
@ -284,15 +325,21 @@ class BaseAgent:
|
||||||
)
|
)
|
||||||
elif step_log.observation is not None:
|
elif step_log.observation is not None:
|
||||||
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
|
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
|
||||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
tool_response_message = {
|
||||||
|
"role": MessageRole.TOOL_RESPONSE,
|
||||||
|
"content": message_content,
|
||||||
|
}
|
||||||
memory.append(tool_response_message)
|
memory.append(tool_response_message)
|
||||||
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def get_succinct_logs(self):
|
def get_succinct_logs(self):
|
||||||
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
return [
|
||||||
|
{key: value for key, value in log.items() if key != "agent_memory"}
|
||||||
|
for log in self.logs
|
||||||
|
]
|
||||||
|
|
||||||
def extract_action(self, llm_output: str, split_token: str) -> str:
|
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse action from the LLM output
|
Parse action from the LLM output
|
||||||
|
|
||||||
|
@ -312,54 +359,6 @@ class BaseAgent:
|
||||||
)
|
)
|
||||||
return rationale.strip(), action.strip()
|
return rationale.strip(), action.strip()
|
||||||
|
|
||||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
|
||||||
"""
|
|
||||||
Execute tool with the provided input and returns the result.
|
|
||||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
|
||||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
|
||||||
"""
|
|
||||||
available_tools = self.toolbox.tools
|
|
||||||
if self.managed_agents is not None:
|
|
||||||
available_tools = {**available_tools, **self.managed_agents}
|
|
||||||
if tool_name not in available_tools:
|
|
||||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(arguments, str):
|
|
||||||
observation = available_tools[tool_name](arguments)
|
|
||||||
elif isinstance(arguments, dict):
|
|
||||||
for key, value in arguments.items():
|
|
||||||
if isinstance(value, str) and value in self.state:
|
|
||||||
arguments[key] = self.state[value]
|
|
||||||
observation = available_tools[tool_name](**arguments)
|
|
||||||
else:
|
|
||||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
|
||||||
return observation
|
|
||||||
except Exception as e:
|
|
||||||
if tool_name in self.toolbox.tools:
|
|
||||||
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
|
||||||
error_msg = (
|
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
|
||||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
|
||||||
)
|
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
|
||||||
elif tool_name in self.managed_agents:
|
|
||||||
error_msg = (
|
|
||||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
|
||||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
|
||||||
)
|
|
||||||
console.print(f"[bold red]{error_msg}")
|
|
||||||
raise AgentExecutionError(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
"""To be implemented in the child class"""
|
"""To be implemented in the child class"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -382,8 +381,6 @@ class ReactAgent(BaseAgent):
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if llm_engine is None:
|
|
||||||
llm_engine = HfApiEngine()
|
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
system_prompt = CODE_SYSTEM_PROMPT
|
system_prompt = CODE_SYSTEM_PROMPT
|
||||||
if tool_description_template is None:
|
if tool_description_template is None:
|
||||||
|
@ -423,8 +420,67 @@ class ReactAgent(BaseAgent):
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
@observe
|
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||||
def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs):
|
"""
|
||||||
|
Execute tool with the provided input and returns the result.
|
||||||
|
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||||
|
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||||
|
"""
|
||||||
|
available_tools = self.toolbox.tools
|
||||||
|
if self.managed_agents is not None:
|
||||||
|
available_tools = {**available_tools, **self.managed_agents}
|
||||||
|
if tool_name not in available_tools:
|
||||||
|
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
observation = available_tools[tool_name](arguments)
|
||||||
|
elif isinstance(arguments, dict):
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, str) and value in self.state:
|
||||||
|
arguments[key] = self.state[value]
|
||||||
|
observation = available_tools[tool_name](**arguments)
|
||||||
|
else:
|
||||||
|
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
|
return observation
|
||||||
|
except Exception as e:
|
||||||
|
if tool_name in self.toolbox.tools:
|
||||||
|
tool_description = get_tool_description_with_args(
|
||||||
|
available_tools[tool_name]
|
||||||
|
)
|
||||||
|
error_msg = (
|
||||||
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
|
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||||
|
)
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
|
elif tool_name in self.managed_agents:
|
||||||
|
error_msg = (
|
||||||
|
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||||
|
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||||
|
)
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
|
def step(self, log_entry: ActionStep):
|
||||||
|
"""To be implemented in children classes"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
stream: bool = False,
|
||||||
|
reset: bool = True,
|
||||||
|
oneshot: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
|
||||||
|
@ -441,10 +497,11 @@ class ReactAgent(BaseAgent):
|
||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
print("LANGFUSE REF:", langfuse_context.get_current_trace_url())
|
|
||||||
self.task = task
|
self.task = task
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
self.task += (
|
||||||
|
f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
|
)
|
||||||
self.state = kwargs.copy()
|
self.state = kwargs.copy()
|
||||||
|
|
||||||
self.initialize_system_prompt()
|
self.initialize_system_prompt()
|
||||||
|
@ -460,7 +517,7 @@ class ReactAgent(BaseAgent):
|
||||||
else:
|
else:
|
||||||
self.logs.append(system_prompt_step)
|
self.logs.append(system_prompt_step)
|
||||||
|
|
||||||
console.rule("[bold]New task", characters='=')
|
console.rule("[bold]New task", characters="=")
|
||||||
console.print(self.task)
|
console.print(self.task)
|
||||||
self.logs.append(TaskStep(task=task))
|
self.logs.append(TaskStep(task=task))
|
||||||
|
|
||||||
|
@ -489,8 +546,13 @@ class ReactAgent(BaseAgent):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
if (
|
||||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
self.planning_interval is not None
|
||||||
|
and iteration % self.planning_interval == 0
|
||||||
|
):
|
||||||
|
self.planning_step(
|
||||||
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
self.step(step_log)
|
||||||
if step_log.final_answer is not None:
|
if step_log.final_answer is not None:
|
||||||
|
@ -530,8 +592,13 @@ class ReactAgent(BaseAgent):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
if (
|
||||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
self.planning_interval is not None
|
||||||
|
and iteration % self.planning_interval == 0
|
||||||
|
):
|
||||||
|
self.planning_step(
|
||||||
|
task, is_first_step=(iteration == 0), iteration=iteration
|
||||||
|
)
|
||||||
console.rule("[bold]New step")
|
console.rule("[bold]New step")
|
||||||
self.step(step_log)
|
self.step(step_log)
|
||||||
if step_log.final_answer is not None:
|
if step_log.final_answer is not None:
|
||||||
|
@ -559,7 +626,7 @@ class ReactAgent(BaseAgent):
|
||||||
|
|
||||||
return final_answer
|
return final_answer
|
||||||
|
|
||||||
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
def planning_step(self, task, is_first_step: bool, iteration: int):
|
||||||
"""
|
"""
|
||||||
Used periodically by the agent to plan the next steps to reach the objective.
|
Used periodically by the agent to plan the next steps to reach the objective.
|
||||||
|
|
||||||
|
@ -569,7 +636,10 @@ class ReactAgent(BaseAgent):
|
||||||
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
||||||
"""
|
"""
|
||||||
if is_first_step:
|
if is_first_step:
|
||||||
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
message_prompt_facts = {
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": SYSTEM_PROMPT_FACTS,
|
||||||
|
}
|
||||||
message_prompt_task = {
|
message_prompt_task = {
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": f"""Here is the task:
|
"content": f"""Here is the task:
|
||||||
|
@ -589,15 +659,20 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN.format(
|
"content": USER_PROMPT_PLAN.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||||
|
self.tool_description_template
|
||||||
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
show_agents_descriptions(self.managed_agents)
|
||||||
|
if self.managed_agents is not None
|
||||||
|
else ""
|
||||||
),
|
),
|
||||||
answer_facts=answer_facts,
|
answer_facts=answer_facts,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
answer_plan = self.llm_engine(
|
answer_plan = self.llm_engine(
|
||||||
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
[message_system_prompt_plan, message_user_prompt_plan],
|
||||||
|
stop_sequences=["<end_plan>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||||
|
@ -608,7 +683,9 @@ Now begin!""",
|
||||||
```
|
```
|
||||||
{answer_facts}
|
{answer_facts}
|
||||||
```""".strip()
|
```""".strip()
|
||||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
self.logs.append(
|
||||||
|
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||||
|
)
|
||||||
console.rule("[orange]Initial plan")
|
console.rule("[orange]Initial plan")
|
||||||
console.print(final_plan_redaction)
|
console.print(final_plan_redaction)
|
||||||
else: # update plan
|
else: # update plan
|
||||||
|
@ -625,7 +702,9 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_FACTS_UPDATE,
|
"content": USER_PROMPT_FACTS_UPDATE,
|
||||||
}
|
}
|
||||||
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
facts_update = self.llm_engine(
|
||||||
|
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
||||||
|
)
|
||||||
|
|
||||||
# Redact updated plan
|
# Redact updated plan
|
||||||
plan_update_message = {
|
plan_update_message = {
|
||||||
|
@ -636,25 +715,34 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
tool_descriptions=self._toolbox.show_tool_descriptions(
|
||||||
|
self.tool_description_template
|
||||||
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
show_agents_descriptions(self.managed_agents)
|
||||||
|
if self.managed_agents is not None
|
||||||
|
else ""
|
||||||
),
|
),
|
||||||
facts_update=facts_update,
|
facts_update=facts_update,
|
||||||
remaining_steps=(self.max_iterations - iteration),
|
remaining_steps=(self.max_iterations - iteration),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
plan_update = self.llm_engine(
|
plan_update = self.llm_engine(
|
||||||
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
[plan_update_message] + agent_memory + [plan_update_message_user],
|
||||||
|
stop_sequences=["<end_plan>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log final facts and plan
|
# Log final facts and plan
|
||||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
|
||||||
|
task=task, plan_update=plan_update
|
||||||
|
)
|
||||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||||
```
|
```
|
||||||
{facts_update}
|
{facts_update}
|
||||||
```"""
|
```"""
|
||||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
self.logs.append(
|
||||||
|
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||||
|
)
|
||||||
console.rule("[orange]Updated plan")
|
console.rule("[orange]Updated plan")
|
||||||
console.print(final_plan_redaction)
|
console.print(final_plan_redaction)
|
||||||
|
|
||||||
|
@ -705,14 +793,20 @@ class JsonAgent(ReactAgent):
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Calling LLM engine with this last message:", align="left")
|
console.rule(
|
||||||
|
"[italic]Calling LLM engine with this last message:", align="left"
|
||||||
|
)
|
||||||
console.print(self.prompt_messages[-1])
|
console.print(self.prompt_messages[-1])
|
||||||
console.rule()
|
console.rule()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = (
|
||||||
|
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
)
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt_messages,
|
||||||
|
stop_sequences=["<end_action>", "Observation:"],
|
||||||
|
**additional_args,
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -723,7 +817,9 @@ class JsonAgent(ReactAgent):
|
||||||
console.print(llm_output)
|
console.print(llm_output)
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
rationale, action = self.extract_action(
|
||||||
|
llm_output=llm_output, split_token="Action:"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_name, arguments = self.tool_parser(action)
|
tool_name, arguments = self.tool_parser(action)
|
||||||
|
@ -807,12 +903,18 @@ class CodeAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = (
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
)
|
||||||
|
self.authorized_imports = list(
|
||||||
|
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
|
||||||
|
)
|
||||||
|
self.system_prompt = self.system_prompt.replace(
|
||||||
|
"<<authorized_imports>>", str(self.authorized_imports)
|
||||||
|
)
|
||||||
self.custom_tools = {}
|
self.custom_tools = {}
|
||||||
|
|
||||||
def step(self, log_entry: Dict[str, Any]):
|
def step(self, log_entry: ActionStep):
|
||||||
"""
|
"""
|
||||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||||
The errors are raised here, they are caught and logged in the run() method.
|
The errors are raised here, they are caught and logged in the run() method.
|
||||||
|
@ -825,14 +927,20 @@ class CodeAgent(ReactAgent):
|
||||||
log_entry.agent_memory = agent_memory.copy()
|
log_entry.agent_memory = agent_memory.copy()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
|
console.rule(
|
||||||
|
"[italic]Calling LLM engine with these last messages:", align="left"
|
||||||
|
)
|
||||||
console.print(self.prompt_messages[-2:])
|
console.print(self.prompt_messages[-2:])
|
||||||
console.rule()
|
console.rule()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = (
|
||||||
|
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
)
|
||||||
llm_output = self.llm_engine(
|
llm_output = self.llm_engine(
|
||||||
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
self.prompt_messages,
|
||||||
|
stop_sequences=["<end_action>", "Observation:"],
|
||||||
|
**additional_args,
|
||||||
)
|
)
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -840,13 +948,19 @@ class CodeAgent(ReactAgent):
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Output message of the LLM:")
|
console.rule("[italic]Output message of the LLM:")
|
||||||
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
|
console.print(
|
||||||
|
Syntax(llm_output, lexer="markdown", background_color="default")
|
||||||
|
)
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
rationale, raw_code_action = self.extract_action(
|
||||||
|
llm_output=llm_output, split_token="Code:"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
console.print(
|
||||||
|
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
|
||||||
|
)
|
||||||
rationale, raw_code_action = llm_output, llm_output
|
rationale, raw_code_action = llm_output, llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -856,14 +970,17 @@ class CodeAgent(ReactAgent):
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
log_entry.rationale = rationale
|
log_entry.rationale = rationale
|
||||||
log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
log_entry.tool_call = {
|
||||||
|
"tool_name": "code interpreter",
|
||||||
|
"tool_arguments": code_action,
|
||||||
|
}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
console.rule("[italic]Agent thoughts")
|
console.rule("[italic]Agent thoughts")
|
||||||
console.print(rationale)
|
console.print(rationale)
|
||||||
console.rule("[bold]Agent is executing the code below:", align="left")
|
console.rule("[bold]Agent is executing the code below:", align="left")
|
||||||
console.print(Syntax(code_action, lexer='python', background_color='default'))
|
console.print(Syntax(code_action, lexer="python", background_color="default"))
|
||||||
console.rule("", align="left")
|
console.rule("", align="left")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -886,7 +1003,9 @@ class CodeAgent(ReactAgent):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
console.rule("Last output from code snippet:", align="left")
|
console.rule("Last output from code snippet:", align="left")
|
||||||
console.print(str(result))
|
console.print(str(result))
|
||||||
observation += "Last output from code snippet:\n" + truncate_content(str(result))
|
observation += "Last output from code snippet:\n" + truncate_content(
|
||||||
|
str(result)
|
||||||
|
)
|
||||||
log_entry.observation = observation
|
log_entry.observation = observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||||
|
@ -902,7 +1021,14 @@ class CodeAgent(ReactAgent):
|
||||||
|
|
||||||
|
|
||||||
class ManagedAgent:
|
class ManagedAgent:
|
||||||
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent,
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
additional_prompting=None,
|
||||||
|
provide_run_summary=False,
|
||||||
|
):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
|
@ -925,18 +1051,22 @@ Your final_answer WILL HAVE to contain these parts:
|
||||||
|
|
||||||
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
|
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
|
||||||
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
||||||
<<additional_prompting>>"""
|
{{additional_prompting}}"""
|
||||||
if self.additional_prompting:
|
if self.additional_prompting:
|
||||||
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
|
full_task = full_task.replace(
|
||||||
|
"\n{{additional_prompting}}", self.additional_prompting
|
||||||
|
).strip()
|
||||||
else:
|
else:
|
||||||
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
|
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
||||||
return full_task
|
return full_task
|
||||||
|
|
||||||
def __call__(self, request, **kwargs):
|
def __call__(self, request, **kwargs):
|
||||||
full_task = self.write_full_task(request)
|
full_task = self.write_full_task(request)
|
||||||
output = self.agent.run(full_task, **kwargs)
|
output = self.agent.run(full_task, **kwargs)
|
||||||
if self.provide_run_summary:
|
if self.provide_run_summary:
|
||||||
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
answer = (
|
||||||
|
f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||||
|
)
|
||||||
answer += str(output)
|
answer += str(output)
|
||||||
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||||
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||||
|
|
|
@ -105,7 +105,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
tools = {}
|
tools = {}
|
||||||
for space_info in spaces:
|
for space_info in spaces:
|
||||||
repo_id = space_info.id
|
repo_id = space_info.id
|
||||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
resolved_config_file = hf_hub_download(
|
||||||
|
repo_id, TOOL_CONFIG_FILE, repo_type="space"
|
||||||
|
)
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||||
config = json.load(reader)
|
config = json.load(reader)
|
||||||
task = repo_id.split("/")[-1]
|
task = repo_id.split("/")[-1]
|
||||||
|
@ -131,7 +133,9 @@ class PythonInterpreterTool(Tool):
|
||||||
if authorized_imports is None:
|
if authorized_imports is None:
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
||||||
else:
|
else:
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
self.authorized_imports = list(
|
||||||
|
set(LIST_SAFE_MODULES) | set(authorized_imports)
|
||||||
|
)
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
"code": {
|
"code": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -145,7 +149,11 @@ class PythonInterpreterTool(Tool):
|
||||||
|
|
||||||
def forward(self, code):
|
def forward(self, code):
|
||||||
output = str(
|
output = str(
|
||||||
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
|
evaluate_python_code(
|
||||||
|
code,
|
||||||
|
static_tools=BASE_PYTHON_TOOLS,
|
||||||
|
authorized_imports=self.authorized_imports,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -153,16 +161,21 @@ class PythonInterpreterTool(Tool):
|
||||||
class FinalAnswerTool(Tool):
|
class FinalAnswerTool(Tool):
|
||||||
name = "final_answer"
|
name = "final_answer"
|
||||||
description = "Provides a final answer to the given problem."
|
description = "Provides a final answer to the given problem."
|
||||||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
inputs = {
|
||||||
|
"answer": {"type": "any", "description": "The final answer to the problem"}
|
||||||
|
}
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def forward(self, answer):
|
def forward(self, answer):
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
class UserInputTool(Tool):
|
class UserInputTool(Tool):
|
||||||
name = "user_input"
|
name = "user_input"
|
||||||
description = "Asks for user's input on a specific question"
|
description = "Asks for user's input on a specific question"
|
||||||
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
inputs = {
|
||||||
|
"question": {"type": "string", "description": "The question to ask the user"}
|
||||||
|
}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def forward(self, question):
|
def forward(self, question):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
from .agents import BaseAgent, AgentStep, ActionStep
|
from .agents import BaseAgent, AgentStep, ActionStep
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
"""Extract ChatMessage objects from agent steps"""
|
"""Extract ChatMessage objects from agent steps"""
|
||||||
if isinstance(step_log, ActionStep):
|
if isinstance(step_log, ActionStep):
|
||||||
|
@ -33,7 +34,9 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
content=str(content),
|
content=str(content),
|
||||||
)
|
)
|
||||||
if step_log.observation is not None:
|
if step_log.observation is not None:
|
||||||
yield gr.ChatMessage(role="assistant", content=f"```\n{step_log.observation}\n```")
|
yield gr.ChatMessage(
|
||||||
|
role="assistant", content=f"```\n{step_log.observation}\n```"
|
||||||
|
)
|
||||||
if step_log.error is not None:
|
if step_log.error is not None:
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
@ -42,7 +45,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs):
|
def stream_to_gradio(
|
||||||
|
agent,
|
||||||
|
task: str,
|
||||||
|
test_mode: bool = False,
|
||||||
|
reset_agent_memory: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
|
||||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
|
||||||
|
@ -52,7 +61,10 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
|
||||||
final_answer = step_log # Last log is the run's final_answer
|
final_answer = step_log # Last log is the run's final_answer
|
||||||
|
|
||||||
if isinstance(final_answer, AgentText):
|
if isinstance(final_answer, AgentText):
|
||||||
yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
|
yield gr.ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```",
|
||||||
|
)
|
||||||
elif isinstance(final_answer, AgentImage):
|
elif isinstance(final_answer, AgentImage):
|
||||||
yield gr.ChatMessage(
|
yield gr.ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
@ -67,8 +79,9 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
|
||||||
yield gr.ChatMessage(role="assistant", content=str(final_answer))
|
yield gr.ChatMessage(role="assistant", content=str(final_answer))
|
||||||
|
|
||||||
|
|
||||||
class GradioUI():
|
class GradioUI:
|
||||||
"""A one-line interface to launch your agent in Gradio"""
|
"""A one-line interface to launch your agent in Gradio"""
|
||||||
|
|
||||||
def __init__(self, agent: BaseAgent):
|
def __init__(self, agent: BaseAgent):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|
||||||
|
@ -83,10 +96,17 @@ class GradioUI():
|
||||||
def run(self):
|
def run(self):
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
stored_message = gr.State([])
|
stored_message = gr.State([])
|
||||||
chatbot = gr.Chatbot(label="Agent",
|
chatbot = gr.Chatbot(
|
||||||
|
label="Agent",
|
||||||
type="messages",
|
type="messages",
|
||||||
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
|
avatar_images=(
|
||||||
|
None,
|
||||||
|
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||||
|
),
|
||||||
|
)
|
||||||
text_input = gr.Textbox(lines=1, label="Chat Message")
|
text_input = gr.Textbox(lines=1, label="Chat Message")
|
||||||
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
text_input.submit(
|
||||||
|
lambda s: (s, ""), [text_input], [stored_message, text_input]
|
||||||
|
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
|
||||||
|
|
||||||
demo.launch()
|
demo.launch()
|
|
@ -39,7 +39,9 @@ class MessageRole(str, Enum):
|
||||||
return [r.value for r in cls]
|
return [r.value for r in cls]
|
||||||
|
|
||||||
|
|
||||||
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
|
def get_clean_message_list(
|
||||||
|
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Subsequent messages with the same role will be concatenated to a single message.
|
Subsequent messages with the same role will be concatenated to a single message.
|
||||||
|
|
||||||
|
@ -54,12 +56,17 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
|
||||||
|
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
if role not in MessageRole.roles():
|
if role not in MessageRole.roles():
|
||||||
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
raise ValueError(
|
||||||
|
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
|
||||||
|
)
|
||||||
|
|
||||||
if role in role_conversions:
|
if role in role_conversions:
|
||||||
message["role"] = role_conversions[role]
|
message["role"] = role_conversions[role]
|
||||||
|
|
||||||
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
if (
|
||||||
|
len(final_message_list) > 0
|
||||||
|
and message["role"] == final_message_list[-1]["role"]
|
||||||
|
):
|
||||||
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
||||||
else:
|
else:
|
||||||
final_message_list.append(message)
|
final_message_list.append(message)
|
||||||
|
@ -81,8 +88,12 @@ class HfEngine:
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
|
logger.warning(
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
|
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
|
)
|
||||||
|
|
||||||
def get_token_counts(self):
|
def get_token_counts(self):
|
||||||
return {
|
return {
|
||||||
|
@ -91,12 +102,18 @@ class HfEngine:
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Process the input messages and return the model's response.
|
"""Process the input messages and return the model's response.
|
||||||
|
|
||||||
|
@ -127,11 +144,15 @@ class HfEngine:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if not isinstance(messages, List):
|
if not isinstance(messages, List):
|
||||||
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
|
raise ValueError(
|
||||||
|
"Messages should be a list of dictionaries with 'role' and 'content' keys."
|
||||||
|
)
|
||||||
if stop_sequences is None:
|
if stop_sequences is None:
|
||||||
stop_sequences = []
|
stop_sequences = []
|
||||||
response = self.generate(messages, stop_sequences, grammar)
|
response = self.generate(messages, stop_sequences, grammar)
|
||||||
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
|
self.last_input_token_count = len(
|
||||||
|
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||||
|
)
|
||||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
self.last_output_token_count = len(self.tokenizer.encode(response))
|
||||||
|
|
||||||
# Remove stop sequences from LLM output
|
# Remove stop sequences from LLM output
|
||||||
|
@ -175,18 +196,28 @@ class HfApiEngine(HfEngine):
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
|
||||||
# Send messages to the Hugging Face Inference API
|
# Send messages to the Hugging Face Inference API
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
response = self.client.chat_completion(
|
response = self.client.chat_completion(
|
||||||
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
|
messages,
|
||||||
|
stop=stop_sequences,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
response_format=grammar,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
|
response = self.client.chat_completion(
|
||||||
|
messages, stop=stop_sequences, max_tokens=self.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
return response
|
return response
|
||||||
|
@ -207,7 +238,9 @@ class TransformersEngine(HfEngine):
|
||||||
max_length: int = 1500,
|
max_length: int = 1500,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
messages = get_clean_message_list(
|
||||||
|
messages, role_conversions=llama_role_conversions
|
||||||
|
)
|
||||||
|
|
||||||
# Get LLM output
|
# Get LLM output
|
||||||
if stop_sequences is not None and len(stop_sequences) > 0:
|
if stop_sequences is not None and len(stop_sequences) > 0:
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
from .utils import console
|
from .utils import console
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Monitor:
|
class Monitor:
|
||||||
def __init__(self, tracked_llm_engine):
|
def __init__(self, tracked_llm_engine):
|
||||||
self.step_durations = []
|
self.step_durations = []
|
||||||
self.tracked_llm_engine = tracked_llm_engine
|
self.tracked_llm_engine = tracked_llm_engine
|
||||||
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
|
if (
|
||||||
|
getattr(self.tracked_llm_engine, "last_input_token_count", "Not found")
|
||||||
|
!= "Not found"
|
||||||
|
):
|
||||||
self.total_input_token_count = 0
|
self.total_input_token_count = 0
|
||||||
self.total_output_token_count = 0
|
self.total_output_token_count = 0
|
||||||
|
|
||||||
|
@ -33,103 +35,11 @@ class Monitor:
|
||||||
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
||||||
|
|
||||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||||
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
self.total_input_token_count += (
|
||||||
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
self.tracked_llm_engine.last_input_token_count
|
||||||
|
)
|
||||||
|
self.total_output_token_count += (
|
||||||
|
self.tracked_llm_engine.last_output_token_count
|
||||||
|
)
|
||||||
console.print(f"- Input tokens: {self.total_input_token_count:,}")
|
console.print(f"- Input tokens: {self.total_input_token_count:,}")
|
||||||
console.print(f"- Output tokens: {self.total_output_token_count:,}")
|
console.print(f"- Output tokens: {self.total_output_token_count:,}")
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Union, List, Any
|
|
||||||
import httpx
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from langfuse.client import Langfuse, StatefulTraceClient, StatefulSpanClient, StateType
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def call(cls, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class LangfuseTracker(BaseTracker):
|
|
||||||
log = logging.getLogger("langfuse")
|
|
||||||
|
|
||||||
def __init__(self, *, public_key: Optional[str] = None, secret_key: Optional[str] = None,
|
|
||||||
host: Optional[str] = None, debug: bool = False, stateful_client: Optional[
|
|
||||||
Union[StatefulTraceClient, StatefulSpanClient]
|
|
||||||
] = None, update_stateful_client: bool = False, version: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None,
|
|
||||||
release: Optional[str] = None, metadata: Optional[Any] = None, tags: Optional[List[str]] = None,
|
|
||||||
threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None,
|
|
||||||
max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None,
|
|
||||||
httpx_client: Optional[httpx.Client] = None, sdk_integration: str = "default") -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.version = version
|
|
||||||
self.session_id = session_id
|
|
||||||
self.user_id = user_id
|
|
||||||
self.trace_name = trace_name
|
|
||||||
self.release = release
|
|
||||||
self.metadata = metadata
|
|
||||||
self.tags = tags
|
|
||||||
|
|
||||||
self.root_span = None
|
|
||||||
self.update_stateful_client = update_stateful_client
|
|
||||||
self.langfuse = None
|
|
||||||
|
|
||||||
prio_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY")
|
|
||||||
prio_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY")
|
|
||||||
prio_host = host or os.environ.get(
|
|
||||||
"LANGFUSE_HOST", "https://cloud.langfuse.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
if stateful_client and isinstance(stateful_client, StatefulTraceClient):
|
|
||||||
self.trace = stateful_client
|
|
||||||
self._task_manager = stateful_client.task_manager
|
|
||||||
return
|
|
||||||
|
|
||||||
elif stateful_client and isinstance(stateful_client, StatefulSpanClient):
|
|
||||||
self.root_span = stateful_client
|
|
||||||
self.trace = StatefulTraceClient(
|
|
||||||
stateful_client.client,
|
|
||||||
stateful_client.trace_id,
|
|
||||||
StateType.TRACE,
|
|
||||||
stateful_client.trace_id,
|
|
||||||
stateful_client.task_manager,
|
|
||||||
)
|
|
||||||
self._task_manager = stateful_client.task_manager
|
|
||||||
return
|
|
||||||
|
|
||||||
args = {
|
|
||||||
"public_key": prio_public_key,
|
|
||||||
"secret_key": prio_secret_key,
|
|
||||||
"host": prio_host,
|
|
||||||
"debug": debug,
|
|
||||||
}
|
|
||||||
|
|
||||||
if release is not None:
|
|
||||||
args["release"] = release
|
|
||||||
if threads is not None:
|
|
||||||
args["threads"] = threads
|
|
||||||
if flush_at is not None:
|
|
||||||
args["flush_at"] = flush_at
|
|
||||||
if flush_interval is not None:
|
|
||||||
args["flush_interval"] = flush_interval
|
|
||||||
if max_retries is not None:
|
|
||||||
args["max_retries"] = max_retries
|
|
||||||
if timeout is not None:
|
|
||||||
args["timeout"] = timeout
|
|
||||||
if enabled is not None:
|
|
||||||
args["enabled"] = enabled
|
|
||||||
if httpx_client is not None:
|
|
||||||
args["httpx_client"] = httpx_client
|
|
||||||
args["sdk_integration"] = sdk_integration
|
|
||||||
|
|
||||||
self.langfuse = Langfuse(**args)
|
|
||||||
self.trace: Optional[StatefulTraceClient] = None
|
|
||||||
self._task_manager = self.langfuse.task_manager
|
|
||||||
|
|
||||||
def call(self, i, o, name=None, **kwargs):
|
|
||||||
self.langfuse.trace(input=i, output=o, name=name, metadata=kwargs)
|
|
|
@ -42,7 +42,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
||||||
return prompt_or_repo_id
|
return prompt_or_repo_id
|
||||||
|
|
||||||
prompt_file = cached_file(
|
prompt_file = cached_file(
|
||||||
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
|
prompt_or_repo_id,
|
||||||
|
PROMPT_FILES[mode],
|
||||||
|
repo_type="dataset",
|
||||||
|
user_agent={"agent": agent_name},
|
||||||
)
|
)
|
||||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
|
@ -26,6 +26,7 @@ import pandas as pd
|
||||||
|
|
||||||
from .utils import truncate_content
|
from .utils import truncate_content
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
"""
|
"""
|
||||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||||
|
@ -38,7 +39,8 @@ class InterpreterError(ValueError):
|
||||||
ERRORS = {
|
ERRORS = {
|
||||||
name: getattr(builtins, name)
|
name: getattr(builtins, name)
|
||||||
for name in dir(builtins)
|
for name in dir(builtins)
|
||||||
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
if isinstance(getattr(builtins, name), type)
|
||||||
|
and issubclass(getattr(builtins, name), BaseException)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +94,9 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
||||||
elif isinstance(expression.op, ast.Invert):
|
elif isinstance(expression.op, ast.Invert):
|
||||||
return ~operand
|
return ~operand
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
raise InterpreterError(
|
||||||
|
f"Unary operation {expression.op.__class__.__name__} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||||
|
@ -102,7 +106,9 @@ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
for arg, value in zip(args, values):
|
for arg, value in zip(args, values):
|
||||||
new_state[arg] = value
|
new_state[arg] = value
|
||||||
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
return evaluate_ast(
|
||||||
|
lambda_expression.body, new_state, static_tools, custom_tools
|
||||||
|
)
|
||||||
|
|
||||||
return lambda_func
|
return lambda_func
|
||||||
|
|
||||||
|
@ -120,7 +126,9 @@ def evaluate_while(while_loop, state, static_tools, custom_tools):
|
||||||
break
|
break
|
||||||
iterations += 1
|
iterations += 1
|
||||||
if iterations > max_iterations:
|
if iterations > max_iterations:
|
||||||
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
raise InterpreterError(
|
||||||
|
f"Maximum number of {max_iterations} iterations in While loop exceeded"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,7 +136,10 @@ def create_function(func_def, state, static_tools, custom_tools):
|
||||||
def new_func(*args, **kwargs):
|
def new_func(*args, **kwargs):
|
||||||
func_state = state.copy()
|
func_state = state.copy()
|
||||||
arg_names = [arg.arg for arg in func_def.args.args]
|
arg_names = [arg.arg for arg in func_def.args.args]
|
||||||
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
default_values = [
|
||||||
|
evaluate_ast(d, state, static_tools, custom_tools)
|
||||||
|
for d in func_def.args.defaults
|
||||||
|
]
|
||||||
|
|
||||||
# Apply default values
|
# Apply default values
|
||||||
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||||
|
@ -180,26 +191,39 @@ def create_class(class_name, class_bases, class_body):
|
||||||
|
|
||||||
|
|
||||||
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
||||||
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
custom_tools[func_def.name] = create_function(
|
||||||
|
func_def, state, static_tools, custom_tools
|
||||||
|
)
|
||||||
return custom_tools[func_def.name]
|
return custom_tools[func_def.name]
|
||||||
|
|
||||||
|
|
||||||
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||||
class_name = class_def.name
|
class_name = class_def.name
|
||||||
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
bases = [
|
||||||
|
evaluate_ast(base, state, static_tools, custom_tools)
|
||||||
|
for base in class_def.bases
|
||||||
|
]
|
||||||
class_dict = {}
|
class_dict = {}
|
||||||
|
|
||||||
for stmt in class_def.body:
|
for stmt in class_def.body:
|
||||||
if isinstance(stmt, ast.FunctionDef):
|
if isinstance(stmt, ast.FunctionDef):
|
||||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
class_dict[stmt.name] = evaluate_function_def(
|
||||||
|
stmt, state, static_tools, custom_tools
|
||||||
|
)
|
||||||
elif isinstance(stmt, ast.Assign):
|
elif isinstance(stmt, ast.Assign):
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
class_dict[target.id] = evaluate_ast(
|
||||||
|
stmt.value, state, static_tools, custom_tools
|
||||||
|
)
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
class_dict[target.attr] = evaluate_ast(
|
||||||
|
stmt.value, state, static_tools, custom_tools
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
raise InterpreterError(
|
||||||
|
f"Unsupported statement in class body: {stmt.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
new_class = type(class_name, tuple(bases), class_dict)
|
new_class = type(class_name, tuple(bases), class_dict)
|
||||||
state[class_name] = new_class
|
state[class_name] = new_class
|
||||||
|
@ -223,7 +247,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
elif isinstance(target, ast.List):
|
elif isinstance(target, ast.List):
|
||||||
return [get_current_value(elt) for elt in target.elts]
|
return [get_current_value(elt) for elt in target.elts]
|
||||||
else:
|
else:
|
||||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
raise InterpreterError(
|
||||||
|
"AugAssign not supported for {type(target)} targets."
|
||||||
|
)
|
||||||
|
|
||||||
current_value = get_current_value(expression.target)
|
current_value = get_current_value(expression.target)
|
||||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
|
@ -232,7 +258,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
if isinstance(expression.op, ast.Add):
|
if isinstance(expression.op, ast.Add):
|
||||||
if isinstance(current_value, list):
|
if isinstance(current_value, list):
|
||||||
if not isinstance(value_to_add, list):
|
if not isinstance(value_to_add, list):
|
||||||
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
raise InterpreterError(
|
||||||
|
f"Cannot add non-list value {value_to_add} to a list."
|
||||||
|
)
|
||||||
updated_value = current_value + value_to_add
|
updated_value = current_value + value_to_add
|
||||||
else:
|
else:
|
||||||
updated_value = current_value + value_to_add
|
updated_value = current_value + value_to_add
|
||||||
|
@ -259,7 +287,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
elif isinstance(expression.op, ast.RShift):
|
elif isinstance(expression.op, ast.RShift):
|
||||||
updated_value = current_value >> value_to_add
|
updated_value = current_value >> value_to_add
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
raise InterpreterError(
|
||||||
|
f"Operation {type(expression.op).__name__} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
# Update the state
|
# Update the state
|
||||||
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
||||||
|
@ -311,7 +341,9 @@ def evaluate_binop(binop, state, static_tools, custom_tools):
|
||||||
elif isinstance(binop.op, ast.RShift):
|
elif isinstance(binop.op, ast.RShift):
|
||||||
return left_val >> right_val
|
return left_val >> right_val
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
raise NotImplementedError(
|
||||||
|
f"Binary operation {type(binop.op).__name__} is not implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_assign(assign, state, static_tools, custom_tools):
|
def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||||
|
@ -321,7 +353,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||||
set_value(target, result, state, static_tools, custom_tools)
|
set_value(target, result, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
if len(assign.targets) != len(result):
|
if len(assign.targets) != len(result):
|
||||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
raise InterpreterError(
|
||||||
|
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
|
||||||
|
)
|
||||||
expanded_values = []
|
expanded_values = []
|
||||||
for tgt in assign.targets:
|
for tgt in assign.targets:
|
||||||
if isinstance(tgt, ast.Starred):
|
if isinstance(tgt, ast.Starred):
|
||||||
|
@ -336,7 +370,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||||
def set_value(target, value, state, static_tools, custom_tools):
|
def set_value(target, value, state, static_tools, custom_tools):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in static_tools:
|
if target.id in static_tools:
|
||||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
raise InterpreterError(
|
||||||
|
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
|
||||||
|
)
|
||||||
state[target.id] = value
|
state[target.id] = value
|
||||||
elif isinstance(target, ast.Tuple):
|
elif isinstance(target, ast.Tuple):
|
||||||
if not isinstance(value, tuple):
|
if not isinstance(value, tuple):
|
||||||
|
@ -399,9 +435,14 @@ def evaluate_call(call, state, static_tools, custom_tools):
|
||||||
else:
|
else:
|
||||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||||
|
|
||||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
kwargs = {
|
||||||
|
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools)
|
||||||
|
for keyword in call.keywords
|
||||||
|
}
|
||||||
|
|
||||||
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
if (
|
||||||
|
isinstance(func, type) and len(func.__module__.split(".")) > 1
|
||||||
|
): # Check for user-defined classes
|
||||||
# Instantiate the class using its constructor
|
# Instantiate the class using its constructor
|
||||||
obj = func.__new__(func) # Create a new instance of the class
|
obj = func.__new__(func) # Create a new instance of the class
|
||||||
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
||||||
|
@ -441,7 +482,9 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||||
|
|
||||||
if isinstance(value, str) and isinstance(index, str):
|
if isinstance(value, str) and isinstance(index, str):
|
||||||
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
raise InterpreterError(
|
||||||
|
"You're trying to subscript a string with a string index, which is impossible"
|
||||||
|
)
|
||||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||||
parent_object = value.obj
|
parent_object = value.obj
|
||||||
return parent_object.loc[index]
|
return parent_object.loc[index]
|
||||||
|
@ -453,11 +496,15 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||||
return value[index]
|
return value[index]
|
||||||
elif isinstance(value, (list, tuple)):
|
elif isinstance(value, (list, tuple)):
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
raise InterpreterError(
|
||||||
|
f"Index {index} out of bounds for list of length {len(value)}"
|
||||||
|
)
|
||||||
return value[int(index)]
|
return value[int(index)]
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
raise InterpreterError(
|
||||||
|
f"Index {index} out of bounds for string of length {len(value)}"
|
||||||
|
)
|
||||||
return value[index]
|
return value[index]
|
||||||
elif index in value:
|
elif index in value:
|
||||||
return value[index]
|
return value[index]
|
||||||
|
@ -483,7 +530,10 @@ def evaluate_name(name, state, static_tools, custom_tools):
|
||||||
|
|
||||||
def evaluate_condition(condition, state, static_tools, custom_tools):
|
def evaluate_condition(condition, state, static_tools, custom_tools):
|
||||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
||||||
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
comparators = [
|
||||||
|
evaluate_ast(c, state, static_tools, custom_tools)
|
||||||
|
for c in condition.comparators
|
||||||
|
]
|
||||||
ops = [type(op) for op in condition.ops]
|
ops = [type(op) for op in condition.ops]
|
||||||
|
|
||||||
result = True
|
result = True
|
||||||
|
@ -561,9 +611,13 @@ def evaluate_for(for_loop, state, static_tools, custom_tools):
|
||||||
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||||
def inner_evaluate(generators, index, current_state):
|
def inner_evaluate(generators, index, current_state):
|
||||||
if index >= len(generators):
|
if index >= len(generators):
|
||||||
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
return [
|
||||||
|
evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)
|
||||||
|
]
|
||||||
generator = generators[index]
|
generator = generators[index]
|
||||||
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
iter_value = evaluate_ast(
|
||||||
|
generator.iter, current_state, static_tools, custom_tools
|
||||||
|
)
|
||||||
result = []
|
result = []
|
||||||
for value in iter_value:
|
for value in iter_value:
|
||||||
new_state = current_state.copy()
|
new_state = current_state.copy()
|
||||||
|
@ -572,7 +626,10 @@ def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||||
new_state[elem.id] = value[idx]
|
new_state[elem.id] = value[idx]
|
||||||
else:
|
else:
|
||||||
new_state[generator.target.id] = value
|
new_state[generator.target.id] = value
|
||||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
if all(
|
||||||
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
|
||||||
|
for if_clause in generator.ifs
|
||||||
|
):
|
||||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -586,7 +643,9 @@ def evaluate_try(try_node, state, static_tools, custom_tools):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
matched = False
|
matched = False
|
||||||
for handler in try_node.handlers:
|
for handler in try_node.handlers:
|
||||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
if handler.type is None or isinstance(
|
||||||
|
e, evaluate_ast(handler.type, state, static_tools, custom_tools)
|
||||||
|
):
|
||||||
matched = True
|
matched = True
|
||||||
if handler.name:
|
if handler.name:
|
||||||
state[handler.name] = e
|
state[handler.name] = e
|
||||||
|
@ -638,7 +697,9 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
||||||
def evaluate_with(with_node, state, static_tools, custom_tools):
|
def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||||
contexts = []
|
contexts = []
|
||||||
for item in with_node.items:
|
for item in with_node.items:
|
||||||
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
context_expr = evaluate_ast(
|
||||||
|
item.context_expr, state, static_tools, custom_tools
|
||||||
|
)
|
||||||
if item.optional_vars:
|
if item.optional_vars:
|
||||||
state[item.optional_vars.id] = context_expr.__enter__()
|
state[item.optional_vars.id] = context_expr.__enter__()
|
||||||
contexts.append(state[item.optional_vars.id])
|
contexts.append(state[item.optional_vars.id])
|
||||||
|
@ -661,7 +722,9 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||||
def import_modules(expression, state, authorized_imports):
|
def import_modules(expression, state, authorized_imports):
|
||||||
def check_module_authorized(module_name):
|
def check_module_authorized(module_name):
|
||||||
module_path = module_name.split(".")
|
module_path = module_name.split(".")
|
||||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
module_subpaths = [
|
||||||
|
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
||||||
|
]
|
||||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
|
||||||
if isinstance(expression, ast.Import):
|
if isinstance(expression, ast.Import):
|
||||||
|
@ -676,7 +739,9 @@ def import_modules(expression, state, authorized_imports):
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if check_module_authorized(expression.module):
|
if check_module_authorized(expression.module):
|
||||||
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
module = __import__(
|
||||||
|
expression.module, fromlist=[alias.name for alias in expression.names]
|
||||||
|
)
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||||
else:
|
else:
|
||||||
|
@ -691,9 +756,14 @@ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
||||||
for value in iter_value:
|
for value in iter_value:
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
||||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
if all(
|
||||||
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
|
||||||
|
for if_clause in gen.ifs
|
||||||
|
):
|
||||||
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
||||||
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
val = evaluate_ast(
|
||||||
|
dictcomp.value, new_state, static_tools, custom_tools
|
||||||
|
)
|
||||||
result[key] = val
|
result[key] = val
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -744,7 +814,10 @@ def evaluate_ast(
|
||||||
# Constant -> just return the value
|
# Constant -> just return the value
|
||||||
return expression.value
|
return expression.value
|
||||||
elif isinstance(expression, ast.Tuple):
|
elif isinstance(expression, ast.Tuple):
|
||||||
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
return tuple(
|
||||||
|
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||||
|
for elt in expression.elts
|
||||||
|
)
|
||||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||||
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.UnaryOp):
|
elif isinstance(expression, ast.UnaryOp):
|
||||||
|
@ -770,8 +843,13 @@ def evaluate_ast(
|
||||||
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Dict):
|
elif isinstance(expression, ast.Dict):
|
||||||
# Dict -> evaluate all keys and values
|
# Dict -> evaluate all keys and values
|
||||||
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
keys = [
|
||||||
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys
|
||||||
|
]
|
||||||
|
values = [
|
||||||
|
evaluate_ast(v, state, static_tools, custom_tools)
|
||||||
|
for v in expression.values
|
||||||
|
]
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
elif isinstance(expression, ast.Expr):
|
elif isinstance(expression, ast.Expr):
|
||||||
# Expression -> evaluate the content
|
# Expression -> evaluate the content
|
||||||
|
@ -788,10 +866,18 @@ def evaluate_ast(
|
||||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.JoinedStr):
|
elif isinstance(expression, ast.JoinedStr):
|
||||||
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
return "".join(
|
||||||
|
[
|
||||||
|
str(evaluate_ast(v, state, static_tools, custom_tools))
|
||||||
|
for v in expression.values
|
||||||
|
]
|
||||||
|
)
|
||||||
elif isinstance(expression, ast.List):
|
elif isinstance(expression, ast.List):
|
||||||
# List -> evaluate all elements
|
# List -> evaluate all elements
|
||||||
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
return [
|
||||||
|
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||||
|
for elt in expression.elts
|
||||||
|
]
|
||||||
elif isinstance(expression, ast.Name):
|
elif isinstance(expression, ast.Name):
|
||||||
# Name -> pick up the value in the state
|
# Name -> pick up the value in the state
|
||||||
return evaluate_name(expression, state, static_tools, custom_tools)
|
return evaluate_name(expression, state, static_tools, custom_tools)
|
||||||
|
@ -815,7 +901,9 @@ def evaluate_ast(
|
||||||
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
||||||
if expression.upper is not None
|
if expression.upper is not None
|
||||||
else None,
|
else None,
|
||||||
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
evaluate_ast(expression.step, state, static_tools, custom_tools)
|
||||||
|
if expression.step is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
elif isinstance(expression, ast.DictComp):
|
elif isinstance(expression, ast.DictComp):
|
||||||
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
||||||
|
@ -834,17 +922,24 @@ def evaluate_ast(
|
||||||
elif isinstance(expression, ast.With):
|
elif isinstance(expression, ast.With):
|
||||||
return evaluate_with(expression, state, static_tools, custom_tools)
|
return evaluate_with(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Set):
|
elif isinstance(expression, ast.Set):
|
||||||
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
return {
|
||||||
|
evaluate_ast(elt, state, static_tools, custom_tools)
|
||||||
|
for elt in expression.elts
|
||||||
|
}
|
||||||
elif isinstance(expression, ast.Return):
|
elif isinstance(expression, ast.Return):
|
||||||
raise ReturnException(
|
raise ReturnException(
|
||||||
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
|
if expression.value
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# For now we refuse anything else. Let's add things as we need them.
|
# For now we refuse anything else. Let's add things as we need them.
|
||||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
|
def truncate_print_outputs(
|
||||||
|
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
|
||||||
|
) -> str:
|
||||||
if len(print_outputs) < max_len_outputs:
|
if len(print_outputs) < max_len_outputs:
|
||||||
return print_outputs
|
return print_outputs
|
||||||
else:
|
else:
|
||||||
|
@ -895,8 +990,12 @@ def evaluate_python_code(
|
||||||
OPERATIONS_COUNT = 0
|
OPERATIONS_COUNT = 0
|
||||||
try:
|
try:
|
||||||
for node in expression.body:
|
for node in expression.body:
|
||||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
result = evaluate_ast(
|
||||||
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
node, state, static_tools, custom_tools, authorized_imports
|
||||||
|
)
|
||||||
|
state["print_outputs"] = truncate_content(
|
||||||
|
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
||||||
|
|
|
@ -26,7 +26,9 @@ class DuckDuckGoSearchTool(Tool):
|
||||||
name = "web_search"
|
name = "web_search"
|
||||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||||
Each result has keys 'title', 'href' and 'body'."""
|
Each result has keys 'title', 'href' and 'body'."""
|
||||||
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
inputs = {
|
||||||
|
"query": {"type": "string", "description": "The search query to perform."}
|
||||||
|
}
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def forward(self, query: str) -> str:
|
def forward(self, query: str) -> str:
|
||||||
|
|
141
agents/tools.py
141
agents/tools.py
|
@ -26,7 +26,13 @@ from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
|
from huggingface_hub import (
|
||||||
|
create_repo,
|
||||||
|
get_collection,
|
||||||
|
hf_hub_download,
|
||||||
|
metadata_update,
|
||||||
|
upload_folder,
|
||||||
|
)
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
@ -73,7 +79,9 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||||
return "model"
|
return "model"
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
raise EnvironmentError(
|
||||||
|
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return "model"
|
return "model"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -158,7 +166,15 @@ class Tool:
|
||||||
"inputs": dict,
|
"inputs": dict,
|
||||||
"output_type": str,
|
"output_type": str,
|
||||||
}
|
}
|
||||||
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
|
authorized_types = [
|
||||||
|
"string",
|
||||||
|
"integer",
|
||||||
|
"number",
|
||||||
|
"image",
|
||||||
|
"audio",
|
||||||
|
"any",
|
||||||
|
"boolean",
|
||||||
|
]
|
||||||
|
|
||||||
for attr, expected_type in required_attributes.items():
|
for attr, expected_type in required_attributes.items():
|
||||||
attr_value = getattr(self, attr, None)
|
attr_value = getattr(self, attr, None)
|
||||||
|
@ -169,7 +185,9 @@ class Tool:
|
||||||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||||
)
|
)
|
||||||
for input_name, input_content in self.inputs.items():
|
for input_name, input_content in self.inputs.items():
|
||||||
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
assert isinstance(
|
||||||
|
input_content, dict
|
||||||
|
), f"Input '{input_name}' should be a dictionary."
|
||||||
assert (
|
assert (
|
||||||
"type" in input_content and "description" in input_content
|
"type" in input_content and "description" in input_content
|
||||||
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||||
|
@ -251,7 +269,11 @@ class Tool:
|
||||||
# Save app file
|
# Save app file
|
||||||
app_file = os.path.join(output_dir, "app.py")
|
app_file = os.path.join(output_dir, "app.py")
|
||||||
with open(app_file, "w", encoding="utf-8") as f:
|
with open(app_file, "w", encoding="utf-8") as f:
|
||||||
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
f.write(
|
||||||
|
APP_FILE_TEMPLATE.format(
|
||||||
|
module_name=last_module, class_name=self.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Save requirements file
|
# Save requirements file
|
||||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||||
|
@ -343,7 +365,9 @@ class Tool:
|
||||||
custom_tool = config
|
custom_tool = config
|
||||||
|
|
||||||
tool_class = custom_tool["tool_class"]
|
tool_class = custom_tool["tool_class"]
|
||||||
tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
|
tool_class = get_class_from_dynamic_module(
|
||||||
|
tool_class, repo_id, token=token, **hub_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if len(tool_class.name) == 0:
|
if len(tool_class.name) == 0:
|
||||||
tool_class.name = custom_tool["name"]
|
tool_class.name = custom_tool["name"]
|
||||||
|
@ -420,7 +444,9 @@ class Tool:
|
||||||
with tempfile.TemporaryDirectory() as work_dir:
|
with tempfile.TemporaryDirectory() as work_dir:
|
||||||
# Save all files.
|
# Save all files.
|
||||||
self.save(work_dir)
|
self.save(work_dir)
|
||||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
logger.info(
|
||||||
|
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||||
|
)
|
||||||
return upload_folder(
|
return upload_folder(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
commit_message=commit_message,
|
commit_message=commit_message,
|
||||||
|
@ -432,7 +458,11 @@ class Tool:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_space(
|
def from_space(
|
||||||
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
|
space_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
api_name: Optional[str] = None,
|
||||||
|
token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a [`Tool`] from a Space given its id on the Hub.
|
Creates a [`Tool`] from a Space given its id on the Hub.
|
||||||
|
@ -485,7 +515,9 @@ class Tool:
|
||||||
self.client = Client(space_id, hf_token=token)
|
self.client = Client(space_id, hf_token=token)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
space_description = self.client.view_api(
|
||||||
|
return_format="dict", print_info=False
|
||||||
|
)["named_endpoints"]
|
||||||
|
|
||||||
# If api_name is not defined, take the first of the available APIs for this space
|
# If api_name is not defined, take the first of the available APIs for this space
|
||||||
if api_name is None:
|
if api_name is None:
|
||||||
|
@ -498,7 +530,9 @@ class Tool:
|
||||||
try:
|
try:
|
||||||
space_description_api = space_description[api_name]
|
space_description_api = space_description[api_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
raise KeyError(
|
||||||
|
f"Could not find specified {api_name=} among available api names."
|
||||||
|
)
|
||||||
|
|
||||||
self.inputs = {}
|
self.inputs = {}
|
||||||
for parameter in space_description_api["parameters"]:
|
for parameter in space_description_api["parameters"]:
|
||||||
|
@ -523,9 +557,11 @@ class Tool:
|
||||||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||||
arg.save(temp_file.name)
|
arg.save(temp_file.name)
|
||||||
arg = temp_file.name
|
arg = temp_file.name
|
||||||
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
|
if (
|
||||||
arg
|
isinstance(arg, (str, Path))
|
||||||
):
|
and Path(arg).exists()
|
||||||
|
and Path(arg).is_file()
|
||||||
|
) or is_http_url_like(arg):
|
||||||
arg = handle_file(arg)
|
arg = handle_file(arg)
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
|
@ -544,7 +580,9 @@ class Tool:
|
||||||
] # Sometime the space also returns the generation seed, in which case the result is at index 0
|
] # Sometime the space also returns the generation seed, in which case the result is at index 0
|
||||||
return output
|
return output
|
||||||
|
|
||||||
return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
|
return SpaceToolWrapper(
|
||||||
|
space_id, name, description, api_name=api_name, token=token
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gradio(gradio_tool):
|
def from_gradio(gradio_tool):
|
||||||
|
@ -561,7 +599,8 @@ class Tool:
|
||||||
self._gradio_tool = _gradio_tool
|
self._gradio_tool = _gradio_tool
|
||||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
|
||||||
|
for key, value in func_args
|
||||||
}
|
}
|
||||||
self.forward = self._gradio_tool.run
|
self.forward = self._gradio_tool.run
|
||||||
|
|
||||||
|
@ -603,7 +642,9 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
|
def get_tool_description_with_args(
|
||||||
|
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
|
) -> str:
|
||||||
compiled_template = compile_jinja_template(description_template)
|
compiled_template = compile_jinja_template(description_template)
|
||||||
rendered = compiled_template.render(
|
rendered = compiled_template.render(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
|
@ -621,7 +662,10 @@ def compile_jinja_template(template):
|
||||||
raise ImportError("template requires jinja2 to be installed.")
|
raise ImportError("template requires jinja2 to be installed.")
|
||||||
|
|
||||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||||
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
|
raise ImportError(
|
||||||
|
"template requires jinja2>=3.1.0 to be installed. Your version is "
|
||||||
|
f"{jinja2.__version__}."
|
||||||
|
)
|
||||||
|
|
||||||
def raise_exception(message):
|
def raise_exception(message):
|
||||||
raise TemplateError(message)
|
raise TemplateError(message)
|
||||||
|
@ -697,7 +741,9 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
if self.default_checkpoint is None:
|
if self.default_checkpoint is None:
|
||||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
raise ValueError(
|
||||||
|
"This tool does not implement a default checkpoint, you need to pass one."
|
||||||
|
)
|
||||||
model = self.default_checkpoint
|
model = self.default_checkpoint
|
||||||
if pre_processor is None:
|
if pre_processor is None:
|
||||||
pre_processor = model
|
pre_processor = model
|
||||||
|
@ -720,15 +766,21 @@ class PipelineTool(Tool):
|
||||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||||
"""
|
"""
|
||||||
if isinstance(self.pre_processor, str):
|
if isinstance(self.pre_processor, str):
|
||||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
self.pre_processor = self.pre_processor_class.from_pretrained(
|
||||||
|
self.pre_processor, **self.hub_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(self.model, str):
|
if isinstance(self.model, str):
|
||||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
self.model = self.model_class.from_pretrained(
|
||||||
|
self.model, **self.model_kwargs, **self.hub_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.post_processor is None:
|
if self.post_processor is None:
|
||||||
self.post_processor = self.pre_processor
|
self.post_processor = self.pre_processor
|
||||||
elif isinstance(self.post_processor, str):
|
elif isinstance(self.post_processor, str):
|
||||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
self.post_processor = self.post_processor_class.from_pretrained(
|
||||||
|
self.post_processor, **self.hub_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
if self.device_map is not None:
|
if self.device_map is not None:
|
||||||
|
@ -768,8 +820,12 @@ class PipelineTool(Tool):
|
||||||
|
|
||||||
encoded_inputs = self.encode(*args, **kwargs)
|
encoded_inputs = self.encode(*args, **kwargs)
|
||||||
|
|
||||||
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
tensor_inputs = {
|
||||||
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
non_tensor_inputs = {
|
||||||
|
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||||
|
@ -790,7 +846,9 @@ def launch_gradio_demo(tool_class: Tool):
|
||||||
try:
|
try:
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
raise ImportError(
|
||||||
|
"Gradio should be installed in order to launch a gradio demo."
|
||||||
|
)
|
||||||
|
|
||||||
tool = tool_class()
|
tool = tool_class()
|
||||||
|
|
||||||
|
@ -807,11 +865,15 @@ def launch_gradio_demo(tool_class: Tool):
|
||||||
|
|
||||||
gradio_inputs = []
|
gradio_inputs = []
|
||||||
for input_name, input_details in tool_class.inputs.items():
|
for input_name, input_details in tool_class.inputs.items():
|
||||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||||
|
input_details["type"]
|
||||||
|
]
|
||||||
new_component = input_gradio_component_class(label=input_name)
|
new_component = input_gradio_component_class(label=input_name)
|
||||||
gradio_inputs.append(new_component)
|
gradio_inputs.append(new_component)
|
||||||
|
|
||||||
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
|
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||||
|
tool_class.output_type
|
||||||
|
]
|
||||||
gradio_output = output_gradio_componentclass(label=input_name)
|
gradio_output = output_gradio_componentclass(label=input_name)
|
||||||
|
|
||||||
gr.Interface(
|
gr.Interface(
|
||||||
|
@ -875,7 +937,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
||||||
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
|
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
|
||||||
f"code that you have checked."
|
f"code that you have checked."
|
||||||
)
|
)
|
||||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
|
return Tool.from_hub(
|
||||||
|
task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_description(description):
|
def add_description(description):
|
||||||
|
@ -935,7 +999,9 @@ class EndpointClient:
|
||||||
payload["parameters"] = params
|
payload["parameters"] = params
|
||||||
|
|
||||||
# Make API call
|
# Make API call
|
||||||
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
response = get_session().post(
|
||||||
|
self.endpoint_url, headers=self.headers, json=payload, data=data
|
||||||
|
)
|
||||||
|
|
||||||
# By default, parse the response for the user.
|
# By default, parse the response for the user.
|
||||||
if output_image:
|
if output_image:
|
||||||
|
@ -972,7 +1038,9 @@ class ToolCollection:
|
||||||
|
|
||||||
def __init__(self, collection_slug: str, token: Optional[str] = None):
|
def __init__(self, collection_slug: str, token: Optional[str] = None):
|
||||||
self._collection = get_collection(collection_slug, token=token)
|
self._collection = get_collection(collection_slug, token=token)
|
||||||
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
|
self._hub_repo_ids = {
|
||||||
|
item.item_id for item in self._collection.items if item.item_type == "space"
|
||||||
|
}
|
||||||
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
||||||
|
|
||||||
|
|
||||||
|
@ -986,7 +1054,9 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
"""
|
"""
|
||||||
parameters = get_json_schema(tool_function)["function"]
|
parameters = get_json_schema(tool_function)["function"]
|
||||||
if "return" not in parameters:
|
if "return" not in parameters:
|
||||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
raise TypeHintParsingException(
|
||||||
|
"Tool return type not found: make sure your function has a return type hint!"
|
||||||
|
)
|
||||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
class_name = f"{parameters['name'].capitalize()}Tool"
|
||||||
|
|
||||||
class SpecificTool(Tool):
|
class SpecificTool(Tool):
|
||||||
|
@ -1000,9 +1070,9 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
return tool_function(*args, **kwargs)
|
return tool_function(*args, **kwargs)
|
||||||
|
|
||||||
original_signature = inspect.signature(tool_function)
|
original_signature = inspect.signature(tool_function)
|
||||||
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
|
new_parameters = [
|
||||||
original_signature.parameters.values()
|
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||||
)
|
] + list(original_signature.parameters.values())
|
||||||
new_signature = original_signature.replace(parameters=new_parameters)
|
new_signature = original_signature.replace(parameters=new_parameters)
|
||||||
SpecificTool.forward.__signature__ = new_signature
|
SpecificTool.forward.__signature__ = new_signature
|
||||||
|
|
||||||
|
@ -1049,7 +1119,10 @@ class Toolbox:
|
||||||
The template to use to describe the tools. If not provided, the default template will be used.
|
The template to use to describe the tools. If not provided, the default template will be used.
|
||||||
"""
|
"""
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
|
[
|
||||||
|
get_tool_description_with_args(tool, tool_description_template)
|
||||||
|
for tool in self._tools.values()
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_tool(self, tool: Tool):
|
def add_tool(self, tool: Tool):
|
||||||
|
|
|
@ -16,32 +16,29 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict, Union
|
||||||
|
|
||||||
from transformers.utils.import_utils import _is_package_available
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
_pygments_available = _is_package_available("pygments")
|
_pygments_available = _is_package_available("pygments")
|
||||||
|
|
||||||
|
|
||||||
def is_pygments_available():
|
def is_pygments_available():
|
||||||
return _pygments_available
|
return _pygments_available
|
||||||
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
LENGTH_TRUNCATE_REPORTS = 10000
|
|
||||||
|
|
||||||
def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS):
|
|
||||||
if len(content) < max_length:
|
|
||||||
return content
|
|
||||||
else:
|
|
||||||
return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
first_accolade_index = json_blob.find("{")
|
first_accolade_index = json_blob.find("{")
|
||||||
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
||||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
|
||||||
|
'\\"', "'"
|
||||||
|
)
|
||||||
json_data = json.loads(json_blob, strict=False)
|
json_data = json.loads(json_blob, strict=False)
|
||||||
return json_data
|
return json_data
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
@ -63,7 +60,12 @@ def parse_code_blob(code_blob: str) -> str:
|
||||||
try:
|
try:
|
||||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||||
match = re.search(pattern, code_blob, re.DOTALL)
|
match = re.search(pattern, code_blob, re.DOTALL)
|
||||||
|
if match is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"No match ground for regex pattern {pattern} in {code_blob=}."
|
||||||
|
)
|
||||||
return match.group(1).strip()
|
return match.group(1).strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""
|
f"""
|
||||||
|
@ -77,7 +79,7 @@ Code:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
||||||
json_blob = json_blob.replace("```json", "").replace("```", "")
|
json_blob = json_blob.replace("```json", "").replace("```", "")
|
||||||
tool_call = parse_json_blob(json_blob)
|
tool_call = parse_json_blob(json_blob)
|
||||||
if "action" in tool_call and "action_input" in tool_call:
|
if "action" in tool_call and "action_input" in tool_call:
|
||||||
|
@ -85,7 +87,25 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||||
elif "action" in tool_call:
|
elif "action" in tool_call:
|
||||||
return tool_call["action"], None
|
return tool_call["action"], None
|
||||||
else:
|
else:
|
||||||
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
|
missing_keys = [
|
||||||
|
key for key in ["action", "action_input"] if key not in tool_call
|
||||||
|
]
|
||||||
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
||||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_content(
|
||||||
|
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
|
||||||
|
) -> str:
|
||||||
|
if len(content) <= max_length:
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
|
||||||
|
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
||||||
|
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||||
|
)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,271 @@
|
||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile requirements.in -o requirements.txt
|
||||||
|
aiofiles==23.2.1
|
||||||
|
# via gradio
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.7.0
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# httpx
|
||||||
|
# starlette
|
||||||
|
appnope==0.1.4
|
||||||
|
# via ipykernel
|
||||||
|
asttokens==3.0.0
|
||||||
|
# via stack-data
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
# via markdownify
|
||||||
|
certifi==2024.8.30
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
charset-normalizer==3.4.0
|
||||||
|
# via requests
|
||||||
|
click==8.1.7
|
||||||
|
# via
|
||||||
|
# duckduckgo-search
|
||||||
|
# typer
|
||||||
|
# uvicorn
|
||||||
|
comm==0.2.2
|
||||||
|
# via ipykernel
|
||||||
|
debugpy==1.8.10
|
||||||
|
# via ipykernel
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
|
diskcache==5.6.3
|
||||||
|
# via llama-cpp-python
|
||||||
|
duckduckgo-search==6.4.1
|
||||||
|
# via -r requirements.in
|
||||||
|
executing==2.1.0
|
||||||
|
# via stack-data
|
||||||
|
fastapi==0.115.6
|
||||||
|
# via gradio
|
||||||
|
ffmpy==0.4.0
|
||||||
|
# via gradio
|
||||||
|
filelock==3.16.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
fsspec==2024.10.0
|
||||||
|
# via
|
||||||
|
# gradio-client
|
||||||
|
# huggingface-hub
|
||||||
|
gradio==5.8.0
|
||||||
|
# via -r requirements.in
|
||||||
|
gradio-client==1.5.1
|
||||||
|
# via gradio
|
||||||
|
h11==0.14.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
|
httpcore==1.0.7
|
||||||
|
# via httpx
|
||||||
|
httpx==0.28.1
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# gradio-client
|
||||||
|
# safehttpx
|
||||||
|
huggingface-hub==0.26.5
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# gradio-client
|
||||||
|
# tokenizers
|
||||||
|
# transformers
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
iniconfig==2.0.0
|
||||||
|
# via pytest
|
||||||
|
ipykernel==6.29.5
|
||||||
|
# via -r requirements.in
|
||||||
|
ipython==8.30.0
|
||||||
|
# via ipykernel
|
||||||
|
jedi==0.19.2
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.4
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# gradio
|
||||||
|
# llama-cpp-python
|
||||||
|
jupyter-client==8.6.3
|
||||||
|
# via ipykernel
|
||||||
|
jupyter-core==5.7.2
|
||||||
|
# via
|
||||||
|
# ipykernel
|
||||||
|
# jupyter-client
|
||||||
|
llama-cpp-python==0.3.5
|
||||||
|
# via -r requirements.in
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
|
markdownify==0.14.1
|
||||||
|
# via -r requirements.in
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# jinja2
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via
|
||||||
|
# ipykernel
|
||||||
|
# ipython
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
# via ipykernel
|
||||||
|
numpy==2.2.0
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# llama-cpp-python
|
||||||
|
# pandas
|
||||||
|
# transformers
|
||||||
|
orjson==3.10.12
|
||||||
|
# via gradio
|
||||||
|
packaging==24.2
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# gradio-client
|
||||||
|
# huggingface-hub
|
||||||
|
# ipykernel
|
||||||
|
# pytest
|
||||||
|
# transformers
|
||||||
|
pandas==2.2.3
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# gradio
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
|
pexpect==4.9.0
|
||||||
|
# via ipython
|
||||||
|
pillow==11.0.0
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# gradio
|
||||||
|
platformdirs==4.3.6
|
||||||
|
# via jupyter-core
|
||||||
|
pluggy==1.5.0
|
||||||
|
# via pytest
|
||||||
|
primp==0.8.3
|
||||||
|
# via duckduckgo-search
|
||||||
|
prompt-toolkit==3.0.48
|
||||||
|
# via ipython
|
||||||
|
psutil==6.1.0
|
||||||
|
# via ipykernel
|
||||||
|
ptyprocess==0.7.0
|
||||||
|
# via pexpect
|
||||||
|
pure-eval==0.2.3
|
||||||
|
# via stack-data
|
||||||
|
pydantic==2.10.3
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# gradio
|
||||||
|
pydantic-core==2.27.1
|
||||||
|
# via pydantic
|
||||||
|
pydub==0.25.1
|
||||||
|
# via gradio
|
||||||
|
pygments==2.18.0
|
||||||
|
# via
|
||||||
|
# ipython
|
||||||
|
# rich
|
||||||
|
pytest==8.3.4
|
||||||
|
# via -r requirements.in
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# jupyter-client
|
||||||
|
# pandas
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via -r requirements.in
|
||||||
|
python-multipart==0.0.19
|
||||||
|
# via gradio
|
||||||
|
pytz==2024.2
|
||||||
|
# via pandas
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# gradio
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
pyzmq==26.2.0
|
||||||
|
# via
|
||||||
|
# ipykernel
|
||||||
|
# jupyter-client
|
||||||
|
regex==2024.11.6
|
||||||
|
# via transformers
|
||||||
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
rich==13.9.4
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# typer
|
||||||
|
ruff==0.8.2
|
||||||
|
# via gradio
|
||||||
|
safehttpx==0.1.6
|
||||||
|
# via gradio
|
||||||
|
safetensors==0.4.5
|
||||||
|
# via transformers
|
||||||
|
semantic-version==2.10.0
|
||||||
|
# via gradio
|
||||||
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
|
six==1.17.0
|
||||||
|
# via
|
||||||
|
# markdownify
|
||||||
|
# python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
soupsieve==2.6
|
||||||
|
# via beautifulsoup4
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
|
starlette==0.41.3
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# gradio
|
||||||
|
tokenizers==0.21.0
|
||||||
|
# via transformers
|
||||||
|
tomlkit==0.13.2
|
||||||
|
# via gradio
|
||||||
|
tornado==6.4.2
|
||||||
|
# via
|
||||||
|
# ipykernel
|
||||||
|
# jupyter-client
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via
|
||||||
|
# comm
|
||||||
|
# ipykernel
|
||||||
|
# ipython
|
||||||
|
# jupyter-client
|
||||||
|
# jupyter-core
|
||||||
|
# matplotlib-inline
|
||||||
|
transformers==4.47.0
|
||||||
|
# via -r requirements.in
|
||||||
|
typer==0.15.1
|
||||||
|
# via gradio
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# fastapi
|
||||||
|
# gradio
|
||||||
|
# gradio-client
|
||||||
|
# huggingface-hub
|
||||||
|
# llama-cpp-python
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# typer
|
||||||
|
tzdata==2024.2
|
||||||
|
# via pandas
|
||||||
|
urllib3==2.2.3
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.32.1
|
||||||
|
# via gradio
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via prompt-toolkit
|
||||||
|
websockets==14.1
|
||||||
|
# via gradio-client
|
122
setup.py
122
setup.py
|
@ -1,122 +0,0 @@
|
||||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
|
|
||||||
extras = {}
|
|
||||||
extras["quality"] = [
|
|
||||||
"black ~= 23.1", # hf-doc-builder has a hidden dependency on `black`
|
|
||||||
"hf-doc-builder >= 0.3.0",
|
|
||||||
"ruff ~= 0.6.4",
|
|
||||||
]
|
|
||||||
extras["docs"] = []
|
|
||||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
|
|
||||||
extras["test_dev"] = [
|
|
||||||
"datasets",
|
|
||||||
"diffusers",
|
|
||||||
"evaluate",
|
|
||||||
"torchdata>=0.8.0",
|
|
||||||
"torchpippy>=0.2.0",
|
|
||||||
"transformers",
|
|
||||||
"scipy",
|
|
||||||
"scikit-learn",
|
|
||||||
"tqdm",
|
|
||||||
"bitsandbytes",
|
|
||||||
"timm",
|
|
||||||
]
|
|
||||||
extras["testing"] = extras["test_prod"] + extras["test_dev"]
|
|
||||||
extras["deepspeed"] = ["deepspeed"]
|
|
||||||
extras["rich"] = ["rich"]
|
|
||||||
|
|
||||||
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
|
|
||||||
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
|
|
||||||
|
|
||||||
extras["sagemaker"] = [
|
|
||||||
"sagemaker", # boto3 is a required package in sagemaker
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="accelerate",
|
|
||||||
version="1.2.0.dev0",
|
|
||||||
description="Accelerate",
|
|
||||||
long_description=open("README.md", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
keywords="deep learning",
|
|
||||||
license="Apache",
|
|
||||||
author="The HuggingFace team",
|
|
||||||
author_email="zach.mueller@huggingface.co",
|
|
||||||
url="https://github.com/huggingface/accelerate",
|
|
||||||
package_dir={"": "src"},
|
|
||||||
packages=find_packages("src"),
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": [
|
|
||||||
"accelerate=accelerate.commands.accelerate_cli:main",
|
|
||||||
"accelerate-config=accelerate.commands.config:main",
|
|
||||||
"accelerate-estimate-memory=accelerate.commands.estimate:main",
|
|
||||||
"accelerate-launch=accelerate.commands.launch:main",
|
|
||||||
"accelerate-merge-weights=accelerate.commands.merge:main",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
python_requires=">=3.9.0",
|
|
||||||
install_requires=[
|
|
||||||
"numpy>=1.17,<3.0.0",
|
|
||||||
"packaging>=20.0",
|
|
||||||
"psutil",
|
|
||||||
"pyyaml",
|
|
||||||
"torch>=1.10.0",
|
|
||||||
"huggingface_hub>=0.21.0",
|
|
||||||
"safetensors>=0.4.3",
|
|
||||||
],
|
|
||||||
extras_require=extras,
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 5 - Production/Stable",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Release checklist
|
|
||||||
# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):
|
|
||||||
# git checkout -b vXX.xx-release
|
|
||||||
# The -b is only necessary for creation (so remove it when doing a patch)
|
|
||||||
# 2. Change the version in __init__.py and setup.py to the proper value.
|
|
||||||
# 3. Commit these changes with the message: "Release: v<VERSION>"
|
|
||||||
# 4. Add a tag in git to mark the release:
|
|
||||||
# git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi'
|
|
||||||
# Push the tag and release commit to git: git push --tags origin vXX.xx-release
|
|
||||||
# 5. Run the following commands in the top-level directory:
|
|
||||||
# rm -rf dist
|
|
||||||
# rm -rf build
|
|
||||||
# python setup.py bdist_wheel
|
|
||||||
# python setup.py sdist
|
|
||||||
# 6. Upload the package to the pypi test server first:
|
|
||||||
# twine upload dist/* -r testpypi
|
|
||||||
# 7. Check that you can install it in a virtualenv by running:
|
|
||||||
# pip install accelerate
|
|
||||||
# pip uninstall accelerate
|
|
||||||
# pip install -i https://testpypi.python.org/pypi accelerate
|
|
||||||
# accelerate env
|
|
||||||
# accelerate test
|
|
||||||
# 8. Upload the final version to actual pypi:
|
|
||||||
# twine upload dist/* -r pypi
|
|
||||||
# 9. Add release notes to the tag in github once everything is looking hunky-dory.
|
|
||||||
# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to
|
|
||||||
# main.
|
|
|
@ -19,19 +19,25 @@ import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
from transformers.testing_utils import (
|
||||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
get_tests_dir,
|
||||||
|
require_soundfile,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
is_soundfile_availble,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if is_soundfile_availble():
|
if is_soundfile_availble():
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
if is_vision_available():
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
directory = tempfile.mkdtemp()
|
directory = tempfile.mkdtemp()
|
||||||
|
|
|
@ -19,17 +19,16 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers.agents.agent_types import AgentText
|
from agents.agent_types import AgentText
|
||||||
from transformers.agents.agents import (
|
from agents.agents import (
|
||||||
AgentMaxIterationsError,
|
AgentMaxIterationsError,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
ReactCodeAgent,
|
CodeAgent,
|
||||||
ReactJsonAgent,
|
JsonAgent,
|
||||||
Toolbox,
|
Toolbox,
|
||||||
)
|
)
|
||||||
from transformers.agents.default_tools import PythonInterpreterTool
|
from agents.default_tools import PythonInterpreterTool
|
||||||
from transformers.testing_utils import require_torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
|
@ -149,19 +148,26 @@ print(result)
|
||||||
|
|
||||||
class AgentTests(unittest.TestCase):
|
class AgentTests(unittest.TestCase):
|
||||||
def test_fake_code_agent(self):
|
def test_fake_code_agent(self):
|
||||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
agent = CodeAgent(
|
||||||
|
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
|
||||||
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
|
|
||||||
def test_fake_react_json_agent(self):
|
def test_fake_react_json_agent(self):
|
||||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
agent = JsonAgent(
|
||||||
|
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
|
||||||
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert output == "7.2904"
|
assert output == "7.2904"
|
||||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||||
assert agent.logs[1]["observation"] == "7.2904"
|
assert agent.logs[1]["observation"] == "7.2904"
|
||||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
assert (
|
||||||
|
agent.logs[1]["rationale"].strip()
|
||||||
|
== "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
agent.logs[2]["llm_output"]
|
agent.logs[2]["llm_output"]
|
||||||
== """
|
== """
|
||||||
|
@ -175,7 +181,9 @@ Action:
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fake_react_code_agent(self):
|
def test_fake_react_code_agent(self):
|
||||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
agent = CodeAgent(
|
||||||
|
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
|
||||||
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, float)
|
assert isinstance(output, float)
|
||||||
assert output == 7.2904
|
assert output == 7.2904
|
||||||
|
@ -186,17 +194,19 @@ Action:
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_react_code_agent_code_errors_show_offending_lines(self):
|
def test_react_code_agent_code_errors_show_offending_lines(self):
|
||||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
|
agent = CodeAgent(
|
||||||
|
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
|
||||||
|
)
|
||||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
assert isinstance(output, AgentText)
|
assert isinstance(output, AgentText)
|
||||||
assert output == "got an error"
|
assert output == "got an error"
|
||||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||||
|
|
||||||
def test_setup_agent_with_empty_toolbox(self):
|
def test_setup_agent_with_empty_toolbox(self):
|
||||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
JsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||||
|
|
||||||
def test_react_fails_max_iterations(self):
|
def test_react_fails_max_iterations(self):
|
||||||
agent = ReactCodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[PythonInterpreterTool()],
|
tools=[PythonInterpreterTool()],
|
||||||
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
||||||
max_iterations=5,
|
max_iterations=5,
|
||||||
|
@ -208,51 +218,62 @@ Action:
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 1
|
len(agent.toolbox.tools) == 1
|
||||||
) # when no tools are provided, only the final_answer tool is added by default
|
) # when no tools are provided, only the final_answer tool is added by default
|
||||||
|
|
||||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.toolbox.tools) == 2
|
||||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||||
|
|
||||||
toolset_3 = Toolbox(toolset_2)
|
toolset_3 = Toolbox(toolset_2)
|
||||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.toolbox.tools) == 2
|
||||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||||
|
|
||||||
# check that add_base_tools will not interfere with existing tools
|
# check that add_base_tools will not interfere with existing tools
|
||||||
with pytest.raises(KeyError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
agent = JsonAgent(
|
||||||
|
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
|
||||||
|
)
|
||||||
assert "already exists in the toolbox" in str(e)
|
assert "already exists in the toolbox" in str(e)
|
||||||
|
|
||||||
# check that python_interpreter base tool does not get added to code agents
|
# check that python_interpreter base tool does not get added to code agents
|
||||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||||
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
|
assert (
|
||||||
|
len(agent.toolbox.tools) == 7
|
||||||
|
) # added final_answer tool + 6 base tools (excluding interpreter)
|
||||||
|
|
||||||
def test_function_persistence_across_steps(self):
|
def test_function_persistence_across_steps(self):
|
||||||
agent = ReactCodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
|
tools=[],
|
||||||
|
llm_engine=fake_react_code_functiondef,
|
||||||
|
max_iterations=2,
|
||||||
|
additional_authorized_imports=["numpy"],
|
||||||
)
|
)
|
||||||
res = agent.run("ok")
|
res = agent.run("ok")
|
||||||
assert res[0] == 0.5
|
assert res[0] == 0.5
|
||||||
|
|
||||||
def test_init_managed_agent(self):
|
def test_init_managed_agent(self):
|
||||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
assert managed_agent.name == "managed_agent"
|
assert managed_agent.name == "managed_agent"
|
||||||
assert managed_agent.description == "Empty"
|
assert managed_agent.description == "Empty"
|
||||||
|
|
||||||
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
||||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||||
manager_agent = ReactCodeAgent(
|
manager_agent = CodeAgent(
|
||||||
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
|
tools=[],
|
||||||
|
llm_engine=fake_react_code_functiondef,
|
||||||
|
managed_agents=[managed_agent],
|
||||||
)
|
)
|
||||||
assert "You can also give requests to team members." not in agent.system_prompt
|
assert "You can also give requests to team members." not in agent.system_prompt
|
||||||
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
||||||
assert "You can also give requests to team members." in manager_agent.system_prompt
|
assert (
|
||||||
|
"You can also give requests to team members." in manager_agent.system_prompt
|
||||||
|
)
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from transformers import load_tool
|
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = load_tool("document_question_answering")
|
|
||||||
self.tool.setup()
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
|
||||||
document = dataset[0]["image"]
|
|
||||||
|
|
||||||
result = self.tool(document, "When is the coffee break?")
|
|
||||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
|
||||||
document = dataset[0]["image"]
|
|
||||||
|
|
||||||
self.tool(document=document, question="When is the coffee break?")
|
|
|
@ -87,7 +87,11 @@ class ExampleDifferenceTests(unittest.TestCase):
|
||||||
examples_path = Path("examples").resolve()
|
examples_path = Path("examples").resolve()
|
||||||
|
|
||||||
def one_complete_example(
|
def one_complete_example(
|
||||||
self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None
|
self,
|
||||||
|
complete_file_name: str,
|
||||||
|
parser_only: bool,
|
||||||
|
secondary_filename: str = None,
|
||||||
|
special_strings: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tests a single `complete` example against all of the implemented `by_feature` scripts
|
Tests a single `complete` example against all of the implemented `by_feature` scripts
|
||||||
|
@ -112,10 +116,15 @@ class ExampleDifferenceTests(unittest.TestCase):
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
tested_script=complete_file_name,
|
tested_script=complete_file_name,
|
||||||
feature_script=item,
|
feature_script=item,
|
||||||
tested_section="main()" if parser_only else "training_function()",
|
tested_section="main()"
|
||||||
|
if parser_only
|
||||||
|
else "training_function()",
|
||||||
):
|
):
|
||||||
diff = compare_against_test(
|
diff = compare_against_test(
|
||||||
self.examples_path / complete_file_name, item_path, parser_only, secondary_filename
|
self.examples_path / complete_file_name,
|
||||||
|
item_path,
|
||||||
|
parser_only,
|
||||||
|
secondary_filename,
|
||||||
)
|
)
|
||||||
diff = "\n".join(diff)
|
diff = "\n".join(diff)
|
||||||
if special_strings is not None:
|
if special_strings is not None:
|
||||||
|
@ -140,8 +149,12 @@ class ExampleDifferenceTests(unittest.TestCase):
|
||||||
" " * 12,
|
" " * 12,
|
||||||
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
|
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
|
||||||
]
|
]
|
||||||
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
|
self.one_complete_example(
|
||||||
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
|
"complete_cv_example.py", True, cv_path, special_strings
|
||||||
|
)
|
||||||
|
self.one_complete_example(
|
||||||
|
"complete_cv_example.py", False, cv_path, special_strings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
|
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})
|
||||||
|
|
|
@ -47,9 +47,9 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def create_inputs(self):
|
def create_inputs(self):
|
||||||
inputs_text = {"answer": "Text input"}
|
inputs_text = {"answer": "Text input"}
|
||||||
inputs_image = {
|
inputs_image = {
|
||||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
"answer": Image.open(
|
||||||
(512, 512)
|
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||||
)
|
).resize((512, 512))
|
||||||
}
|
}
|
||||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from transformers import is_vision_available, load_tool
|
|
||||||
from transformers.testing_utils import get_tests_dir
|
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = load_tool("image_question_answering")
|
|
||||||
self.tool.setup()
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
|
||||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
|
||||||
self.assertEqual(result, "2")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
|
||||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
|
||||||
self.assertEqual(result, "2")
|
|
|
@ -129,7 +129,9 @@ final_answer('This is the final answer.')
|
||||||
|
|
||||||
def test_streaming_agent_image_output(self):
|
def test_streaming_agent_image_output(self):
|
||||||
def dummy_llm_engine(prompt, **kwargs):
|
def dummy_llm_engine(prompt, **kwargs):
|
||||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
return (
|
||||||
|
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||||
|
)
|
||||||
|
|
||||||
agent = ReactJsonAgent(
|
agent = ReactJsonAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
@ -138,7 +140,14 @@ final_answer('This is the final answer.')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use stream_to_gradio to capture the output
|
# Use stream_to_gradio to capture the output
|
||||||
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
|
outputs = list(
|
||||||
|
stream_to_gradio(
|
||||||
|
agent,
|
||||||
|
task="Test task",
|
||||||
|
image=AgentImage(value="path.png"),
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), 2)
|
self.assertEqual(len(outputs), 2)
|
||||||
final_message = outputs[-1]
|
final_message = outputs[-1]
|
||||||
|
|
|
@ -21,7 +21,10 @@ import pytest
|
||||||
from transformers import load_tool
|
from transformers import load_tool
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
|
from transformers.agents.python_interpreter import (
|
||||||
|
InterpreterError,
|
||||||
|
evaluate_python_code,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
from .test_tools_common import ToolTesterMixin
|
||||||
|
|
||||||
|
@ -57,7 +60,12 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||||
input_type = expected_input["type"]
|
input_type = expected_input["type"]
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list):
|
||||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
_inputs.append(
|
||||||
|
[
|
||||||
|
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||||
|
for _input_type in input_type
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
|
@ -91,7 +99,10 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
code = "print = '3'"
|
code = "print = '3'"
|
||||||
with pytest.raises(InterpreterError) as e:
|
with pytest.raises(InterpreterError) as e:
|
||||||
evaluate_python_code(code, {"print": print}, state={})
|
evaluate_python_code(code, {"print": print}, state={})
|
||||||
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
assert (
|
||||||
|
"Cannot assign to name 'print': doing this would erase the existing tool!"
|
||||||
|
in str(e)
|
||||||
|
)
|
||||||
|
|
||||||
def test_evaluate_call(self):
|
def test_evaluate_call(self):
|
||||||
code = "y = add_two(x)"
|
code = "y = add_two(x)"
|
||||||
|
@ -117,7 +128,9 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
self.assertDictEqual(
|
||||||
|
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||||
|
)
|
||||||
|
|
||||||
def test_evaluate_expression(self):
|
def test_evaluate_expression(self):
|
||||||
code = "x = 3\ny = 5"
|
code = "x = 3\ny = 5"
|
||||||
|
@ -133,7 +146,9 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
result = evaluate_python_code(code, {}, state=state)
|
result = evaluate_python_code(code, {}, state=state)
|
||||||
# evaluate returns the value of the last assignment.
|
# evaluate returns the value of the last assignment.
|
||||||
assert result == "This is x: 3."
|
assert result == "This is x: 3."
|
||||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
self.assertDictEqual(
|
||||||
|
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
|
||||||
|
)
|
||||||
|
|
||||||
def test_evaluate_if(self):
|
def test_evaluate_if(self):
|
||||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||||
|
@ -174,11 +189,15 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
self.assertDictEqual(
|
||||||
|
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||||
|
)
|
||||||
|
|
||||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
evaluate_python_code(
|
||||||
|
code, {"min": min, "print": print, "round": round}, state=state
|
||||||
|
)
|
||||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||||
|
|
||||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||||
|
@ -292,7 +311,16 @@ print(check_digits)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(
|
evaluate_python_code(
|
||||||
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
|
code,
|
||||||
|
{
|
||||||
|
"range": range,
|
||||||
|
"print": print,
|
||||||
|
"sum": sum,
|
||||||
|
"enumerate": enumerate,
|
||||||
|
"int": int,
|
||||||
|
"str": str,
|
||||||
|
},
|
||||||
|
state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_listcomp(self):
|
def test_listcomp(self):
|
||||||
|
@ -325,7 +353,9 @@ print(check_digits)
|
||||||
assert result == {0: 0, 1: 1, 2: 4}
|
assert result == {0: 0, 1: 1, 2: 4}
|
||||||
|
|
||||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||||
|
)
|
||||||
assert result == {102: "b"}
|
assert result == {102: "b"}
|
||||||
|
|
||||||
code = """
|
code = """
|
||||||
|
@ -373,7 +403,9 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
result = evaluate_python_code(
|
||||||
|
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||||
|
)
|
||||||
assert result == "Brooklyn"
|
assert result == "Brooklyn"
|
||||||
|
|
||||||
code = """if d > e and a < b:
|
code = """if d > e and a < b:
|
||||||
|
@ -384,7 +416,9 @@ else:
|
||||||
best_city = "Manhattan"
|
best_city = "Manhattan"
|
||||||
best_city
|
best_city
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
result = evaluate_python_code(
|
||||||
|
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||||
|
)
|
||||||
assert result == "Sacramento"
|
assert result == "Sacramento"
|
||||||
|
|
||||||
def test_if_conditions(self):
|
def test_if_conditions(self):
|
||||||
|
@ -400,7 +434,9 @@ if char.isalpha():
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == 2.0
|
assert result == 2.0
|
||||||
|
|
||||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
code = (
|
||||||
|
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||||
|
)
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "lose"
|
assert result == "lose"
|
||||||
|
|
||||||
|
@ -434,10 +470,14 @@ if char.isalpha():
|
||||||
|
|
||||||
# Test submodules are handled properly, thus not raising error
|
# Test submodules are handled properly, thus not raising error
|
||||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
result = evaluate_python_code(
|
||||||
|
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||||
|
)
|
||||||
|
|
||||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
result = evaluate_python_code(
|
||||||
|
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||||
|
)
|
||||||
|
|
||||||
def test_additional_imports(self):
|
def test_additional_imports(self):
|
||||||
code = "import numpy as np"
|
code = "import numpy as np"
|
||||||
|
@ -554,7 +594,11 @@ cat_sound = cat.sound()
|
||||||
cat_str = str(cat)
|
cat_str = str(cat)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
evaluate_python_code(
|
||||||
|
code,
|
||||||
|
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
# Assert results
|
# Assert results
|
||||||
assert state["dog1_sound"] == "The dog barks."
|
assert state["dog1_sound"] == "The dog barks."
|
||||||
|
@ -588,7 +632,11 @@ except ValueError as e:
|
||||||
exception_message = str(e)
|
exception_message = str(e)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
evaluate_python_code(
|
||||||
|
code,
|
||||||
|
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
assert state["exception_message"] == "An error occurred"
|
assert state["exception_message"] == "An error occurred"
|
||||||
|
|
||||||
def test_print(self):
|
def test_print(self):
|
||||||
|
@ -600,7 +648,9 @@ except ValueError as e:
|
||||||
def test_types_as_objects(self):
|
def test_types_as_objects(self):
|
||||||
code = "type_a = float(2); type_b = str; type_c = int"
|
code = "type_a = float(2); type_b = str; type_c = int"
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
result = evaluate_python_code(
|
||||||
|
code, {"float": float, "str": str, "int": int}, state=state
|
||||||
|
)
|
||||||
assert result is int
|
assert result is int
|
||||||
|
|
||||||
def test_tuple_id(self):
|
def test_tuple_id(self):
|
||||||
|
@ -731,7 +781,9 @@ def add_one(n, shift):
|
||||||
add_one(1, 1)
|
add_one(1, 1)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||||
|
)
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
# test returning None
|
# test returning None
|
||||||
|
@ -742,7 +794,9 @@ def returns_none(a):
|
||||||
returns_none(1)
|
returns_none(1)
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
|
||||||
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_nested_for_loop(self):
|
def test_nested_for_loop(self):
|
||||||
|
@ -758,7 +812,9 @@ out = [i for sublist in all_res for i in sublist]
|
||||||
out[:10]
|
out[:10]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print, "range": range}, state=state
|
||||||
|
)
|
||||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||||
|
|
||||||
def test_pandas(self):
|
def test_pandas(self):
|
||||||
|
@ -773,7 +829,9 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
result = evaluate_python_code(
|
||||||
|
code, {}, state=state, authorized_imports=["pandas"]
|
||||||
|
)
|
||||||
assert np.array_equal(result, [-1, 5])
|
assert np.array_equal(result, [-1, 5])
|
||||||
|
|
||||||
code = """
|
code = """
|
||||||
|
@ -785,7 +843,9 @@ print("HH0")
|
||||||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||||
|
)
|
||||||
assert np.array_equal(result.values[0], [104, 1])
|
assert np.array_equal(result.values[0], [104, 1])
|
||||||
|
|
||||||
code = """import pandas as pd
|
code = """import pandas as pd
|
||||||
|
@ -818,7 +878,9 @@ coords_barcelona = (41.3869, 2.1660)
|
||||||
|
|
||||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||||
"""
|
"""
|
||||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
result = evaluate_python_code(
|
||||||
|
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
||||||
|
)
|
||||||
assert round(result, 1) == 622395.4
|
assert round(result, 1) == 622395.4
|
||||||
|
|
||||||
def test_for(self):
|
def test_for(self):
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from transformers import load_tool
|
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = load_tool("speech_to_text")
|
|
||||||
self.tool.setup()
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
result = self.tool(np.ones(3000))
|
|
||||||
self.assertEqual(result, " Thank you.")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
result = self.tool(audio=np.ones(3000))
|
|
||||||
self.assertEqual(result, " Thank you.")
|
|
|
@ -1,50 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers import load_tool
|
|
||||||
from transformers.utils import is_torch_available
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch
|
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
|
||||||
def setUp(self):
|
|
||||||
self.tool = load_tool("text_to_speech")
|
|
||||||
self.tool.setup()
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
|
||||||
# SpeechT5 isn't deterministic
|
|
||||||
torch.manual_seed(0)
|
|
||||||
result = self.tool("hey")
|
|
||||||
resulting_tensor = result.to_raw()
|
|
||||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
|
||||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
|
||||||
# SpeechT5 isn't deterministic
|
|
||||||
torch.manual_seed(0)
|
|
||||||
result = self.tool("hey")
|
|
||||||
resulting_tensor = result.to_raw()
|
|
||||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
|
||||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
|
|
@ -20,7 +20,12 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
from transformers.agents.agent_types import (
|
||||||
|
AGENT_TYPE_MAPPING,
|
||||||
|
AgentAudio,
|
||||||
|
AgentImage,
|
||||||
|
AgentText,
|
||||||
|
)
|
||||||
from transformers.agents.tools import Tool, tool
|
from transformers.agents.tools import Tool, tool
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,9 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
self.assertEqual(result, "- Hé, comment ça va?")
|
self.assertEqual(result, "- Hé, comment ça va?")
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
def test_exact_match_kwarg(self):
|
||||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
result = self.tool(
|
||||||
|
text="Hey, what's up?", src_lang="English", tgt_lang="French"
|
||||||
|
)
|
||||||
self.assertEqual(result, "- Hé, comment ça va?")
|
self.assertEqual(result, "- Hé, comment ça va?")
|
||||||
|
|
||||||
def test_call(self):
|
def test_call(self):
|
||||||
|
|
Loading…
Reference in New Issue