Ruff formatting

This commit is contained in:
Aymeric 2024-12-11 16:16:18 +01:00
parent 851e177e71
commit 67deb6808f
28 changed files with 1153 additions and 3460 deletions

View File

@ -24,10 +24,25 @@ from transformers.utils import (
_import_structure = {
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"],
"agents": [
"Agent",
"CodeAgent",
"ManagedAgent",
"ReactAgent",
"CodeAgent",
"JsonAgent",
"Toolbox",
],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
"tools": [
"PipelineTool",
"Tool",
"ToolCollection",
"launch_gradio_demo",
"load_tool",
"tool",
],
}
try:
@ -45,10 +60,25 @@ else:
_import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox
from .agents import (
Agent,
CodeAgent,
ManagedAgent,
ReactAgent,
CodeAgent,
JsonAgent,
Toolbox,
)
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
from .tools import (
PipelineTool,
Tool,
ToolCollection,
launch_gradio_demo,
load_tool,
tool,
)
try:
if not is_torch_available():
@ -66,4 +96,6 @@ if TYPE_CHECKING:
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
)

View File

@ -19,7 +19,11 @@ import uuid
import numpy as np
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
import logging
@ -108,7 +112,9 @@ class AgentImage(AgentType, ImageType):
elif isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value)
else:
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
raise TypeError(
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
)
def _ipython_display_(self, include=None, exclude=None):
"""
@ -159,7 +165,7 @@ class AgentImage(AgentType, ImageType):
return self._path
def save(self, output_bytes, format : str = None, **params):
def save(self, output_bytes, format: str = None, **params):
"""
Saves the image to a file.
Args:
@ -243,7 +249,9 @@ if is_torch_available():
def handle_agent_inputs(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
kwargs = {
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
}
return args, kwargs

View File

@ -15,12 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass
from rich.syntax import Syntax
from langfuse.decorators import langfuse_context, observe
from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
@ -51,6 +49,7 @@ from .tools import (
HUGGINGFACE_DEFAULT_TOOLS = {}
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
@ -60,7 +59,6 @@ class AgentError(Exception):
console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError):
"""Exception raised for errors in parsing in the agent"""
@ -84,12 +82,14 @@ class AgentGenerationError(AgentError):
pass
class AgentStep:
pass
@dataclass
class ActionStep(AgentStep):
tool_call: str | None = None
tool_call: Dict[str, str] | None = None
start_time: float | None = None
step_end_time: float | None = None
iteration: int | None = None
@ -97,32 +97,43 @@ class ActionStep(AgentStep):
error: AgentError | None = None
step_duration: float | None = None
llm_output: str | None = None
observation: str | None = None
agent_memory: List[Dict[str, str]] | None = None
rationale: str | None = None
@dataclass
class PlanningStep(AgentStep):
plan: str
facts: str
@dataclass
class TaskStep(AgentStep):
task: str
@dataclass
class SystemPromptStep(AgentStep):
system_prompt: str
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
def format_prompt_with_tools(
toolbox: Toolbox, prompt_template: str, tool_description_template: str
) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "{{tool_names}}" in prompt:
prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]))
prompt = prompt.replace(
"{{tool_names}}",
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
)
return prompt
def show_agents_descriptions(managed_agents: list):
def show_agents_descriptions(managed_agents: Dict):
managed_agents_descriptions = """
You can also give requests to team members.
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
@ -133,16 +144,24 @@ Here is a list of the team members that you can call:"""
return managed_agents_descriptions
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
def format_prompt_with_managed_agents_descriptions(
prompt_template, managed_agents=None
) -> str:
if managed_agents is not None:
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
return prompt_template.replace(
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)
)
else:
return prompt_template.replace("<<managed_agents_descriptions>>", "")
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
def format_prompt_with_imports(
prompt_template: str, authorized_imports: List[str]
) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
raise AgentError(
"Tag '<<authorized_imports>>' should be provided in the prompt."
)
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
@ -150,7 +169,7 @@ class BaseAgent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = None,
llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
additional_args: Dict = {},
@ -159,10 +178,12 @@ class BaseAgent:
add_base_tools: bool = False,
verbose: bool = False,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
managed_agents: Optional[Dict] = None,
step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
if tool_parser is None:
@ -171,14 +192,16 @@ class BaseAgent:
self.llm_engine = llm_engine
self.system_prompt_template = system_prompt
self.tool_description_template = (
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
tool_description_template
if tool_description_template
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
)
self.additional_args = additional_args
self.max_iterations = max_iterations
self.tool_parser = tool_parser
self.grammar = grammar
self.managed_agents = None
self.managed_agents = {}
if managed_agents is not None:
self.managed_agents = {agent.name: agent for agent in managed_agents}
@ -186,9 +209,13 @@ class BaseAgent:
self._toolbox = tools
if add_base_tools:
if not is_torch_available():
raise ImportError("Using the base tools requires torch to be installed.")
raise ImportError(
"Using the base tools requires torch to be installed."
)
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent))
self._toolbox.add_base_tools(
add_python_interpreter=(self.__class__ == JsonAgent)
)
else:
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
self._toolbox.add_tool(FinalAnswerTool())
@ -196,7 +223,9 @@ class BaseAgent:
self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
self.prompt_messages = None
self.logs = []
self.task = None
@ -222,15 +251,20 @@ class BaseAgent:
self.system_prompt_template,
self.tool_description_template,
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
self.system_prompt,
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
)
return self.system_prompt
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
def write_inner_memory_from_logs(
self, summary_mode: Optional[bool] = False
) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM.
@ -253,7 +287,10 @@ class BaseAgent:
memory.append(thought_message)
if not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log.plan.strip()}
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[PLAN]:\n" + step_log.plan.strip(),
}
memory.append(thought_message)
elif isinstance(step_log, TaskStep):
@ -265,13 +302,17 @@ class BaseAgent:
elif isinstance(step_log, ActionStep):
if step_log.llm_output is not None and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log.llm_output.strip()}
thought_message = {
"role": MessageRole.ASSISTANT,
"content": step_log.llm_output.strip(),
}
memory.append(thought_message)
if step_log.tool_call is not None and summary_mode:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log.tool_call).strip(),
"content": f"[STEP {i} TOOL CALL]: "
+ str(step_log.tool_call).strip(),
}
memory.append(tool_call_message)
@ -284,15 +325,21 @@ class BaseAgent:
)
elif step_log.observation is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": message_content,
}
memory.append(tool_response_message)
return memory
def get_succinct_logs(self):
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
return [
{key: value for key, value in log.items() if key != "agent_memory"}
for log in self.logs
]
def extract_action(self, llm_output: str, split_token: str) -> str:
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
"""
Parse action from the LLM output
@ -312,54 +359,6 @@ class BaseAgent:
)
return rationale.strip(), action.strip()
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
try:
if isinstance(arguments, str):
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
tool_description = get_tool_description_with_args(available_tools[tool_name])
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def run(self, **kwargs):
"""To be implemented in the child class"""
raise NotImplementedError
@ -382,8 +381,6 @@ class ReactAgent(BaseAgent):
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None:
@ -423,8 +420,67 @@ class ReactAgent(BaseAgent):
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
@observe
def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs):
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
try:
if isinstance(arguments, str):
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
tool_description = get_tool_description_with_args(
available_tools[tool_name]
)
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep):
"""To be implemented in children classes"""
pass
def run(
self,
task: str,
stream: bool = False,
reset: bool = True,
oneshot: bool = False,
**kwargs,
):
"""
Runs the agent for the given task.
@ -441,10 +497,11 @@ class ReactAgent(BaseAgent):
agent.run("What is the result of 2 power 3.7384?")
```
"""
print("LANGFUSE REF:", langfuse_context.get_current_trace_url())
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.task += (
f"\nYou have been provided with these initial arguments: {str(kwargs)}."
)
self.state = kwargs.copy()
self.initialize_system_prompt()
@ -460,7 +517,7 @@ class ReactAgent(BaseAgent):
else:
self.logs.append(system_prompt_step)
console.rule("[bold]New task", characters='=')
console.rule("[bold]New task", characters="=")
console.print(self.task)
self.logs.append(TaskStep(task=task))
@ -489,8 +546,13 @@ class ReactAgent(BaseAgent):
step_start_time = time.time()
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
if (
self.planning_interval is not None
and iteration % self.planning_interval == 0
):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None:
@ -530,8 +592,13 @@ class ReactAgent(BaseAgent):
step_start_time = time.time()
step_log = ActionStep(iteration=iteration, start_time=step_start_time)
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
if (
self.planning_interval is not None
and iteration % self.planning_interval == 0
):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step")
self.step(step_log)
if step_log.final_answer is not None:
@ -559,7 +626,7 @@ class ReactAgent(BaseAgent):
return final_answer
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
def planning_step(self, task, is_first_step: bool, iteration: int):
"""
Used periodically by the agent to plan the next steps to reach the objective.
@ -569,7 +636,10 @@ class ReactAgent(BaseAgent):
iteration (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS,
}
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
@ -589,15 +659,20 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
tool_descriptions=self._toolbox.show_tool_descriptions(
self.tool_description_template
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
),
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
)
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
@ -608,10 +683,12 @@ Now begin!""",
```
{answer_facts}
```""".strip()
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Initial plan")
console.print(final_plan_redaction)
else: # update plan
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
) # This will not log the plan but will log facts
@ -625,7 +702,9 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
facts_update = self.llm_engine(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
# Redact updated plan
plan_update_message = {
@ -636,25 +715,34 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
tool_descriptions=self._toolbox.show_tool_descriptions(
self.tool_description_template
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
[plan_update_message] + agent_memory + [plan_update_message_user],
stop_sequences=["<end_plan>"],
)
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
task=task, plan_update=plan_update
)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Updated plan")
console.print(final_plan_redaction)
@ -705,14 +793,20 @@ class JsonAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule("[italic]Calling LLM engine with this last message:", align="left")
console.rule(
"[italic]Calling LLM engine with this last message:", align="left"
)
console.print(self.prompt_messages[-1])
console.rule()
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine(
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
self.prompt_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
)
log_entry.llm_output = llm_output
except Exception as e:
@ -721,9 +815,11 @@ class JsonAgent(ReactAgent):
if self.verbose:
console.rule("[italic]Output message of the LLM:")
console.print(llm_output)
# Parse
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
rationale, action = self.extract_action(
llm_output=llm_output, split_token="Action:"
)
try:
tool_name, arguments = self.tool_parser(action)
@ -807,12 +903,18 @@ class CodeAgent(ReactAgent):
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
)
self.system_prompt = self.system_prompt.replace(
"<<authorized_imports>>", str(self.authorized_imports)
)
self.custom_tools = {}
def step(self, log_entry: Dict[str, Any]):
def step(self, log_entry: ActionStep):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
@ -825,14 +927,20 @@ class CodeAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy()
if self.verbose:
console.rule("[italic]Calling LLM engine with these last messages:", align="left")
console.rule(
"[italic]Calling LLM engine with these last messages:", align="left"
)
console.print(self.prompt_messages[-2:])
console.rule()
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine(
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args
self.prompt_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
)
log_entry.llm_output = llm_output
except Exception as e:
@ -840,13 +948,19 @@ class CodeAgent(ReactAgent):
if self.verbose:
console.rule("[italic]Output message of the LLM:")
console.print(Syntax(llm_output, lexer='markdown', background_color='default'))
console.print(
Syntax(llm_output, lexer="markdown", background_color="default")
)
# Parse
try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
rationale, raw_code_action = self.extract_action(
llm_output=llm_output, split_token="Code:"
)
except Exception as e:
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
console.print(
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
)
rationale, raw_code_action = llm_output, llm_output
try:
@ -856,14 +970,17 @@ class CodeAgent(ReactAgent):
raise AgentParsingError(error_msg)
log_entry.rationale = rationale
log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action}
log_entry.tool_call = {
"tool_name": "code interpreter",
"tool_arguments": code_action,
}
# Execute
if self.verbose:
console.rule("[italic]Agent thoughts")
console.print(rationale)
console.rule("[bold]Agent is executing the code below:", align="left")
console.print(Syntax(code_action, lexer='python', background_color='default'))
console.print(Syntax(code_action, lexer="python", background_color="default"))
console.rule("", align="left")
try:
@ -886,7 +1003,9 @@ class CodeAgent(ReactAgent):
if result is not None:
console.rule("Last output from code snippet:", align="left")
console.print(str(result))
observation += "Last output from code snippet:\n" + truncate_content(str(result))
observation += "Last output from code snippet:\n" + truncate_content(
str(result)
)
log_entry.observation = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
@ -902,7 +1021,14 @@ class CodeAgent(ReactAgent):
class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
def __init__(
self,
agent,
name,
description,
additional_prompting=None,
provide_run_summary=False,
):
self.agent = agent
self.name = name
self.description = description
@ -925,18 +1051,22 @@ Your final_answer WILL HAVE to contain these parts:
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
<<additional_prompting>>"""
{{additional_prompting}}"""
if self.additional_prompting:
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
full_task = full_task.replace(
"\n{{additional_prompting}}", self.additional_prompting
).strip()
else:
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
return full_task
def __call__(self, request, **kwargs):
full_task = self.write_full_task(request)
output = self.agent.run(full_task, **kwargs)
if self.provide_run_summary:
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
answer = (
f"Here is the final answer from your managed agent '{self.name}':\n"
)
answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):

View File

@ -105,7 +105,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
resolved_config_file = hf_hub_download(
repo_id, TOOL_CONFIG_FILE, repo_type="space"
)
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
@ -131,7 +133,9 @@ class PythonInterpreterTool(Tool):
if authorized_imports is None:
self.authorized_imports = list(set(LIST_SAFE_MODULES))
else:
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(authorized_imports)
)
self.inputs = {
"code": {
"type": "string",
@ -145,7 +149,11 @@ class PythonInterpreterTool(Tool):
def forward(self, code):
output = str(
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
evaluate_python_code(
code,
static_tools=BASE_PYTHON_TOOLS,
authorized_imports=self.authorized_imports,
)
)
return output
@ -153,16 +161,21 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
inputs = {
"answer": {"type": "any", "description": "The final answer to the problem"}
}
output_type = "any"
def forward(self, answer):
return answer
class UserInputTool(Tool):
name = "user_input"
description = "Asks for user's input on a specific question"
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
inputs = {
"question": {"type": "string", "description": "The question to ask the user"}
}
output_type = "string"
def forward(self, question):

View File

@ -18,6 +18,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import BaseAgent, AgentStep, ActionStep
import gradio as gr
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
@ -33,7 +34,9 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
content=str(content),
)
if step_log.observation is not None:
yield gr.ChatMessage(role="assistant", content=f"```\n{step_log.observation}\n```")
yield gr.ChatMessage(
role="assistant", content=f"```\n{step_log.observation}\n```"
)
if step_log.error is not None:
yield gr.ChatMessage(
role="assistant",
@ -42,7 +45,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
)
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs):
def stream_to_gradio(
agent,
task: str,
test_mode: bool = False,
reset_agent_memory: bool = False,
**kwargs,
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
@ -52,7 +61,10 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
final_answer = step_log # Last log is the run's final_answer
if isinstance(final_answer, AgentText):
yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
yield gr.ChatMessage(
role="assistant",
content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```",
)
elif isinstance(final_answer, AgentImage):
yield gr.ChatMessage(
role="assistant",
@ -67,10 +79,11 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
yield gr.ChatMessage(role="assistant", content=str(final_answer))
class GradioUI():
class GradioUI:
"""A one-line interface to launch your agent in Gradio"""
def __init__(self, agent: BaseAgent):
self.agent = agent
self.agent = agent
def interact_with_agent(self, prompt, messages):
messages.append(gr.ChatMessage(role="user", content=prompt))
@ -83,10 +96,17 @@ class GradioUI():
def run(self):
with gr.Blocks() as demo:
stored_message = gr.State([])
chatbot = gr.Chatbot(label="Agent",
type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
demo.launch()
demo.launch()

View File

@ -39,7 +39,9 @@ class MessageRole(str, Enum):
return [r.value for r in cls]
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
):
"""
Subsequent messages with the same role will be concatenated to a single message.
@ -54,12 +56,17 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions:
message["role"] = role_conversions[role]
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
if (
len(final_message_list) > 0
and message["role"] == final_message_list[-1]["role"]
):
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else:
final_message_list.append(message)
@ -81,8 +88,12 @@ class HfEngine:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
logger.warning(
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
)
self.tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
)
def get_token_counts(self):
return {
@ -91,12 +102,18 @@ class HfEngine:
}
def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
):
raise NotImplementedError
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
) -> str:
"""Process the input messages and return the model's response.
@ -127,11 +144,15 @@ class HfEngine:
```
"""
if not isinstance(messages, List):
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
raise ValueError(
"Messages should be a list of dictionaries with 'role' and 'content' keys."
)
if stop_sequences is None:
stop_sequences = []
response = self.generate(messages, stop_sequences, grammar)
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
self.last_input_token_count = len(
self.tokenizer.apply_chat_template(messages, tokenize=True)
)
self.last_output_token_count = len(self.tokenizer.encode(response))
# Remove stop sequences from LLM output
@ -175,18 +196,28 @@ class HfApiEngine(HfEngine):
self.max_tokens = max_tokens
def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
# Send messages to the Hugging Face Inference API
if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
messages,
stop=stop_sequences,
max_tokens=self.max_tokens,
response_format=grammar,
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=self.max_tokens
)
response = response.choices[0].message.content
return response
@ -207,7 +238,9 @@ class TransformersEngine(HfEngine):
max_length: int = 1500,
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
# Get LLM output
if stop_sequences is not None and len(stop_sequences) > 0:

View File

@ -17,12 +17,14 @@
from .utils import console
class Monitor:
def __init__(self, tracked_llm_engine):
self.step_durations = []
self.tracked_llm_engine = tracked_llm_engine
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
if (
getattr(self.tracked_llm_engine, "last_input_token_count", "Not found")
!= "Not found"
):
self.total_input_token_count = 0
self.total_output_token_count = 0
@ -33,103 +35,11 @@ class Monitor:
console.print(f"- Time taken: {step_duration:.2f} seconds")
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
self.total_input_token_count += (
self.tracked_llm_engine.last_input_token_count
)
self.total_output_token_count += (
self.tracked_llm_engine.last_output_token_count
)
console.print(f"- Input tokens: {self.total_input_token_count:,}")
console.print(f"- Output tokens: {self.total_output_token_count:,}")
from typing import Optional, Union, List, Any
import httpx
import logging
import os
from langfuse.client import Langfuse, StatefulTraceClient, StatefulSpanClient, StateType
class BaseTracker:
def __init__(self):
pass
@classmethod
def call(cls, *args, **kwargs):
pass
class LangfuseTracker(BaseTracker):
log = logging.getLogger("langfuse")
def __init__(self, *, public_key: Optional[str] = None, secret_key: Optional[str] = None,
host: Optional[str] = None, debug: bool = False, stateful_client: Optional[
Union[StatefulTraceClient, StatefulSpanClient]
] = None, update_stateful_client: bool = False, version: Optional[str] = None,
session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None,
release: Optional[str] = None, metadata: Optional[Any] = None, tags: Optional[List[str]] = None,
threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None,
max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None,
httpx_client: Optional[httpx.Client] = None, sdk_integration: str = "default") -> None:
super().__init__()
self.version = version
self.session_id = session_id
self.user_id = user_id
self.trace_name = trace_name
self.release = release
self.metadata = metadata
self.tags = tags
self.root_span = None
self.update_stateful_client = update_stateful_client
self.langfuse = None
prio_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY")
prio_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY")
prio_host = host or os.environ.get(
"LANGFUSE_HOST", "https://cloud.langfuse.com"
)
if stateful_client and isinstance(stateful_client, StatefulTraceClient):
self.trace = stateful_client
self._task_manager = stateful_client.task_manager
return
elif stateful_client and isinstance(stateful_client, StatefulSpanClient):
self.root_span = stateful_client
self.trace = StatefulTraceClient(
stateful_client.client,
stateful_client.trace_id,
StateType.TRACE,
stateful_client.trace_id,
stateful_client.task_manager,
)
self._task_manager = stateful_client.task_manager
return
args = {
"public_key": prio_public_key,
"secret_key": prio_secret_key,
"host": prio_host,
"debug": debug,
}
if release is not None:
args["release"] = release
if threads is not None:
args["threads"] = threads
if flush_at is not None:
args["flush_at"] = flush_at
if flush_interval is not None:
args["flush_interval"] = flush_interval
if max_retries is not None:
args["max_retries"] = max_retries
if timeout is not None:
args["timeout"] = timeout
if enabled is not None:
args["enabled"] = enabled
if httpx_client is not None:
args["httpx_client"] = httpx_client
args["sdk_integration"] = sdk_integration
self.langfuse = Langfuse(**args)
self.trace: Optional[StatefulTraceClient] = None
self._task_manager = self.langfuse.task_manager
def call(self, i, o, name=None, **kwargs):
self.langfuse.trace(input=i, output=o, name=name, metadata=kwargs)

View File

@ -42,7 +42,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
return prompt_or_repo_id
prompt_file = cached_file(
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
prompt_or_repo_id,
PROMPT_FILES[mode],
repo_type="dataset",
user_agent={"agent": agent_name},
)
with open(prompt_file, "r", encoding="utf-8") as f:
return f.read()

View File

@ -26,6 +26,7 @@ import pandas as pd
from .utils import truncate_content
class InterpreterError(ValueError):
"""
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
@ -38,7 +39,8 @@ class InterpreterError(ValueError):
ERRORS = {
name: getattr(builtins, name)
for name in dir(builtins)
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
if isinstance(getattr(builtins, name), type)
and issubclass(getattr(builtins, name), BaseException)
}
@ -92,7 +94,9 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools):
elif isinstance(expression.op, ast.Invert):
return ~operand
else:
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
raise InterpreterError(
f"Unary operation {expression.op.__class__.__name__} is not supported."
)
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
@ -102,7 +106,9 @@ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
new_state = state.copy()
for arg, value in zip(args, values):
new_state[arg] = value
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
return evaluate_ast(
lambda_expression.body, new_state, static_tools, custom_tools
)
return lambda_func
@ -120,7 +126,9 @@ def evaluate_while(while_loop, state, static_tools, custom_tools):
break
iterations += 1
if iterations > max_iterations:
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
raise InterpreterError(
f"Maximum number of {max_iterations} iterations in While loop exceeded"
)
return None
@ -128,7 +136,10 @@ def create_function(func_def, state, static_tools, custom_tools):
def new_func(*args, **kwargs):
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
default_values = [
evaluate_ast(d, state, static_tools, custom_tools)
for d in func_def.args.defaults
]
# Apply default values
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
@ -180,26 +191,39 @@ def create_class(class_name, class_bases, class_body):
def evaluate_function_def(func_def, state, static_tools, custom_tools):
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
custom_tools[func_def.name] = create_function(
func_def, state, static_tools, custom_tools
)
return custom_tools[func_def.name]
def evaluate_class_def(class_def, state, static_tools, custom_tools):
class_name = class_def.name
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
bases = [
evaluate_ast(base, state, static_tools, custom_tools)
for base in class_def.bases
]
class_dict = {}
for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
class_dict[stmt.name] = evaluate_function_def(
stmt, state, static_tools, custom_tools
)
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
class_dict[target.id] = evaluate_ast(
stmt.value, state, static_tools, custom_tools
)
elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
class_dict[target.attr] = evaluate_ast(
stmt.value, state, static_tools, custom_tools
)
else:
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
raise InterpreterError(
f"Unsupported statement in class body: {stmt.__class__.__name__}"
)
new_class = type(class_name, tuple(bases), class_dict)
state[class_name] = new_class
@ -223,7 +247,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
elif isinstance(target, ast.List):
return [get_current_value(elt) for elt in target.elts]
else:
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
raise InterpreterError(
"AugAssign not supported for {type(target)} targets."
)
current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
@ -232,7 +258,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
if isinstance(expression.op, ast.Add):
if isinstance(current_value, list):
if not isinstance(value_to_add, list):
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
raise InterpreterError(
f"Cannot add non-list value {value_to_add} to a list."
)
updated_value = current_value + value_to_add
else:
updated_value = current_value + value_to_add
@ -259,7 +287,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
elif isinstance(expression.op, ast.RShift):
updated_value = current_value >> value_to_add
else:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
raise InterpreterError(
f"Operation {type(expression.op).__name__} is not supported."
)
# Update the state
set_value(expression.target, updated_value, state, static_tools, custom_tools)
@ -311,7 +341,9 @@ def evaluate_binop(binop, state, static_tools, custom_tools):
elif isinstance(binop.op, ast.RShift):
return left_val >> right_val
else:
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
raise NotImplementedError(
f"Binary operation {type(binop.op).__name__} is not implemented."
)
def evaluate_assign(assign, state, static_tools, custom_tools):
@ -321,7 +353,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
set_value(target, result, state, static_tools, custom_tools)
else:
if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
raise InterpreterError(
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
)
expanded_values = []
for tgt in assign.targets:
if isinstance(tgt, ast.Starred):
@ -336,7 +370,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
def set_value(target, value, state, static_tools, custom_tools):
if isinstance(target, ast.Name):
if target.id in static_tools:
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
raise InterpreterError(
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
)
state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
@ -399,9 +435,14 @@ def evaluate_call(call, state, static_tools, custom_tools):
else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
kwargs = {
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools)
for keyword in call.keywords
}
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
if (
isinstance(func, type) and len(func.__module__.split(".")) > 1
): # Check for user-defined classes
# Instantiate the class using its constructor
obj = func.__new__(func) # Create a new instance of the class
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
@ -441,7 +482,9 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(value, str) and isinstance(index, str):
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
raise InterpreterError(
"You're trying to subscript a string with a string index, which is impossible"
)
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]
@ -453,11 +496,15 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
return value[index]
elif isinstance(value, (list, tuple)):
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
raise InterpreterError(
f"Index {index} out of bounds for list of length {len(value)}"
)
return value[int(index)]
elif isinstance(value, str):
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
raise InterpreterError(
f"Index {index} out of bounds for string of length {len(value)}"
)
return value[index]
elif index in value:
return value[index]
@ -483,7 +530,10 @@ def evaluate_name(name, state, static_tools, custom_tools):
def evaluate_condition(condition, state, static_tools, custom_tools):
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
comparators = [
evaluate_ast(c, state, static_tools, custom_tools)
for c in condition.comparators
]
ops = [type(op) for op in condition.ops]
result = True
@ -561,9 +611,13 @@ def evaluate_for(for_loop, state, static_tools, custom_tools):
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
def inner_evaluate(generators, index, current_state):
if index >= len(generators):
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
return [
evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)
]
generator = generators[index]
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
iter_value = evaluate_ast(
generator.iter, current_state, static_tools, custom_tools
)
result = []
for value in iter_value:
new_state = current_state.copy()
@ -572,7 +626,10 @@ def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
new_state[elem.id] = value[idx]
else:
new_state[generator.target.id] = value
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
for if_clause in generator.ifs
):
result.extend(inner_evaluate(generators, index + 1, new_state))
return result
@ -586,7 +643,9 @@ def evaluate_try(try_node, state, static_tools, custom_tools):
except Exception as e:
matched = False
for handler in try_node.handlers:
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
if handler.type is None or isinstance(
e, evaluate_ast(handler.type, state, static_tools, custom_tools)
):
matched = True
if handler.name:
state[handler.name] = e
@ -638,7 +697,9 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools):
def evaluate_with(with_node, state, static_tools, custom_tools):
contexts = []
for item in with_node.items:
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
context_expr = evaluate_ast(
item.context_expr, state, static_tools, custom_tools
)
if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id])
@ -661,7 +722,9 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
module_path = module_name.split(".")
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
@ -676,7 +739,9 @@ def import_modules(expression, state, authorized_imports):
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names]
)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
@ -691,9 +756,14 @@ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
for value in iter_value:
new_state = state.copy()
set_value(gen.target, value, new_state, static_tools, custom_tools)
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
for if_clause in gen.ifs
):
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
val = evaluate_ast(
dictcomp.value, new_state, static_tools, custom_tools
)
result[key] = val
return result
@ -744,7 +814,10 @@ def evaluate_ast(
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
return tuple(
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.UnaryOp):
@ -770,8 +843,13 @@ def evaluate_ast(
return evaluate_function_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
keys = [
evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys
]
values = [
evaluate_ast(v, state, static_tools, custom_tools)
for v in expression.values
]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
@ -788,10 +866,18 @@ def evaluate_ast(
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
return "".join(
[
str(evaluate_ast(v, state, static_tools, custom_tools))
for v in expression.values
]
)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
return [
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, static_tools, custom_tools)
@ -815,7 +901,9 @@ def evaluate_ast(
evaluate_ast(expression.upper, state, static_tools, custom_tools)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
evaluate_ast(expression.step, state, static_tools, custom_tools)
if expression.step is not None
else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
@ -834,17 +922,24 @@ def evaluate_ast(
elif isinstance(expression, ast.With):
return evaluate_with(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
return {
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
}
elif isinstance(expression, ast.Return):
raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
evaluate_ast(expression.value, state, static_tools, custom_tools)
if expression.value
else None
)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
def truncate_print_outputs(
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
) -> str:
if len(print_outputs) < max_len_outputs:
return print_outputs
else:
@ -895,8 +990,12 @@ def evaluate_python_code(
OPERATIONS_COUNT = 0
try:
for node in expression.body:
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
result = evaluate_ast(
node, state, static_tools, custom_tools, authorized_imports
)
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
return result
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)

View File

@ -26,7 +26,9 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "any"
def forward(self, query: str) -> str:

View File

@ -26,7 +26,13 @@ from functools import lru_cache, wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
from huggingface_hub import (
create_repo,
get_collection,
hf_hub_download,
metadata_update,
upload_folder,
)
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version
@ -73,7 +79,9 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
raise EnvironmentError(
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
)
except Exception:
return "model"
except Exception:
@ -158,7 +166,15 @@ class Tool:
"inputs": dict,
"output_type": str,
}
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
authorized_types = [
"string",
"integer",
"number",
"image",
"audio",
"any",
"boolean",
]
for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None)
@ -169,7 +185,9 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
)
for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
assert isinstance(
input_content, dict
), f"Input '{input_name}' should be a dictionary."
assert (
"type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
@ -251,7 +269,11 @@ class Tool:
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
f.write(
APP_FILE_TEMPLATE.format(
module_name=last_module, class_name=self.__class__.__name__
)
)
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
@ -343,7 +365,9 @@ class Tool:
custom_tool = config
tool_class = custom_tool["tool_class"]
tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
tool_class = get_class_from_dynamic_module(
tool_class, repo_id, token=token, **hub_kwargs
)
if len(tool_class.name) == 0:
tool_class.name = custom_tool["name"]
@ -420,7 +444,9 @@ class Tool:
with tempfile.TemporaryDirectory() as work_dir:
# Save all files.
self.save(work_dir)
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
)
return upload_folder(
repo_id=repo_id,
commit_message=commit_message,
@ -432,7 +458,11 @@ class Tool:
@staticmethod
def from_space(
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
space_id: str,
name: str,
description: str,
api_name: Optional[str] = None,
token: Optional[str] = None,
):
"""
Creates a [`Tool`] from a Space given its id on the Hub.
@ -485,7 +515,9 @@ class Tool:
self.client = Client(space_id, hf_token=token)
self.name = name
self.description = description
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
space_description = self.client.view_api(
return_format="dict", print_info=False
)["named_endpoints"]
# If api_name is not defined, take the first of the available APIs for this space
if api_name is None:
@ -498,7 +530,9 @@ class Tool:
try:
space_description_api = space_description[api_name]
except KeyError:
raise KeyError(f"Could not find specified {api_name=} among available api names.")
raise KeyError(
f"Could not find specified {api_name=} among available api names."
)
self.inputs = {}
for parameter in space_description_api["parameters"]:
@ -523,9 +557,11 @@ class Tool:
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name)
arg = temp_file.name
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
arg
):
if (
isinstance(arg, (str, Path))
and Path(arg).exists()
and Path(arg).is_file()
) or is_http_url_like(arg):
arg = handle_file(arg)
return arg
@ -544,7 +580,9 @@ class Tool:
] # Sometime the space also returns the generation seed, in which case the result is at index 0
return output
return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
return SpaceToolWrapper(
space_id, name, description, api_name=api_name, token=token
)
@staticmethod
def from_gradio(gradio_tool):
@ -561,7 +599,8 @@ class Tool:
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
for key, value in func_args
}
self.forward = self._gradio_tool.run
@ -603,7 +642,9 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
"""
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
def get_tool_description_with_args(
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) -> str:
compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render(
tool=tool,
@ -621,7 +662,10 @@ def compile_jinja_template(template):
raise ImportError("template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
raise ImportError(
"template requires jinja2>=3.1.0 to be installed. Your version is "
f"{jinja2.__version__}."
)
def raise_exception(message):
raise TemplateError(message)
@ -697,7 +741,9 @@ class PipelineTool(Tool):
if model is None:
if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
raise ValueError(
"This tool does not implement a default checkpoint, you need to pass one."
)
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
@ -720,15 +766,21 @@ class PipelineTool(Tool):
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs
)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
self.model = self.model_class.from_pretrained(
self.model, **self.model_kwargs, **self.hub_kwargs
)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
self.post_processor = self.post_processor_class.from_pretrained(
self.post_processor, **self.hub_kwargs
)
if self.device is None:
if self.device_map is not None:
@ -768,8 +820,12 @@ class PipelineTool(Tool):
encoded_inputs = self.encode(*args, **kwargs)
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
tensor_inputs = {
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
}
non_tensor_inputs = {
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
}
encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
@ -790,7 +846,9 @@ def launch_gradio_demo(tool_class: Tool):
try:
import gradio as gr
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
raise ImportError(
"Gradio should be installed in order to launch a gradio demo."
)
tool = tool_class()
@ -807,11 +865,15 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs = []
for input_name, input_details in tool_class.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"]
]
new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component)
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
tool_class.output_type
]
gradio_output = output_gradio_componentclass(label=input_name)
gr.Interface(
@ -875,7 +937,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
f"code that you have checked."
)
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
return Tool.from_hub(
task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
)
def add_description(description):
@ -935,7 +999,9 @@ class EndpointClient:
payload["parameters"] = params
# Make API call
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
response = get_session().post(
self.endpoint_url, headers=self.headers, json=payload, data=data
)
# By default, parse the response for the user.
if output_image:
@ -972,7 +1038,9 @@ class ToolCollection:
def __init__(self, collection_slug: str, token: Optional[str] = None):
self._collection = get_collection(collection_slug, token=token)
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
self._hub_repo_ids = {
item.item_id for item in self._collection.items if item.item_type == "space"
}
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
@ -986,7 +1054,9 @@ def tool(tool_function: Callable) -> Tool:
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
class_name = f"{parameters['name'].capitalize()}Tool"
class SpecificTool(Tool):
@ -1000,9 +1070,9 @@ def tool(tool_function: Callable) -> Tool:
return tool_function(*args, **kwargs)
original_signature = inspect.signature(tool_function)
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
original_signature.parameters.values()
)
new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature
@ -1049,7 +1119,10 @@ class Toolbox:
The template to use to describe the tools. If not provided, the default template will be used.
"""
return "\n".join(
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
[
get_tool_description_with_args(tool, tool_description_template)
for tool in self._tools.values()
]
)
def add_tool(self, tool: Tool):

View File

@ -16,32 +16,29 @@
# limitations under the License.
import json
import re
from typing import Tuple, Dict
from typing import Tuple, Dict, Union
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")
def is_pygments_available():
return _pygments_available
from rich.console import Console
console = Console()
LENGTH_TRUNCATE_REPORTS = 10000
def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS):
if len(content) < max_length:
return content
else:
return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:]
def parse_json_blob(json_blob: str) -> Dict[str, str]:
try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
'\\"', "'"
)
json_data = json.loads(json_blob, strict=False)
return json_data
except json.JSONDecodeError as e:
@ -63,7 +60,12 @@ def parse_code_blob(code_blob: str) -> str:
try:
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
if match is None:
raise ValueError(
f"No match ground for regex pattern {pattern} in {code_blob=}."
)
return match.group(1).strip()
except Exception as e:
raise ValueError(
f"""
@ -77,7 +79,7 @@ Code:
)
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
json_blob = json_blob.replace("```json", "").replace("```", "")
tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call:
@ -85,7 +87,25 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
elif "action" in tool_call:
return tool_call["action"], None
else:
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
missing_keys = [
key for key in ["action", "action_input"] if key not in tool_call
]
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
console.print(f"[bold red]{error_msg}[/bold red]")
raise ValueError(error_msg)
MAX_LENGTH_TRUNCATE_CONTENT = 20000
def truncate_content(
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
) -> str:
if len(content) <= max_length:
return content
else:
return (
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
)

View File

@ -120,4 +120,4 @@ def get_weather_api(location (str), date_time: str) -> str:
raise ValueError("Conversion of `date_time` to datetime format failed, make sure to provide a string in format '%m/%d/%y %H:%M:%S'. Full trace:" + str(e))
temperature_celsius, risk_of_rain, wave_height = get_weather_report_at_coordinates((lon, lat), date_time)
return f"Weather report for {location}, {date_time}: Temperature will be {temperature_celsius}°C, risk of rain is {risk_of_rain*100:.0f}%, wave height is {wave_height}m."
```
```

2748
poetry.lock generated

File diff suppressed because it is too large Load Diff

271
requirements.txt Normal file
View File

@ -0,0 +1,271 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiofiles==23.2.1
# via gradio
annotated-types==0.7.0
# via pydantic
anyio==4.7.0
# via
# gradio
# httpx
# starlette
appnope==0.1.4
# via ipykernel
asttokens==3.0.0
# via stack-data
beautifulsoup4==4.12.3
# via markdownify
certifi==2024.8.30
# via
# httpcore
# httpx
# requests
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via
# duckduckgo-search
# typer
# uvicorn
comm==0.2.2
# via ipykernel
debugpy==1.8.10
# via ipykernel
decorator==5.1.1
# via ipython
diskcache==5.6.3
# via llama-cpp-python
duckduckgo-search==6.4.1
# via -r requirements.in
executing==2.1.0
# via stack-data
fastapi==0.115.6
# via gradio
ffmpy==0.4.0
# via gradio
filelock==3.16.1
# via
# huggingface-hub
# transformers
fsspec==2024.10.0
# via
# gradio-client
# huggingface-hub
gradio==5.8.0
# via -r requirements.in
gradio-client==1.5.1
# via gradio
h11==0.14.0
# via
# httpcore
# uvicorn
httpcore==1.0.7
# via httpx
httpx==0.28.1
# via
# gradio
# gradio-client
# safehttpx
huggingface-hub==0.26.5
# via
# gradio
# gradio-client
# tokenizers
# transformers
idna==3.10
# via
# anyio
# httpx
# requests
iniconfig==2.0.0
# via pytest
ipykernel==6.29.5
# via -r requirements.in
ipython==8.30.0
# via ipykernel
jedi==0.19.2
# via ipython
jinja2==3.1.4
# via
# -r requirements.in
# gradio
# llama-cpp-python
jupyter-client==8.6.3
# via ipykernel
jupyter-core==5.7.2
# via
# ipykernel
# jupyter-client
llama-cpp-python==0.3.5
# via -r requirements.in
markdown-it-py==3.0.0
# via rich
markdownify==0.14.1
# via -r requirements.in
markupsafe==2.1.5
# via
# gradio
# jinja2
matplotlib-inline==0.1.7
# via
# ipykernel
# ipython
mdurl==0.1.2
# via markdown-it-py
nest-asyncio==1.6.0
# via ipykernel
numpy==2.2.0
# via
# gradio
# llama-cpp-python
# pandas
# transformers
orjson==3.10.12
# via gradio
packaging==24.2
# via
# gradio
# gradio-client
# huggingface-hub
# ipykernel
# pytest
# transformers
pandas==2.2.3
# via
# -r requirements.in
# gradio
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pillow==11.0.0
# via
# -r requirements.in
# gradio
platformdirs==4.3.6
# via jupyter-core
pluggy==1.5.0
# via pytest
primp==0.8.3
# via duckduckgo-search
prompt-toolkit==3.0.48
# via ipython
psutil==6.1.0
# via ipykernel
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pydantic==2.10.3
# via
# fastapi
# gradio
pydantic-core==2.27.1
# via pydantic
pydub==0.25.1
# via gradio
pygments==2.18.0
# via
# ipython
# rich
pytest==8.3.4
# via -r requirements.in
python-dateutil==2.9.0.post0
# via
# jupyter-client
# pandas
python-dotenv==1.0.1
# via -r requirements.in
python-multipart==0.0.19
# via gradio
pytz==2024.2
# via pandas
pyyaml==6.0.2
# via
# gradio
# huggingface-hub
# transformers
pyzmq==26.2.0
# via
# ipykernel
# jupyter-client
regex==2024.11.6
# via transformers
requests==2.32.3
# via
# -r requirements.in
# huggingface-hub
# transformers
rich==13.9.4
# via
# -r requirements.in
# typer
ruff==0.8.2
# via gradio
safehttpx==0.1.6
# via gradio
safetensors==0.4.5
# via transformers
semantic-version==2.10.0
# via gradio
shellingham==1.5.4
# via typer
six==1.17.0
# via
# markdownify
# python-dateutil
sniffio==1.3.1
# via anyio
soupsieve==2.6
# via beautifulsoup4
stack-data==0.6.3
# via ipython
starlette==0.41.3
# via
# fastapi
# gradio
tokenizers==0.21.0
# via transformers
tomlkit==0.13.2
# via gradio
tornado==6.4.2
# via
# ipykernel
# jupyter-client
tqdm==4.67.1
# via
# huggingface-hub
# transformers
traitlets==5.14.3
# via
# comm
# ipykernel
# ipython
# jupyter-client
# jupyter-core
# matplotlib-inline
transformers==4.47.0
# via -r requirements.in
typer==0.15.1
# via gradio
typing-extensions==4.12.2
# via
# anyio
# fastapi
# gradio
# gradio-client
# huggingface-hub
# llama-cpp-python
# pydantic
# pydantic-core
# typer
tzdata==2024.2
# via pandas
urllib3==2.2.3
# via requests
uvicorn==0.32.1
# via gradio
wcwidth==0.2.13
# via prompt-toolkit
websockets==14.1
# via gradio-client

122
setup.py
View File

@ -1,122 +0,0 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import find_packages, setup
extras = {}
extras["quality"] = [
"black ~= 23.1", # hf-doc-builder has a hidden dependency on `black`
"hf-doc-builder >= 0.3.0",
"ruff ~= 0.6.4",
]
extras["docs"] = []
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
extras["test_dev"] = [
"datasets",
"diffusers",
"evaluate",
"torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
"scikit-learn",
"tqdm",
"bitsandbytes",
"timm",
]
extras["testing"] = extras["test_prod"] + extras["test_dev"]
extras["deepspeed"] = ["deepspeed"]
extras["rich"] = ["rich"]
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
extras["sagemaker"] = [
"sagemaker", # boto3 is a required package in sagemaker
]
setup(
name="accelerate",
version="1.2.0.dev0",
description="Accelerate",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning",
license="Apache",
author="The HuggingFace team",
author_email="zach.mueller@huggingface.co",
url="https://github.com/huggingface/accelerate",
package_dir={"": "src"},
packages=find_packages("src"),
entry_points={
"console_scripts": [
"accelerate=accelerate.commands.accelerate_cli:main",
"accelerate-config=accelerate.commands.config:main",
"accelerate-estimate-memory=accelerate.commands.estimate:main",
"accelerate-launch=accelerate.commands.launch:main",
"accelerate-merge-weights=accelerate.commands.merge:main",
]
},
python_requires=">=3.9.0",
install_requires=[
"numpy>=1.17,<3.0.0",
"packaging>=20.0",
"psutil",
"pyyaml",
"torch>=1.10.0",
"huggingface_hub>=0.21.0",
"safetensors>=0.4.3",
],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
# Release checklist
# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):
# git checkout -b vXX.xx-release
# The -b is only necessary for creation (so remove it when doing a patch)
# 2. Change the version in __init__.py and setup.py to the proper value.
# 3. Commit these changes with the message: "Release: v<VERSION>"
# 4. Add a tag in git to mark the release:
# git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi'
# Push the tag and release commit to git: git push --tags origin vXX.xx-release
# 5. Run the following commands in the top-level directory:
# rm -rf dist
# rm -rf build
# python setup.py bdist_wheel
# python setup.py sdist
# 6. Upload the package to the pypi test server first:
# twine upload dist/* -r testpypi
# 7. Check that you can install it in a virtualenv by running:
# pip install accelerate
# pip uninstall accelerate
# pip install -i https://testpypi.python.org/pypi accelerate
# accelerate env
# accelerate test
# 8. Upload the final version to actual pypi:
# twine upload dist/* -r pypi
# 9. Add release notes to the tag in github once everything is looking hunky-dory.
# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to
# main.

View File

@ -19,19 +19,25 @@ import uuid
from pathlib import Path
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
from transformers.testing_utils import (
get_tests_dir,
require_soundfile,
require_torch,
require_vision,
)
from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
import torch
from PIL import Image
if is_torch_available():
import torch
if is_soundfile_availble():
import soundfile as sf
if is_vision_available():
from PIL import Image
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()

View File

@ -19,17 +19,16 @@ import uuid
import pytest
from transformers.agents.agent_types import AgentText
from transformers.agents.agents import (
from agents.agent_types import AgentText
from agents.agents import (
AgentMaxIterationsError,
CodeAgent,
ManagedAgent,
ReactCodeAgent,
ReactJsonAgent,
CodeAgent,
JsonAgent,
Toolbox,
)
from transformers.agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import require_torch
from agents.default_tools import PythonInterpreterTool
def get_new_path(suffix="") -> str:
@ -149,19 +148,26 @@ print(result)
class AgentTests(unittest.TestCase):
def test_fake_code_agent(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
def test_fake_react_json_agent(self):
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
agent = JsonAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert agent.logs[1]["observation"] == "7.2904"
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
assert (
agent.logs[1]["rationale"].strip()
== "Thought: I should multiply 2 by 3.6452. special_marker"
)
assert (
agent.logs[2]["llm_output"]
== """
@ -175,7 +181,9 @@ Action:
)
def test_fake_react_code_agent(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float)
assert output == 7.2904
@ -186,17 +194,19 @@ Action:
}
def test_react_code_agent_code_errors_show_offending_lines(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self):
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
JsonAgent(llm_engine=fake_react_json_llm, tools=[])
def test_react_fails_max_iterations(self):
agent = ReactCodeAgent(
agent = CodeAgent(
tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
max_iterations=5,
@ -208,51 +218,62 @@ Action:
@require_torch
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
agent = JsonAgent(
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert (
len(agent.toolbox.tools) == 7
) # added final_answer tool + 6 base tools (excluding interpreter)
def test_function_persistence_across_steps(self):
agent = ReactCodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
agent = CodeAgent(
tools=[],
llm_engine=fake_react_code_functiondef,
max_iterations=2,
additional_authorized_imports=["numpy"],
)
res = agent.run("ok")
assert res[0] == 0.5
def test_init_managed_agent(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = ReactCodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
manager_agent = CodeAgent(
tools=[],
llm_engine=fake_react_code_functiondef,
managed_agents=[managed_agent],
)
assert "You can also give requests to team members." not in agent.system_prompt
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
assert "You can also give requests to team members." in manager_agent.system_prompt
assert (
"You can also give requests to team members." in manager_agent.system_prompt
)

View File

@ -1,41 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document_question_answering")
self.tool.setup()
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
result = self.tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
self.tool(document=document, question="When is the coffee break?")

View File

@ -87,7 +87,11 @@ class ExampleDifferenceTests(unittest.TestCase):
examples_path = Path("examples").resolve()
def one_complete_example(
self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None
self,
complete_file_name: str,
parser_only: bool,
secondary_filename: str = None,
special_strings: list = None,
):
"""
Tests a single `complete` example against all of the implemented `by_feature` scripts
@ -112,10 +116,15 @@ class ExampleDifferenceTests(unittest.TestCase):
with self.subTest(
tested_script=complete_file_name,
feature_script=item,
tested_section="main()" if parser_only else "training_function()",
tested_section="main()"
if parser_only
else "training_function()",
):
diff = compare_against_test(
self.examples_path / complete_file_name, item_path, parser_only, secondary_filename
self.examples_path / complete_file_name,
item_path,
parser_only,
secondary_filename,
)
diff = "\n".join(diff)
if special_strings is not None:
@ -140,8 +149,12 @@ class ExampleDifferenceTests(unittest.TestCase):
" " * 12,
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
]
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
self.one_complete_example(
"complete_cv_example.py", True, cv_path, special_strings
)
self.one_complete_example(
"complete_cv_example.py", False, cv_path, special_strings
)
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})

View File

@ -47,9 +47,9 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def create_inputs(self):
inputs_text = {"answer": "Text input"}
inputs_image = {
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
(512, 512)
)
"answer": Image.open(
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
).resize((512, 512))
}
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}

View File

@ -1,42 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image_question_answering")
self.tool.setup()
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")

View File

@ -129,7 +129,9 @@ final_answer('This is the final answer.')
def test_streaming_agent_image_output(self):
def dummy_llm_engine(prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
return (
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
)
agent = ReactJsonAgent(
tools=[],
@ -138,7 +140,14 @@ final_answer('This is the final answer.')
)
# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
outputs = list(
stream_to_gradio(
agent,
task="Test task",
image=AgentImage(value="path.png"),
test_mode=True,
)
)
self.assertEqual(len(outputs), 2)
final_message = outputs[-1]

View File

@ -21,7 +21,10 @@ import pytest
from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
from transformers.agents.python_interpreter import (
InterpreterError,
evaluate_python_code,
)
from .test_tools_common import ToolTesterMixin
@ -57,7 +60,12 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
_inputs.append(
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
@ -91,7 +99,10 @@ class PythonInterpreterTester(unittest.TestCase):
code = "print = '3'"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {"print": print}, state={})
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
assert (
"Cannot assign to name 'print': doing this would erase the existing tool!"
in str(e)
)
def test_evaluate_call(self):
code = "y = add_two(x)"
@ -117,7 +128,9 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
@ -133,7 +146,9 @@ class PythonInterpreterTester(unittest.TestCase):
result = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
self.assertDictEqual(
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
)
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
@ -174,11 +189,15 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
evaluate_python_code(
code, {"min": min, "print": print, "round": round}, state=state
)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self):
@ -292,7 +311,16 @@ print(check_digits)
"""
state = {}
evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
code,
{
"range": range,
"print": print,
"sum": sum,
"enumerate": enumerate,
"int": int,
"str": str,
},
state,
)
def test_listcomp(self):
@ -325,7 +353,9 @@ print(check_digits)
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
result = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert result == {102: "b"}
code = """
@ -373,7 +403,9 @@ else:
best_city = "Manhattan"
best_city
"""
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Brooklyn"
code = """if d > e and a < b:
@ -384,7 +416,9 @@ else:
best_city = "Manhattan"
best_city
"""
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Sacramento"
def test_if_conditions(self):
@ -400,7 +434,9 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
@ -434,10 +470,14 @@ if char.isalpha():
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
def test_additional_imports(self):
code = "import numpy as np"
@ -554,7 +594,11 @@ cat_sound = cat.sound()
cat_str = str(cat)
"""
state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
evaluate_python_code(
code,
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
state=state,
)
# Assert results
assert state["dog1_sound"] == "The dog barks."
@ -588,7 +632,11 @@ except ValueError as e:
exception_message = str(e)
"""
state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
evaluate_python_code(
code,
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
state=state,
)
assert state["exception_message"] == "An error occurred"
def test_print(self):
@ -600,7 +648,9 @@ except ValueError as e:
def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
result = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state
)
assert result is int
def test_tuple_id(self):
@ -731,7 +781,9 @@ def add_one(n, shift):
add_one(1, 1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
result = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result == 2
# test returning None
@ -742,7 +794,9 @@ def returns_none(a):
returns_none(1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
result = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result is None
def test_nested_for_loop(self):
@ -758,7 +812,9 @@ out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
result = evaluate_python_code(
code, {"print": print, "range": range}, state=state
)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self):
@ -773,7 +829,9 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
result = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"]
)
assert np.array_equal(result, [-1, 5])
code = """
@ -785,7 +843,9 @@ print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
result = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert np.array_equal(result.values[0], [104, 1])
code = """import pandas as pd
@ -818,7 +878,9 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
result = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
assert round(result, 1) == 622395.4
def test_for(self):

View File

@ -1,36 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("speech_to_text")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(np.ones(3000))
self.assertEqual(result, " Thank you.")
def test_exact_match_kwarg(self):
result = self.tool(audio=np.ones(3000))
self.assertEqual(result, " Thank you.")

View File

@ -1,50 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from transformers.utils import is_torch_available
if is_torch_available():
import torch
from transformers.testing_utils import require_torch
from .test_tools_common import ToolTesterMixin
@require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text_to_speech")
self.tool.setup()
def test_exact_match_arg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
def test_exact_match_kwarg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)

View File

@ -20,7 +20,12 @@ import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
from transformers.agents.agent_types import (
AGENT_TYPE_MAPPING,
AgentAudio,
AgentImage,
AgentText,
)
from transformers.agents.tools import Tool, tool
from transformers.testing_utils import get_tests_dir

View File

@ -32,7 +32,9 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg(self):
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
result = self.tool(
text="Hey, what's up?", src_lang="English", tgt_lang="French"
)
self.assertEqual(result, "- Hé, comment ça va?")
def test_call(self):