Ruff formatting

This commit is contained in:
Aymeric 2024-12-11 16:16:18 +01:00
parent 851e177e71
commit 67deb6808f
28 changed files with 1153 additions and 3460 deletions

View File

@ -24,10 +24,25 @@ from transformers.utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"], "agents": [
"Agent",
"CodeAgent",
"ManagedAgent",
"ReactAgent",
"CodeAgent",
"JsonAgent",
"Toolbox",
],
"llm_engine": ["HfApiEngine", "TransformersEngine"], "llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"], "monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"], "tools": [
"PipelineTool",
"Tool",
"ToolCollection",
"launch_gradio_demo",
"load_tool",
"tool",
],
} }
try: try:
@ -45,10 +60,25 @@ else:
_import_structure["translation"] = ["TranslationTool"] _import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox from .agents import (
Agent,
CodeAgent,
ManagedAgent,
ReactAgent,
CodeAgent,
JsonAgent,
Toolbox,
)
from .llm_engine import HfApiEngine, TransformersEngine from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool from .tools import (
PipelineTool,
Tool,
ToolCollection,
launch_gradio_demo,
load_tool,
tool,
)
try: try:
if not is_torch_available(): if not is_torch_available():
@ -66,4 +96,6 @@ if TYPE_CHECKING:
else: else:
import sys import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
)

View File

@ -19,7 +19,11 @@ import uuid
import numpy as np import numpy as np
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
import logging import logging
@ -108,7 +112,9 @@ class AgentImage(AgentType, ImageType):
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value) self._tensor = torch.from_numpy(value)
else: else:
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") raise TypeError(
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
)
def _ipython_display_(self, include=None, exclude=None): def _ipython_display_(self, include=None, exclude=None):
""" """
@ -243,7 +249,9 @@ if is_torch_available():
def handle_agent_inputs(*args, **kwargs): def handle_agent_inputs(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()} kwargs = {
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
}
return args, kwargs return args, kwargs

View File

@ -15,12 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time import time
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from rich.syntax import Syntax from rich.syntax import Syntax
from langfuse.decorators import langfuse_context, observe
from transformers.utils import is_torch_available from transformers.utils import is_torch_available
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
@ -51,6 +49,7 @@ from .tools import (
HUGGINGFACE_DEFAULT_TOOLS = {} HUGGINGFACE_DEFAULT_TOOLS = {}
class AgentError(Exception): class AgentError(Exception):
"""Base class for other agent-related exceptions""" """Base class for other agent-related exceptions"""
@ -60,7 +59,6 @@ class AgentError(Exception):
console.print(f"[bold red]{message}[/bold red]") console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError): class AgentParsingError(AgentError):
"""Exception raised for errors in parsing in the agent""" """Exception raised for errors in parsing in the agent"""
@ -84,12 +82,14 @@ class AgentGenerationError(AgentError):
pass pass
class AgentStep: class AgentStep:
pass pass
@dataclass @dataclass
class ActionStep(AgentStep): class ActionStep(AgentStep):
tool_call: str | None = None tool_call: Dict[str, str] | None = None
start_time: float | None = None start_time: float | None = None
step_end_time: float | None = None step_end_time: float | None = None
iteration: int | None = None iteration: int | None = None
@ -97,32 +97,43 @@ class ActionStep(AgentStep):
error: AgentError | None = None error: AgentError | None = None
step_duration: float | None = None step_duration: float | None = None
llm_output: str | None = None llm_output: str | None = None
observation: str | None = None
agent_memory: List[Dict[str, str]] | None = None
rationale: str | None = None
@dataclass @dataclass
class PlanningStep(AgentStep): class PlanningStep(AgentStep):
plan: str plan: str
facts: str facts: str
@dataclass @dataclass
class TaskStep(AgentStep): class TaskStep(AgentStep):
task: str task: str
@dataclass @dataclass
class SystemPromptStep(AgentStep): class SystemPromptStep(AgentStep):
system_prompt: str system_prompt: str
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: def format_prompt_with_tools(
toolbox: Toolbox, prompt_template: str, tool_description_template: str
) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "{{tool_names}}" in prompt: if "{{tool_names}}" in prompt:
prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()])) prompt = prompt.replace(
"{{tool_names}}",
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
)
return prompt return prompt
def show_agents_descriptions(managed_agents: list): def show_agents_descriptions(managed_agents: Dict):
managed_agents_descriptions = """ managed_agents_descriptions = """
You can also give requests to team members. You can also give requests to team members.
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request. Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
@ -133,16 +144,24 @@ Here is a list of the team members that you can call:"""
return managed_agents_descriptions return managed_agents_descriptions
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str: def format_prompt_with_managed_agents_descriptions(
prompt_template, managed_agents=None
) -> str:
if managed_agents is not None: if managed_agents is not None:
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)) return prompt_template.replace(
"<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)
)
else: else:
return prompt_template.replace("<<managed_agents_descriptions>>", "") return prompt_template.replace("<<managed_agents_descriptions>>", "")
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str: def format_prompt_with_imports(
prompt_template: str, authorized_imports: List[str]
) -> str:
if "<<authorized_imports>>" not in prompt_template: if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.") raise AgentError(
"Tag '<<authorized_imports>>' should be provided in the prompt."
)
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports)) return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
@ -150,7 +169,7 @@ class BaseAgent:
def __init__( def __init__(
self, self,
tools: Union[List[Tool], Toolbox], tools: Union[List[Tool], Toolbox],
llm_engine: Callable = None, llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None, tool_description_template: Optional[str] = None,
additional_args: Dict = {}, additional_args: Dict = {},
@ -159,10 +178,12 @@ class BaseAgent:
add_base_tools: bool = False, add_base_tools: bool = False,
verbose: bool = False, verbose: bool = False,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None, managed_agents: Optional[Dict] = None,
step_callbacks: Optional[List[Callable]] = None, step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True, monitor_metrics: bool = True,
): ):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_parser is None: if tool_parser is None:
@ -171,14 +192,16 @@ class BaseAgent:
self.llm_engine = llm_engine self.llm_engine = llm_engine
self.system_prompt_template = system_prompt self.system_prompt_template = system_prompt
self.tool_description_template = ( self.tool_description_template = (
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE tool_description_template
if tool_description_template
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) )
self.additional_args = additional_args self.additional_args = additional_args
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar self.grammar = grammar
self.managed_agents = None self.managed_agents = {}
if managed_agents is not None: if managed_agents is not None:
self.managed_agents = {agent.name: agent for agent in managed_agents} self.managed_agents = {agent.name: agent for agent in managed_agents}
@ -186,9 +209,13 @@ class BaseAgent:
self._toolbox = tools self._toolbox = tools
if add_base_tools: if add_base_tools:
if not is_torch_available(): if not is_torch_available():
raise ImportError("Using the base tools requires torch to be installed.") raise ImportError(
"Using the base tools requires torch to be installed."
)
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent)) self._toolbox.add_base_tools(
add_python_interpreter=(self.__class__ == JsonAgent)
)
else: else:
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
self._toolbox.add_tool(FinalAnswerTool()) self._toolbox.add_tool(FinalAnswerTool())
@ -196,7 +223,9 @@ class BaseAgent:
self.system_prompt = format_prompt_with_tools( self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template self._toolbox, self.system_prompt_template, self.tool_description_template
) )
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
self.prompt_messages = None self.prompt_messages = None
self.logs = [] self.logs = []
self.task = None self.task = None
@ -222,15 +251,20 @@ class BaseAgent:
self.system_prompt_template, self.system_prompt_template,
self.tool_description_template, self.tool_description_template,
) )
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
if hasattr(self, "authorized_imports"): if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports( self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) self.system_prompt,
list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))),
) )
return self.system_prompt return self.system_prompt
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: def write_inner_memory_from_logs(
self, summary_mode: Optional[bool] = False
) -> List[Dict[str, str]]:
""" """
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM. that can be used as input to the LLM.
@ -253,7 +287,10 @@ class BaseAgent:
memory.append(thought_message) memory.append(thought_message)
if not summary_mode: if not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log.plan.strip()} thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[PLAN]:\n" + step_log.plan.strip(),
}
memory.append(thought_message) memory.append(thought_message)
elif isinstance(step_log, TaskStep): elif isinstance(step_log, TaskStep):
@ -265,13 +302,17 @@ class BaseAgent:
elif isinstance(step_log, ActionStep): elif isinstance(step_log, ActionStep):
if step_log.llm_output is not None and not summary_mode: if step_log.llm_output is not None and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log.llm_output.strip()} thought_message = {
"role": MessageRole.ASSISTANT,
"content": step_log.llm_output.strip(),
}
memory.append(thought_message) memory.append(thought_message)
if step_log.tool_call is not None and summary_mode: if step_log.tool_call is not None and summary_mode:
tool_call_message = { tool_call_message = {
"role": MessageRole.ASSISTANT, "role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log.tool_call).strip(), "content": f"[STEP {i} TOOL CALL]: "
+ str(step_log.tool_call).strip(),
} }
memory.append(tool_call_message) memory.append(tool_call_message)
@ -284,15 +325,21 @@ class BaseAgent:
) )
elif step_log.observation is not None: elif step_log.observation is not None:
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": message_content,
}
memory.append(tool_response_message) memory.append(tool_response_message)
return memory return memory
def get_succinct_logs(self): def get_succinct_logs(self):
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] return [
{key: value for key, value in log.items() if key != "agent_memory"}
for log in self.logs
]
def extract_action(self, llm_output: str, split_token: str) -> str: def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
""" """
Parse action from the LLM output Parse action from the LLM output
@ -312,54 +359,6 @@ class BaseAgent:
) )
return rationale.strip(), action.strip() return rationale.strip(), action.strip()
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
try:
if isinstance(arguments, str):
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
tool_description = get_tool_description_with_args(available_tools[tool_name])
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def run(self, **kwargs): def run(self, **kwargs):
"""To be implemented in the child class""" """To be implemented in the child class"""
raise NotImplementedError raise NotImplementedError
@ -382,8 +381,6 @@ class ReactAgent(BaseAgent):
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
if tool_description_template is None: if tool_description_template is None:
@ -423,8 +420,67 @@ class ReactAgent(BaseAgent):
console.print(f"[bold red]{error_msg}[/bold red]") console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg return error_msg
@observe def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs): """
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = self.toolbox.tools
if self.managed_agents is not None:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
try:
if isinstance(arguments, str):
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
tool_description = get_tool_description_with_args(
available_tools[tool_name]
)
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def step(self, log_entry: ActionStep):
"""To be implemented in children classes"""
pass
def run(
self,
task: str,
stream: bool = False,
reset: bool = True,
oneshot: bool = False,
**kwargs,
):
""" """
Runs the agent for the given task. Runs the agent for the given task.
@ -441,10 +497,11 @@ class ReactAgent(BaseAgent):
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
""" """
print("LANGFUSE REF:", langfuse_context.get_current_trace_url())
self.task = task self.task = task
if len(kwargs) > 0: if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.task += (
f"\nYou have been provided with these initial arguments: {str(kwargs)}."
)
self.state = kwargs.copy() self.state = kwargs.copy()
self.initialize_system_prompt() self.initialize_system_prompt()
@ -460,7 +517,7 @@ class ReactAgent(BaseAgent):
else: else:
self.logs.append(system_prompt_step) self.logs.append(system_prompt_step)
console.rule("[bold]New task", characters='=') console.rule("[bold]New task", characters="=")
console.print(self.task) console.print(self.task)
self.logs.append(TaskStep(task=task)) self.logs.append(TaskStep(task=task))
@ -489,8 +546,13 @@ class ReactAgent(BaseAgent):
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(iteration=iteration, start_time=step_start_time) step_log = ActionStep(iteration=iteration, start_time=step_start_time)
try: try:
if self.planning_interval is not None and iteration % self.planning_interval == 0: if (
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) self.planning_interval is not None
and iteration % self.planning_interval == 0
):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step") console.rule("[bold]New step")
self.step(step_log) self.step(step_log)
if step_log.final_answer is not None: if step_log.final_answer is not None:
@ -530,8 +592,13 @@ class ReactAgent(BaseAgent):
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(iteration=iteration, start_time=step_start_time) step_log = ActionStep(iteration=iteration, start_time=step_start_time)
try: try:
if self.planning_interval is not None and iteration % self.planning_interval == 0: if (
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) self.planning_interval is not None
and iteration % self.planning_interval == 0
):
self.planning_step(
task, is_first_step=(iteration == 0), iteration=iteration
)
console.rule("[bold]New step") console.rule("[bold]New step")
self.step(step_log) self.step(step_log)
if step_log.final_answer is not None: if step_log.final_answer is not None:
@ -559,7 +626,7 @@ class ReactAgent(BaseAgent):
return final_answer return final_answer
def planning_step(self, task, is_first_step: bool = False, iteration: int = None): def planning_step(self, task, is_first_step: bool, iteration: int):
""" """
Used periodically by the agent to plan the next steps to reach the objective. Used periodically by the agent to plan the next steps to reach the objective.
@ -569,7 +636,10 @@ class ReactAgent(BaseAgent):
iteration (`int`): The number of the current step, used as an indication for the LLM. iteration (`int`): The number of the current step, used as an indication for the LLM.
""" """
if is_first_step: if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS} message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS,
}
message_prompt_task = { message_prompt_task = {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": f"""Here is the task: "content": f"""Here is the task:
@ -589,15 +659,20 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format( "content": USER_PROMPT_PLAN.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), tool_descriptions=self._toolbox.show_tool_descriptions(
self.tool_description_template
),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
), ),
answer_facts=answer_facts, answer_facts=answer_facts,
), ),
} }
answer_plan = self.llm_engine( answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"] [message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
) )
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
@ -608,7 +683,9 @@ Now begin!""",
``` ```
{answer_facts} {answer_facts}
```""".strip() ```""".strip()
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Initial plan") console.rule("[orange]Initial plan")
console.print(final_plan_redaction) console.print(final_plan_redaction)
else: # update plan else: # update plan
@ -625,7 +702,9 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE, "content": USER_PROMPT_FACTS_UPDATE,
} }
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message]) facts_update = self.llm_engine(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
# Redact updated plan # Redact updated plan
plan_update_message = { plan_update_message = {
@ -636,25 +715,34 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format( "content": USER_PROMPT_PLAN_UPDATE.format(
task=task, task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), tool_descriptions=self._toolbox.show_tool_descriptions(
self.tool_description_template
),
managed_agents_descriptions=( managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" show_agents_descriptions(self.managed_agents)
if self.managed_agents is not None
else ""
), ),
facts_update=facts_update, facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration), remaining_steps=(self.max_iterations - iteration),
), ),
} }
plan_update = self.llm_engine( plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"] [plan_update_message] + agent_memory + [plan_update_message_user],
stop_sequences=["<end_plan>"],
) )
# Log final facts and plan # Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
task=task, plan_update=plan_update
)
final_facts_redaction = f"""Here is the updated list of the facts that I know: final_facts_redaction = f"""Here is the updated list of the facts that I know:
``` ```
{facts_update} {facts_update}
```""" ```"""
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.rule("[orange]Updated plan") console.rule("[orange]Updated plan")
console.print(final_plan_redaction) console.print(final_plan_redaction)
@ -705,14 +793,20 @@ class JsonAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose: if self.verbose:
console.rule("[italic]Calling LLM engine with this last message:", align="left") console.rule(
"[italic]Calling LLM engine with this last message:", align="left"
)
console.print(self.prompt_messages[-1]) console.print(self.prompt_messages[-1])
console.rule() console.rule()
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine( llm_output = self.llm_engine(
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args self.prompt_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
@ -723,7 +817,9 @@ class JsonAgent(ReactAgent):
console.print(llm_output) console.print(llm_output)
# Parse # Parse
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") rationale, action = self.extract_action(
llm_output=llm_output, split_token="Action:"
)
try: try:
tool_name, arguments = self.tool_parser(action) tool_name, arguments = self.tool_parser(action)
@ -807,12 +903,18 @@ class CodeAgent(ReactAgent):
) )
self.python_evaluator = evaluate_python_code self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = (
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) additional_authorized_imports if additional_authorized_imports else []
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports)) )
self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)
)
self.system_prompt = self.system_prompt.replace(
"<<authorized_imports>>", str(self.authorized_imports)
)
self.custom_tools = {} self.custom_tools = {}
def step(self, log_entry: Dict[str, Any]): def step(self, log_entry: ActionStep):
""" """
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method. The errors are raised here, they are caught and logged in the run() method.
@ -825,14 +927,20 @@ class CodeAgent(ReactAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
if self.verbose: if self.verbose:
console.rule("[italic]Calling LLM engine with these last messages:", align="left") console.rule(
"[italic]Calling LLM engine with these last messages:", align="left"
)
console.print(self.prompt_messages[-2:]) console.print(self.prompt_messages[-2:])
console.rule() console.rule()
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.llm_engine( llm_output = self.llm_engine(
self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args self.prompt_messages,
stop_sequences=["<end_action>", "Observation:"],
**additional_args,
) )
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
@ -840,13 +948,19 @@ class CodeAgent(ReactAgent):
if self.verbose: if self.verbose:
console.rule("[italic]Output message of the LLM:") console.rule("[italic]Output message of the LLM:")
console.print(Syntax(llm_output, lexer='markdown', background_color='default')) console.print(
Syntax(llm_output, lexer="markdown", background_color="default")
)
# Parse # Parse
try: try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") rationale, raw_code_action = self.extract_action(
llm_output=llm_output, split_token="Code:"
)
except Exception as e: except Exception as e:
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") console.print(
f"Error in extracting action, trying to parse the whole output. Error trace: {e}"
)
rationale, raw_code_action = llm_output, llm_output rationale, raw_code_action = llm_output, llm_output
try: try:
@ -856,14 +970,17 @@ class CodeAgent(ReactAgent):
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.rationale = rationale log_entry.rationale = rationale
log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action} log_entry.tool_call = {
"tool_name": "code interpreter",
"tool_arguments": code_action,
}
# Execute # Execute
if self.verbose: if self.verbose:
console.rule("[italic]Agent thoughts") console.rule("[italic]Agent thoughts")
console.print(rationale) console.print(rationale)
console.rule("[bold]Agent is executing the code below:", align="left") console.rule("[bold]Agent is executing the code below:", align="left")
console.print(Syntax(code_action, lexer='python', background_color='default')) console.print(Syntax(code_action, lexer="python", background_color="default"))
console.rule("", align="left") console.rule("", align="left")
try: try:
@ -886,7 +1003,9 @@ class CodeAgent(ReactAgent):
if result is not None: if result is not None:
console.rule("Last output from code snippet:", align="left") console.rule("Last output from code snippet:", align="left")
console.print(str(result)) console.print(str(result))
observation += "Last output from code snippet:\n" + truncate_content(str(result)) observation += "Last output from code snippet:\n" + truncate_content(
str(result)
)
log_entry.observation = observation log_entry.observation = observation
except Exception as e: except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}" error_msg = f"Code execution failed due to the following error:\n{str(e)}"
@ -902,7 +1021,14 @@ class CodeAgent(ReactAgent):
class ManagedAgent: class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False): def __init__(
self,
agent,
name,
description,
additional_prompting=None,
provide_run_summary=False,
):
self.agent = agent self.agent = agent
self.name = name self.name = name
self.description = description self.description = description
@ -925,18 +1051,22 @@ Your final_answer WILL HAVE to contain these parts:
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
<<additional_prompting>>""" {{additional_prompting}}"""
if self.additional_prompting: if self.additional_prompting:
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip() full_task = full_task.replace(
"\n{{additional_prompting}}", self.additional_prompting
).strip()
else: else:
full_task = full_task.replace("\n<<additional_prompting>>", "").strip() full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
return full_task return full_task
def __call__(self, request, **kwargs): def __call__(self, request, **kwargs):
full_task = self.write_full_task(request) full_task = self.write_full_task(request)
output = self.agent.run(full_task, **kwargs) output = self.agent.run(full_task, **kwargs)
if self.provide_run_summary: if self.provide_run_summary:
answer = f"Here is the final answer from your managed agent '{self.name}':\n" answer = (
f"Here is the final answer from your managed agent '{self.name}':\n"
)
answer += str(output) answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True): for message in self.agent.write_inner_memory_from_logs(summary_mode=True):

View File

@ -105,7 +105,9 @@ def get_remote_tools(logger, organization="huggingface-tools"):
tools = {} tools = {}
for space_info in spaces: for space_info in spaces:
repo_id = space_info.id repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") resolved_config_file = hf_hub_download(
repo_id, TOOL_CONFIG_FILE, repo_type="space"
)
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader) config = json.load(reader)
task = repo_id.split("/")[-1] task = repo_id.split("/")[-1]
@ -131,7 +133,9 @@ class PythonInterpreterTool(Tool):
if authorized_imports is None: if authorized_imports is None:
self.authorized_imports = list(set(LIST_SAFE_MODULES)) self.authorized_imports = list(set(LIST_SAFE_MODULES))
else: else:
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.authorized_imports = list(
set(LIST_SAFE_MODULES) | set(authorized_imports)
)
self.inputs = { self.inputs = {
"code": { "code": {
"type": "string", "type": "string",
@ -145,7 +149,11 @@ class PythonInterpreterTool(Tool):
def forward(self, code): def forward(self, code):
output = str( output = str(
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports) evaluate_python_code(
code,
static_tools=BASE_PYTHON_TOOLS,
authorized_imports=self.authorized_imports,
)
) )
return output return output
@ -153,16 +161,21 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool): class FinalAnswerTool(Tool):
name = "final_answer" name = "final_answer"
description = "Provides a final answer to the given problem." description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} inputs = {
"answer": {"type": "any", "description": "The final answer to the problem"}
}
output_type = "any" output_type = "any"
def forward(self, answer): def forward(self, answer):
return answer return answer
class UserInputTool(Tool): class UserInputTool(Tool):
name = "user_input" name = "user_input"
description = "Asks for user's input on a specific question" description = "Asks for user's input on a specific question"
inputs = {"question": {"type": "string", "description": "The question to ask the user"}} inputs = {
"question": {"type": "string", "description": "The question to ask the user"}
}
output_type = "string" output_type = "string"
def forward(self, question): def forward(self, question):

View File

@ -18,6 +18,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import BaseAgent, AgentStep, ActionStep from .agents import BaseAgent, AgentStep, ActionStep
import gradio as gr import gradio as gr
def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps""" """Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):
@ -33,7 +34,9 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
content=str(content), content=str(content),
) )
if step_log.observation is not None: if step_log.observation is not None:
yield gr.ChatMessage(role="assistant", content=f"```\n{step_log.observation}\n```") yield gr.ChatMessage(
role="assistant", content=f"```\n{step_log.observation}\n```"
)
if step_log.error is not None: if step_log.error is not None:
yield gr.ChatMessage( yield gr.ChatMessage(
role="assistant", role="assistant",
@ -42,7 +45,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
) )
def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs): def stream_to_gradio(
agent,
task: str,
test_mode: bool = False,
reset_agent_memory: bool = False,
**kwargs,
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
@ -52,7 +61,10 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
final_answer = step_log # Last log is the run's final_answer final_answer = step_log # Last log is the run's final_answer
if isinstance(final_answer, AgentText): if isinstance(final_answer, AgentText):
yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```") yield gr.ChatMessage(
role="assistant",
content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```",
)
elif isinstance(final_answer, AgentImage): elif isinstance(final_answer, AgentImage):
yield gr.ChatMessage( yield gr.ChatMessage(
role="assistant", role="assistant",
@ -67,8 +79,9 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo
yield gr.ChatMessage(role="assistant", content=str(final_answer)) yield gr.ChatMessage(role="assistant", content=str(final_answer))
class GradioUI(): class GradioUI:
"""A one-line interface to launch your agent in Gradio""" """A one-line interface to launch your agent in Gradio"""
def __init__(self, agent: BaseAgent): def __init__(self, agent: BaseAgent):
self.agent = agent self.agent = agent
@ -83,10 +96,17 @@ class GradioUI():
def run(self): def run(self):
with gr.Blocks() as demo: with gr.Blocks() as demo:
stored_message = gr.State([]) stored_message = gr.State([])
chatbot = gr.Chatbot(label="Agent", chatbot = gr.Chatbot(
label="Agent",
type="messages", type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")) avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
text_input = gr.Textbox(lines=1, label="Chat Message") text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(self.interact_with_agent, [stored_message, chatbot], [chatbot]) text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
demo.launch() demo.launch()

View File

@ -39,7 +39,9 @@ class MessageRole(str, Enum):
return [r.value for r in cls] return [r.value for r in cls]
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}): def get_clean_message_list(
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
):
""" """
Subsequent messages with the same role will be concatenated to a single message. Subsequent messages with the same role will be concatenated to a single message.
@ -54,12 +56,17 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
role = message["role"] role = message["role"]
if role not in MessageRole.roles(): if role not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions: if role in role_conversions:
message["role"] = role_conversions[role] message["role"] = role_conversions[role]
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: if (
len(final_message_list) > 0
and message["role"] == final_message_list[-1]["role"]
):
final_message_list[-1]["content"] += "\n=======\n" + message["content"] final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else: else:
final_message_list.append(message) final_message_list.append(message)
@ -81,8 +88,12 @@ class HfEngine:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e: except Exception as e:
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.") logger.warning(
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct") f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
)
self.tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
)
def get_token_counts(self): def get_token_counts(self):
return { return {
@ -91,12 +102,18 @@ class HfEngine:
} }
def generate( def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
): ):
raise NotImplementedError raise NotImplementedError
def __call__( def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
) -> str: ) -> str:
"""Process the input messages and return the model's response. """Process the input messages and return the model's response.
@ -127,11 +144,15 @@ class HfEngine:
``` ```
""" """
if not isinstance(messages, List): if not isinstance(messages, List):
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.") raise ValueError(
"Messages should be a list of dictionaries with 'role' and 'content' keys."
)
if stop_sequences is None: if stop_sequences is None:
stop_sequences = [] stop_sequences = []
response = self.generate(messages, stop_sequences, grammar) response = self.generate(messages, stop_sequences, grammar)
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True)) self.last_input_token_count = len(
self.tokenizer.apply_chat_template(messages, tokenize=True)
)
self.last_output_token_count = len(self.tokenizer.encode(response)) self.last_output_token_count = len(self.tokenizer.encode(response))
# Remove stop sequences from LLM output # Remove stop sequences from LLM output
@ -175,18 +196,28 @@ class HfApiEngine(HfEngine):
self.max_tokens = max_tokens self.max_tokens = max_tokens
def generate( def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
) -> str: ) -> str:
# Get clean message list # Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
# Send messages to the Hugging Face Inference API # Send messages to the Hugging Face Inference API
if grammar is not None: if grammar is not None:
response = self.client.chat_completion( response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar messages,
stop=stop_sequences,
max_tokens=self.max_tokens,
response_format=grammar,
) )
else: else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens) response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=self.max_tokens
)
response = response.choices[0].message.content response = response.choices[0].message.content
return response return response
@ -207,7 +238,9 @@ class TransformersEngine(HfEngine):
max_length: int = 1500, max_length: int = 1500,
) -> str: ) -> str:
# Get clean message list # Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) messages = get_clean_message_list(
messages, role_conversions=llama_role_conversions
)
# Get LLM output # Get LLM output
if stop_sequences is not None and len(stop_sequences) > 0: if stop_sequences is not None and len(stop_sequences) > 0:

View File

@ -17,12 +17,14 @@
from .utils import console from .utils import console
class Monitor: class Monitor:
def __init__(self, tracked_llm_engine): def __init__(self, tracked_llm_engine):
self.step_durations = [] self.step_durations = []
self.tracked_llm_engine = tracked_llm_engine self.tracked_llm_engine = tracked_llm_engine
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found": if (
getattr(self.tracked_llm_engine, "last_input_token_count", "Not found")
!= "Not found"
):
self.total_input_token_count = 0 self.total_input_token_count = 0
self.total_output_token_count = 0 self.total_output_token_count = 0
@ -33,103 +35,11 @@ class Monitor:
console.print(f"- Time taken: {step_duration:.2f} seconds") console.print(f"- Time taken: {step_duration:.2f} seconds")
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count self.total_input_token_count += (
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count self.tracked_llm_engine.last_input_token_count
)
self.total_output_token_count += (
self.tracked_llm_engine.last_output_token_count
)
console.print(f"- Input tokens: {self.total_input_token_count:,}") console.print(f"- Input tokens: {self.total_input_token_count:,}")
console.print(f"- Output tokens: {self.total_output_token_count:,}") console.print(f"- Output tokens: {self.total_output_token_count:,}")
from typing import Optional, Union, List, Any
import httpx
import logging
import os
from langfuse.client import Langfuse, StatefulTraceClient, StatefulSpanClient, StateType
class BaseTracker:
def __init__(self):
pass
@classmethod
def call(cls, *args, **kwargs):
pass
class LangfuseTracker(BaseTracker):
log = logging.getLogger("langfuse")
def __init__(self, *, public_key: Optional[str] = None, secret_key: Optional[str] = None,
host: Optional[str] = None, debug: bool = False, stateful_client: Optional[
Union[StatefulTraceClient, StatefulSpanClient]
] = None, update_stateful_client: bool = False, version: Optional[str] = None,
session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None,
release: Optional[str] = None, metadata: Optional[Any] = None, tags: Optional[List[str]] = None,
threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None,
max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None,
httpx_client: Optional[httpx.Client] = None, sdk_integration: str = "default") -> None:
super().__init__()
self.version = version
self.session_id = session_id
self.user_id = user_id
self.trace_name = trace_name
self.release = release
self.metadata = metadata
self.tags = tags
self.root_span = None
self.update_stateful_client = update_stateful_client
self.langfuse = None
prio_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY")
prio_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY")
prio_host = host or os.environ.get(
"LANGFUSE_HOST", "https://cloud.langfuse.com"
)
if stateful_client and isinstance(stateful_client, StatefulTraceClient):
self.trace = stateful_client
self._task_manager = stateful_client.task_manager
return
elif stateful_client and isinstance(stateful_client, StatefulSpanClient):
self.root_span = stateful_client
self.trace = StatefulTraceClient(
stateful_client.client,
stateful_client.trace_id,
StateType.TRACE,
stateful_client.trace_id,
stateful_client.task_manager,
)
self._task_manager = stateful_client.task_manager
return
args = {
"public_key": prio_public_key,
"secret_key": prio_secret_key,
"host": prio_host,
"debug": debug,
}
if release is not None:
args["release"] = release
if threads is not None:
args["threads"] = threads
if flush_at is not None:
args["flush_at"] = flush_at
if flush_interval is not None:
args["flush_interval"] = flush_interval
if max_retries is not None:
args["max_retries"] = max_retries
if timeout is not None:
args["timeout"] = timeout
if enabled is not None:
args["enabled"] = enabled
if httpx_client is not None:
args["httpx_client"] = httpx_client
args["sdk_integration"] = sdk_integration
self.langfuse = Langfuse(**args)
self.trace: Optional[StatefulTraceClient] = None
self._task_manager = self.langfuse.task_manager
def call(self, i, o, name=None, **kwargs):
self.langfuse.trace(input=i, output=o, name=name, metadata=kwargs)

View File

@ -42,7 +42,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
return prompt_or_repo_id return prompt_or_repo_id
prompt_file = cached_file( prompt_file = cached_file(
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name} prompt_or_repo_id,
PROMPT_FILES[mode],
repo_type="dataset",
user_agent={"agent": agent_name},
) )
with open(prompt_file, "r", encoding="utf-8") as f: with open(prompt_file, "r", encoding="utf-8") as f:
return f.read() return f.read()

View File

@ -26,6 +26,7 @@ import pandas as pd
from .utils import truncate_content from .utils import truncate_content
class InterpreterError(ValueError): class InterpreterError(ValueError):
""" """
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
@ -38,7 +39,8 @@ class InterpreterError(ValueError):
ERRORS = { ERRORS = {
name: getattr(builtins, name) name: getattr(builtins, name)
for name in dir(builtins) for name in dir(builtins)
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) if isinstance(getattr(builtins, name), type)
and issubclass(getattr(builtins, name), BaseException)
} }
@ -92,7 +94,9 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools):
elif isinstance(expression.op, ast.Invert): elif isinstance(expression.op, ast.Invert):
return ~operand return ~operand
else: else:
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") raise InterpreterError(
f"Unary operation {expression.op.__class__.__name__} is not supported."
)
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
@ -102,7 +106,9 @@ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
new_state = state.copy() new_state = state.copy()
for arg, value in zip(args, values): for arg, value in zip(args, values):
new_state[arg] = value new_state[arg] = value
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools) return evaluate_ast(
lambda_expression.body, new_state, static_tools, custom_tools
)
return lambda_func return lambda_func
@ -120,7 +126,9 @@ def evaluate_while(while_loop, state, static_tools, custom_tools):
break break
iterations += 1 iterations += 1
if iterations > max_iterations: if iterations > max_iterations:
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded") raise InterpreterError(
f"Maximum number of {max_iterations} iterations in While loop exceeded"
)
return None return None
@ -128,7 +136,10 @@ def create_function(func_def, state, static_tools, custom_tools):
def new_func(*args, **kwargs): def new_func(*args, **kwargs):
func_state = state.copy() func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args] arg_names = [arg.arg for arg in func_def.args.args]
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults] default_values = [
evaluate_ast(d, state, static_tools, custom_tools)
for d in func_def.args.defaults
]
# Apply default values # Apply default values
defaults = dict(zip(arg_names[-len(default_values) :], default_values)) defaults = dict(zip(arg_names[-len(default_values) :], default_values))
@ -180,26 +191,39 @@ def create_class(class_name, class_bases, class_body):
def evaluate_function_def(func_def, state, static_tools, custom_tools): def evaluate_function_def(func_def, state, static_tools, custom_tools):
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools) custom_tools[func_def.name] = create_function(
func_def, state, static_tools, custom_tools
)
return custom_tools[func_def.name] return custom_tools[func_def.name]
def evaluate_class_def(class_def, state, static_tools, custom_tools): def evaluate_class_def(class_def, state, static_tools, custom_tools):
class_name = class_def.name class_name = class_def.name
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases] bases = [
evaluate_ast(base, state, static_tools, custom_tools)
for base in class_def.bases
]
class_dict = {} class_dict = {}
for stmt in class_def.body: for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef): if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools) class_dict[stmt.name] = evaluate_function_def(
stmt, state, static_tools, custom_tools
)
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
for target in stmt.targets: for target in stmt.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools) class_dict[target.id] = evaluate_ast(
stmt.value, state, static_tools, custom_tools
)
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools) class_dict[target.attr] = evaluate_ast(
stmt.value, state, static_tools, custom_tools
)
else: else:
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") raise InterpreterError(
f"Unsupported statement in class body: {stmt.__class__.__name__}"
)
new_class = type(class_name, tuple(bases), class_dict) new_class = type(class_name, tuple(bases), class_dict)
state[class_name] = new_class state[class_name] = new_class
@ -223,7 +247,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
elif isinstance(target, ast.List): elif isinstance(target, ast.List):
return [get_current_value(elt) for elt in target.elts] return [get_current_value(elt) for elt in target.elts]
else: else:
raise InterpreterError("AugAssign not supported for {type(target)} targets.") raise InterpreterError(
"AugAssign not supported for {type(target)} targets."
)
current_value = get_current_value(expression.target) current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
@ -232,7 +258,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
if isinstance(expression.op, ast.Add): if isinstance(expression.op, ast.Add):
if isinstance(current_value, list): if isinstance(current_value, list):
if not isinstance(value_to_add, list): if not isinstance(value_to_add, list):
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") raise InterpreterError(
f"Cannot add non-list value {value_to_add} to a list."
)
updated_value = current_value + value_to_add updated_value = current_value + value_to_add
else: else:
updated_value = current_value + value_to_add updated_value = current_value + value_to_add
@ -259,7 +287,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools):
elif isinstance(expression.op, ast.RShift): elif isinstance(expression.op, ast.RShift):
updated_value = current_value >> value_to_add updated_value = current_value >> value_to_add
else: else:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") raise InterpreterError(
f"Operation {type(expression.op).__name__} is not supported."
)
# Update the state # Update the state
set_value(expression.target, updated_value, state, static_tools, custom_tools) set_value(expression.target, updated_value, state, static_tools, custom_tools)
@ -311,7 +341,9 @@ def evaluate_binop(binop, state, static_tools, custom_tools):
elif isinstance(binop.op, ast.RShift): elif isinstance(binop.op, ast.RShift):
return left_val >> right_val return left_val >> right_val
else: else:
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") raise NotImplementedError(
f"Binary operation {type(binop.op).__name__} is not implemented."
)
def evaluate_assign(assign, state, static_tools, custom_tools): def evaluate_assign(assign, state, static_tools, custom_tools):
@ -321,7 +353,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
set_value(target, result, state, static_tools, custom_tools) set_value(target, result, state, static_tools, custom_tools)
else: else:
if len(assign.targets) != len(result): if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") raise InterpreterError(
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
)
expanded_values = [] expanded_values = []
for tgt in assign.targets: for tgt in assign.targets:
if isinstance(tgt, ast.Starred): if isinstance(tgt, ast.Starred):
@ -336,7 +370,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
def set_value(target, value, state, static_tools, custom_tools): def set_value(target, value, state, static_tools, custom_tools):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in static_tools: if target.id in static_tools:
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") raise InterpreterError(
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
)
state[target.id] = value state[target.id] = value
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple): if not isinstance(value, tuple):
@ -399,9 +435,14 @@ def evaluate_call(call, state, static_tools, custom_tools):
else: else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools)) args.append(evaluate_ast(arg, state, static_tools, custom_tools))
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords} kwargs = {
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools)
for keyword in call.keywords
}
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes if (
isinstance(func, type) and len(func.__module__.split(".")) > 1
): # Check for user-defined classes
# Instantiate the class using its constructor # Instantiate the class using its constructor
obj = func.__new__(func) # Create a new instance of the class obj = func.__new__(func) # Create a new instance of the class
if hasattr(obj, "__init__"): # Check if the class has an __init__ method if hasattr(obj, "__init__"): # Check if the class has an __init__ method
@ -441,7 +482,9 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
value = evaluate_ast(subscript.value, state, static_tools, custom_tools) value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(value, str) and isinstance(index, str): if isinstance(value, str) and isinstance(index, str):
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") raise InterpreterError(
"You're trying to subscript a string with a string index, which is impossible"
)
if isinstance(value, pd.core.indexing._LocIndexer): if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj parent_object = value.obj
return parent_object.loc[index] return parent_object.loc[index]
@ -453,11 +496,15 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
return value[index] return value[index]
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") raise InterpreterError(
f"Index {index} out of bounds for list of length {len(value)}"
)
return value[int(index)] return value[int(index)]
elif isinstance(value, str): elif isinstance(value, str):
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") raise InterpreterError(
f"Index {index} out of bounds for string of length {len(value)}"
)
return value[index] return value[index]
elif index in value: elif index in value:
return value[index] return value[index]
@ -483,7 +530,10 @@ def evaluate_name(name, state, static_tools, custom_tools):
def evaluate_condition(condition, state, static_tools, custom_tools): def evaluate_condition(condition, state, static_tools, custom_tools):
left = evaluate_ast(condition.left, state, static_tools, custom_tools) left = evaluate_ast(condition.left, state, static_tools, custom_tools)
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators] comparators = [
evaluate_ast(c, state, static_tools, custom_tools)
for c in condition.comparators
]
ops = [type(op) for op in condition.ops] ops = [type(op) for op in condition.ops]
result = True result = True
@ -561,9 +611,13 @@ def evaluate_for(for_loop, state, static_tools, custom_tools):
def evaluate_listcomp(listcomp, state, static_tools, custom_tools): def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
def inner_evaluate(generators, index, current_state): def inner_evaluate(generators, index, current_state):
if index >= len(generators): if index >= len(generators):
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)] return [
evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)
]
generator = generators[index] generator = generators[index]
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools) iter_value = evaluate_ast(
generator.iter, current_state, static_tools, custom_tools
)
result = [] result = []
for value in iter_value: for value in iter_value:
new_state = current_state.copy() new_state = current_state.copy()
@ -572,7 +626,10 @@ def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
new_state[elem.id] = value[idx] new_state[elem.id] = value[idx]
else: else:
new_state[generator.target.id] = value new_state[generator.target.id] = value
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs): if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
for if_clause in generator.ifs
):
result.extend(inner_evaluate(generators, index + 1, new_state)) result.extend(inner_evaluate(generators, index + 1, new_state))
return result return result
@ -586,7 +643,9 @@ def evaluate_try(try_node, state, static_tools, custom_tools):
except Exception as e: except Exception as e:
matched = False matched = False
for handler in try_node.handlers: for handler in try_node.handlers:
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)): if handler.type is None or isinstance(
e, evaluate_ast(handler.type, state, static_tools, custom_tools)
):
matched = True matched = True
if handler.name: if handler.name:
state[handler.name] = e state[handler.name] = e
@ -638,7 +697,9 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools):
def evaluate_with(with_node, state, static_tools, custom_tools): def evaluate_with(with_node, state, static_tools, custom_tools):
contexts = [] contexts = []
for item in with_node.items: for item in with_node.items:
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools) context_expr = evaluate_ast(
item.context_expr, state, static_tools, custom_tools
)
if item.optional_vars: if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__() state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id]) contexts.append(state[item.optional_vars.id])
@ -661,7 +722,9 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
def import_modules(expression, state, authorized_imports): def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name): def check_module_authorized(module_name):
module_path = module_name.split(".") module_path = module_name.split(".")
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
return any(subpath in authorized_imports for subpath in module_subpaths) return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import): if isinstance(expression, ast.Import):
@ -676,7 +739,9 @@ def import_modules(expression, state, authorized_imports):
return None return None
elif isinstance(expression, ast.ImportFrom): elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module): if check_module_authorized(expression.module):
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names]
)
for alias in expression.names: for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name) state[alias.asname or alias.name] = getattr(module, alias.name)
else: else:
@ -691,9 +756,14 @@ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
for value in iter_value: for value in iter_value:
new_state = state.copy() new_state = state.copy()
set_value(gen.target, value, new_state, static_tools, custom_tools) set_value(gen.target, value, new_state, static_tools, custom_tools)
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs): if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools)
for if_clause in gen.ifs
):
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools) key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools) val = evaluate_ast(
dictcomp.value, new_state, static_tools, custom_tools
)
result[key] = val result[key] = val
return result return result
@ -744,7 +814,10 @@ def evaluate_ast(
# Constant -> just return the value # Constant -> just return the value
return expression.value return expression.value
elif isinstance(expression, ast.Tuple): elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts) return tuple(
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools) return evaluate_listcomp(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.UnaryOp): elif isinstance(expression, ast.UnaryOp):
@ -770,8 +843,13 @@ def evaluate_ast(
return evaluate_function_def(expression, state, static_tools, custom_tools) return evaluate_function_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Dict): elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values # Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys] keys = [
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values] evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys
]
values = [
evaluate_ast(v, state, static_tools, custom_tools)
for v in expression.values
]
return dict(zip(keys, values)) return dict(zip(keys, values))
elif isinstance(expression, ast.Expr): elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content # Expression -> evaluate the content
@ -788,10 +866,18 @@ def evaluate_ast(
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, static_tools, custom_tools) return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.JoinedStr): elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values]) return "".join(
[
str(evaluate_ast(v, state, static_tools, custom_tools))
for v in expression.values
]
)
elif isinstance(expression, ast.List): elif isinstance(expression, ast.List):
# List -> evaluate all elements # List -> evaluate all elements
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts] return [
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
]
elif isinstance(expression, ast.Name): elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state # Name -> pick up the value in the state
return evaluate_name(expression, state, static_tools, custom_tools) return evaluate_name(expression, state, static_tools, custom_tools)
@ -815,7 +901,9 @@ def evaluate_ast(
evaluate_ast(expression.upper, state, static_tools, custom_tools) evaluate_ast(expression.upper, state, static_tools, custom_tools)
if expression.upper is not None if expression.upper is not None
else None, else None,
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None, evaluate_ast(expression.step, state, static_tools, custom_tools)
if expression.step is not None
else None,
) )
elif isinstance(expression, ast.DictComp): elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, state, static_tools, custom_tools) return evaluate_dictcomp(expression, state, static_tools, custom_tools)
@ -834,17 +922,24 @@ def evaluate_ast(
elif isinstance(expression, ast.With): elif isinstance(expression, ast.With):
return evaluate_with(expression, state, static_tools, custom_tools) return evaluate_with(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Set): elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts} return {
evaluate_ast(elt, state, static_tools, custom_tools)
for elt in expression.elts
}
elif isinstance(expression, ast.Return): elif isinstance(expression, ast.Return):
raise ReturnException( raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None evaluate_ast(expression.value, state, static_tools, custom_tools)
if expression.value
else None
) )
else: else:
# For now we refuse anything else. Let's add things as we need them. # For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str: def truncate_print_outputs(
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
) -> str:
if len(print_outputs) < max_len_outputs: if len(print_outputs) < max_len_outputs:
return print_outputs return print_outputs
else: else:
@ -895,8 +990,12 @@ def evaluate_python_code(
OPERATIONS_COUNT = 0 OPERATIONS_COUNT = 0
try: try:
for node in expression.body: for node in expression.body:
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) result = evaluate_ast(
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) node, state, static_tools, custom_tools, authorized_imports
)
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
)
return result return result
except InterpreterError as e: except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)

View File

@ -26,7 +26,9 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search" name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'.""" Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "string", "description": "The search query to perform."}} inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "any" output_type = "any"
def forward(self, query: str) -> str: def forward(self, query: str) -> str:

View File

@ -26,7 +26,13 @@ from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder from huggingface_hub import (
create_repo,
get_collection,
hf_hub_download,
metadata_update,
upload_folder,
)
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version from packaging import version
@ -73,7 +79,9 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model" return "model"
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.") raise EnvironmentError(
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
)
except Exception: except Exception:
return "model" return "model"
except Exception: except Exception:
@ -158,7 +166,15 @@ class Tool:
"inputs": dict, "inputs": dict,
"output_type": str, "output_type": str,
} }
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"] authorized_types = [
"string",
"integer",
"number",
"image",
"audio",
"any",
"boolean",
]
for attr, expected_type in required_attributes.items(): for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None) attr_value = getattr(self, attr, None)
@ -169,7 +185,9 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
) )
for input_name, input_content in self.inputs.items(): for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." assert isinstance(
input_content, dict
), f"Input '{input_name}' should be a dictionary."
assert ( assert (
"type" in input_content and "description" in input_content "type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
@ -251,7 +269,11 @@ class Tool:
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f: with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__)) f.write(
APP_FILE_TEMPLATE.format(
module_name=last_module, class_name=self.__class__.__name__
)
)
# Save requirements file # Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
@ -343,7 +365,9 @@ class Tool:
custom_tool = config custom_tool = config
tool_class = custom_tool["tool_class"] tool_class = custom_tool["tool_class"]
tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs) tool_class = get_class_from_dynamic_module(
tool_class, repo_id, token=token, **hub_kwargs
)
if len(tool_class.name) == 0: if len(tool_class.name) == 0:
tool_class.name = custom_tool["name"] tool_class.name = custom_tool["name"]
@ -420,7 +444,9 @@ class Tool:
with tempfile.TemporaryDirectory() as work_dir: with tempfile.TemporaryDirectory() as work_dir:
# Save all files. # Save all files.
self.save(work_dir) self.save(work_dir)
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
)
return upload_folder( return upload_folder(
repo_id=repo_id, repo_id=repo_id,
commit_message=commit_message, commit_message=commit_message,
@ -432,7 +458,11 @@ class Tool:
@staticmethod @staticmethod
def from_space( def from_space(
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None space_id: str,
name: str,
description: str,
api_name: Optional[str] = None,
token: Optional[str] = None,
): ):
""" """
Creates a [`Tool`] from a Space given its id on the Hub. Creates a [`Tool`] from a Space given its id on the Hub.
@ -485,7 +515,9 @@ class Tool:
self.client = Client(space_id, hf_token=token) self.client = Client(space_id, hf_token=token)
self.name = name self.name = name
self.description = description self.description = description
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] space_description = self.client.view_api(
return_format="dict", print_info=False
)["named_endpoints"]
# If api_name is not defined, take the first of the available APIs for this space # If api_name is not defined, take the first of the available APIs for this space
if api_name is None: if api_name is None:
@ -498,7 +530,9 @@ class Tool:
try: try:
space_description_api = space_description[api_name] space_description_api = space_description[api_name]
except KeyError: except KeyError:
raise KeyError(f"Could not find specified {api_name=} among available api names.") raise KeyError(
f"Could not find specified {api_name=} among available api names."
)
self.inputs = {} self.inputs = {}
for parameter in space_description_api["parameters"]: for parameter in space_description_api["parameters"]:
@ -523,9 +557,11 @@ class Tool:
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name) arg.save(temp_file.name)
arg = temp_file.name arg = temp_file.name
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like( if (
arg isinstance(arg, (str, Path))
): and Path(arg).exists()
and Path(arg).is_file()
) or is_http_url_like(arg):
arg = handle_file(arg) arg = handle_file(arg)
return arg return arg
@ -544,7 +580,9 @@ class Tool:
] # Sometime the space also returns the generation seed, in which case the result is at index 0 ] # Sometime the space also returns the generation seed, in which case the result is at index 0
return output return output
return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token) return SpaceToolWrapper(
space_id, name, description, api_name=api_name, token=token
)
@staticmethod @staticmethod
def from_gradio(gradio_tool): def from_gradio(gradio_tool):
@ -561,7 +599,8 @@ class Tool:
self._gradio_tool = _gradio_tool self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = { self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
for key, value in func_args
} }
self.forward = self._gradio_tool.run self.forward = self._gradio_tool.run
@ -603,7 +642,9 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
""" """
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str: def get_tool_description_with_args(
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) -> str:
compiled_template = compile_jinja_template(description_template) compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render( rendered = compiled_template.render(
tool=tool, tool=tool,
@ -621,7 +662,10 @@ def compile_jinja_template(template):
raise ImportError("template requires jinja2 to be installed.") raise ImportError("template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"): if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.") raise ImportError(
"template requires jinja2>=3.1.0 to be installed. Your version is "
f"{jinja2.__version__}."
)
def raise_exception(message): def raise_exception(message):
raise TemplateError(message) raise TemplateError(message)
@ -697,7 +741,9 @@ class PipelineTool(Tool):
if model is None: if model is None:
if self.default_checkpoint is None: if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.") raise ValueError(
"This tool does not implement a default checkpoint, you need to pass one."
)
model = self.default_checkpoint model = self.default_checkpoint
if pre_processor is None: if pre_processor is None:
pre_processor = model pre_processor = model
@ -720,15 +766,21 @@ class PipelineTool(Tool):
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
""" """
if isinstance(self.pre_processor, str): if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs
)
if isinstance(self.model, str): if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs) self.model = self.model_class.from_pretrained(
self.model, **self.model_kwargs, **self.hub_kwargs
)
if self.post_processor is None: if self.post_processor is None:
self.post_processor = self.pre_processor self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str): elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) self.post_processor = self.post_processor_class.from_pretrained(
self.post_processor, **self.hub_kwargs
)
if self.device is None: if self.device is None:
if self.device_map is not None: if self.device_map is not None:
@ -768,8 +820,12 @@ class PipelineTool(Tool):
encoded_inputs = self.encode(*args, **kwargs) encoded_inputs = self.encode(*args, **kwargs)
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)} tensor_inputs = {
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)} k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
}
non_tensor_inputs = {
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
}
encoded_inputs = send_to_device(tensor_inputs, self.device) encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
@ -790,7 +846,9 @@ def launch_gradio_demo(tool_class: Tool):
try: try:
import gradio as gr import gradio as gr
except ImportError: except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.") raise ImportError(
"Gradio should be installed in order to launch a gradio demo."
)
tool = tool_class() tool = tool_class()
@ -807,11 +865,15 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs = [] gradio_inputs = []
for input_name, input_details in tool_class.inputs.items(): for input_name, input_details in tool_class.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"]
]
new_component = input_gradio_component_class(label=input_name) new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component) gradio_inputs.append(new_component)
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type] output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
tool_class.output_type
]
gradio_output = output_gradio_componentclass(label=input_name) gradio_output = output_gradio_componentclass(label=input_name)
gr.Interface( gr.Interface(
@ -875,7 +937,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the " f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
f"code that you have checked." f"code that you have checked."
) )
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs) return Tool.from_hub(
task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
)
def add_description(description): def add_description(description):
@ -935,7 +999,9 @@ class EndpointClient:
payload["parameters"] = params payload["parameters"] = params
# Make API call # Make API call
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data) response = get_session().post(
self.endpoint_url, headers=self.headers, json=payload, data=data
)
# By default, parse the response for the user. # By default, parse the response for the user.
if output_image: if output_image:
@ -972,7 +1038,9 @@ class ToolCollection:
def __init__(self, collection_slug: str, token: Optional[str] = None): def __init__(self, collection_slug: str, token: Optional[str] = None):
self._collection = get_collection(collection_slug, token=token) self._collection = get_collection(collection_slug, token=token)
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"} self._hub_repo_ids = {
item.item_id for item in self._collection.items if item.item_type == "space"
}
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
@ -986,7 +1054,9 @@ def tool(tool_function: Callable) -> Tool:
""" """
parameters = get_json_schema(tool_function)["function"] parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters: if "return" not in parameters:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
class_name = f"{parameters['name'].capitalize()}Tool" class_name = f"{parameters['name'].capitalize()}Tool"
class SpecificTool(Tool): class SpecificTool(Tool):
@ -1000,9 +1070,9 @@ def tool(tool_function: Callable) -> Tool:
return tool_function(*args, **kwargs) return tool_function(*args, **kwargs)
original_signature = inspect.signature(tool_function) original_signature = inspect.signature(tool_function)
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list( new_parameters = [
original_signature.parameters.values() inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
) ] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters) new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature SpecificTool.forward.__signature__ = new_signature
@ -1049,7 +1119,10 @@ class Toolbox:
The template to use to describe the tools. If not provided, the default template will be used. The template to use to describe the tools. If not provided, the default template will be used.
""" """
return "\n".join( return "\n".join(
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()] [
get_tool_description_with_args(tool, tool_description_template)
for tool in self._tools.values()
]
) )
def add_tool(self, tool: Tool): def add_tool(self, tool: Tool):

View File

@ -16,32 +16,29 @@
# limitations under the License. # limitations under the License.
import json import json
import re import re
from typing import Tuple, Dict from typing import Tuple, Dict, Union
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments") _pygments_available = _is_package_available("pygments")
def is_pygments_available(): def is_pygments_available():
return _pygments_available return _pygments_available
from rich.console import Console from rich.console import Console
console = Console() console = Console()
LENGTH_TRUNCATE_REPORTS = 10000
def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS):
if len(content) < max_length:
return content
else:
return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:]
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
'\\"', "'"
)
json_data = json.loads(json_blob, strict=False) json_data = json.loads(json_blob, strict=False)
return json_data return json_data
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -63,7 +60,12 @@ def parse_code_blob(code_blob: str) -> str:
try: try:
pattern = r"```(?:py|python)?\n(.*?)\n```" pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL) match = re.search(pattern, code_blob, re.DOTALL)
if match is None:
raise ValueError(
f"No match ground for regex pattern {pattern} in {code_blob=}."
)
return match.group(1).strip() return match.group(1).strip()
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f""" f"""
@ -77,7 +79,7 @@ Code:
) )
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
json_blob = json_blob.replace("```json", "").replace("```", "") json_blob = json_blob.replace("```json", "").replace("```", "")
tool_call = parse_json_blob(json_blob) tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call: if "action" in tool_call and "action_input" in tool_call:
@ -85,7 +87,25 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
elif "action" in tool_call: elif "action" in tool_call:
return tool_call["action"], None return tool_call["action"], None
else: else:
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call] missing_keys = [
key for key in ["action", "action_input"] if key not in tool_call
]
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}" error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
console.print(f"[bold red]{error_msg}[/bold red]") console.print(f"[bold red]{error_msg}[/bold red]")
raise ValueError(error_msg) raise ValueError(error_msg)
MAX_LENGTH_TRUNCATE_CONTENT = 20000
def truncate_content(
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
) -> str:
if len(content) <= max_length:
return content
else:
return (
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
)

2748
poetry.lock generated

File diff suppressed because it is too large Load Diff

271
requirements.txt Normal file
View File

@ -0,0 +1,271 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
aiofiles==23.2.1
# via gradio
annotated-types==0.7.0
# via pydantic
anyio==4.7.0
# via
# gradio
# httpx
# starlette
appnope==0.1.4
# via ipykernel
asttokens==3.0.0
# via stack-data
beautifulsoup4==4.12.3
# via markdownify
certifi==2024.8.30
# via
# httpcore
# httpx
# requests
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via
# duckduckgo-search
# typer
# uvicorn
comm==0.2.2
# via ipykernel
debugpy==1.8.10
# via ipykernel
decorator==5.1.1
# via ipython
diskcache==5.6.3
# via llama-cpp-python
duckduckgo-search==6.4.1
# via -r requirements.in
executing==2.1.0
# via stack-data
fastapi==0.115.6
# via gradio
ffmpy==0.4.0
# via gradio
filelock==3.16.1
# via
# huggingface-hub
# transformers
fsspec==2024.10.0
# via
# gradio-client
# huggingface-hub
gradio==5.8.0
# via -r requirements.in
gradio-client==1.5.1
# via gradio
h11==0.14.0
# via
# httpcore
# uvicorn
httpcore==1.0.7
# via httpx
httpx==0.28.1
# via
# gradio
# gradio-client
# safehttpx
huggingface-hub==0.26.5
# via
# gradio
# gradio-client
# tokenizers
# transformers
idna==3.10
# via
# anyio
# httpx
# requests
iniconfig==2.0.0
# via pytest
ipykernel==6.29.5
# via -r requirements.in
ipython==8.30.0
# via ipykernel
jedi==0.19.2
# via ipython
jinja2==3.1.4
# via
# -r requirements.in
# gradio
# llama-cpp-python
jupyter-client==8.6.3
# via ipykernel
jupyter-core==5.7.2
# via
# ipykernel
# jupyter-client
llama-cpp-python==0.3.5
# via -r requirements.in
markdown-it-py==3.0.0
# via rich
markdownify==0.14.1
# via -r requirements.in
markupsafe==2.1.5
# via
# gradio
# jinja2
matplotlib-inline==0.1.7
# via
# ipykernel
# ipython
mdurl==0.1.2
# via markdown-it-py
nest-asyncio==1.6.0
# via ipykernel
numpy==2.2.0
# via
# gradio
# llama-cpp-python
# pandas
# transformers
orjson==3.10.12
# via gradio
packaging==24.2
# via
# gradio
# gradio-client
# huggingface-hub
# ipykernel
# pytest
# transformers
pandas==2.2.3
# via
# -r requirements.in
# gradio
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pillow==11.0.0
# via
# -r requirements.in
# gradio
platformdirs==4.3.6
# via jupyter-core
pluggy==1.5.0
# via pytest
primp==0.8.3
# via duckduckgo-search
prompt-toolkit==3.0.48
# via ipython
psutil==6.1.0
# via ipykernel
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pydantic==2.10.3
# via
# fastapi
# gradio
pydantic-core==2.27.1
# via pydantic
pydub==0.25.1
# via gradio
pygments==2.18.0
# via
# ipython
# rich
pytest==8.3.4
# via -r requirements.in
python-dateutil==2.9.0.post0
# via
# jupyter-client
# pandas
python-dotenv==1.0.1
# via -r requirements.in
python-multipart==0.0.19
# via gradio
pytz==2024.2
# via pandas
pyyaml==6.0.2
# via
# gradio
# huggingface-hub
# transformers
pyzmq==26.2.0
# via
# ipykernel
# jupyter-client
regex==2024.11.6
# via transformers
requests==2.32.3
# via
# -r requirements.in
# huggingface-hub
# transformers
rich==13.9.4
# via
# -r requirements.in
# typer
ruff==0.8.2
# via gradio
safehttpx==0.1.6
# via gradio
safetensors==0.4.5
# via transformers
semantic-version==2.10.0
# via gradio
shellingham==1.5.4
# via typer
six==1.17.0
# via
# markdownify
# python-dateutil
sniffio==1.3.1
# via anyio
soupsieve==2.6
# via beautifulsoup4
stack-data==0.6.3
# via ipython
starlette==0.41.3
# via
# fastapi
# gradio
tokenizers==0.21.0
# via transformers
tomlkit==0.13.2
# via gradio
tornado==6.4.2
# via
# ipykernel
# jupyter-client
tqdm==4.67.1
# via
# huggingface-hub
# transformers
traitlets==5.14.3
# via
# comm
# ipykernel
# ipython
# jupyter-client
# jupyter-core
# matplotlib-inline
transformers==4.47.0
# via -r requirements.in
typer==0.15.1
# via gradio
typing-extensions==4.12.2
# via
# anyio
# fastapi
# gradio
# gradio-client
# huggingface-hub
# llama-cpp-python
# pydantic
# pydantic-core
# typer
tzdata==2024.2
# via pandas
urllib3==2.2.3
# via requests
uvicorn==0.32.1
# via gradio
wcwidth==0.2.13
# via prompt-toolkit
websockets==14.1
# via gradio-client

122
setup.py
View File

@ -1,122 +0,0 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import find_packages, setup
extras = {}
extras["quality"] = [
"black ~= 23.1", # hf-doc-builder has a hidden dependency on `black`
"hf-doc-builder >= 0.3.0",
"ruff ~= 0.6.4",
]
extras["docs"] = []
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
extras["test_dev"] = [
"datasets",
"diffusers",
"evaluate",
"torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
"scikit-learn",
"tqdm",
"bitsandbytes",
"timm",
]
extras["testing"] = extras["test_prod"] + extras["test_dev"]
extras["deepspeed"] = ["deepspeed"]
extras["rich"] = ["rich"]
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
extras["sagemaker"] = [
"sagemaker", # boto3 is a required package in sagemaker
]
setup(
name="accelerate",
version="1.2.0.dev0",
description="Accelerate",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning",
license="Apache",
author="The HuggingFace team",
author_email="zach.mueller@huggingface.co",
url="https://github.com/huggingface/accelerate",
package_dir={"": "src"},
packages=find_packages("src"),
entry_points={
"console_scripts": [
"accelerate=accelerate.commands.accelerate_cli:main",
"accelerate-config=accelerate.commands.config:main",
"accelerate-estimate-memory=accelerate.commands.estimate:main",
"accelerate-launch=accelerate.commands.launch:main",
"accelerate-merge-weights=accelerate.commands.merge:main",
]
},
python_requires=">=3.9.0",
install_requires=[
"numpy>=1.17,<3.0.0",
"packaging>=20.0",
"psutil",
"pyyaml",
"torch>=1.10.0",
"huggingface_hub>=0.21.0",
"safetensors>=0.4.3",
],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
# Release checklist
# 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one):
# git checkout -b vXX.xx-release
# The -b is only necessary for creation (so remove it when doing a patch)
# 2. Change the version in __init__.py and setup.py to the proper value.
# 3. Commit these changes with the message: "Release: v<VERSION>"
# 4. Add a tag in git to mark the release:
# git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi'
# Push the tag and release commit to git: git push --tags origin vXX.xx-release
# 5. Run the following commands in the top-level directory:
# rm -rf dist
# rm -rf build
# python setup.py bdist_wheel
# python setup.py sdist
# 6. Upload the package to the pypi test server first:
# twine upload dist/* -r testpypi
# 7. Check that you can install it in a virtualenv by running:
# pip install accelerate
# pip uninstall accelerate
# pip install -i https://testpypi.python.org/pypi accelerate
# accelerate env
# accelerate test
# 8. Upload the final version to actual pypi:
# twine upload dist/* -r pypi
# 9. Add release notes to the tag in github once everything is looking hunky-dory.
# 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to
# main.

View File

@ -19,19 +19,25 @@ import uuid
from pathlib import Path from pathlib import Path
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision from transformers.testing_utils import (
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available get_tests_dir,
require_soundfile,
require_torch,
require_vision,
)
from transformers.utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
)
if is_torch_available():
import torch import torch
from PIL import Image
if is_soundfile_availble(): if is_soundfile_availble():
import soundfile as sf import soundfile as sf
if is_vision_available():
from PIL import Image
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()

View File

@ -19,17 +19,16 @@ import uuid
import pytest import pytest
from transformers.agents.agent_types import AgentText from agents.agent_types import AgentText
from transformers.agents.agents import ( from agents.agents import (
AgentMaxIterationsError, AgentMaxIterationsError,
CodeAgent, CodeAgent,
ManagedAgent, ManagedAgent,
ReactCodeAgent, CodeAgent,
ReactJsonAgent, JsonAgent,
Toolbox, Toolbox,
) )
from transformers.agents.default_tools import PythonInterpreterTool from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import require_torch
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
@ -149,19 +148,26 @@ print(result)
class AgentTests(unittest.TestCase): class AgentTests(unittest.TestCase):
def test_fake_code_agent(self): def test_fake_code_agent(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot) agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
def test_fake_react_json_agent(self): def test_fake_react_json_agent(self):
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm) agent = JsonAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert output == "7.2904" assert output == "7.2904"
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert agent.logs[1]["observation"] == "7.2904" assert agent.logs[1]["observation"] == "7.2904"
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker" assert (
agent.logs[1]["rationale"].strip()
== "Thought: I should multiply 2 by 3.6452. special_marker"
)
assert ( assert (
agent.logs[2]["llm_output"] agent.logs[2]["llm_output"]
== """ == """
@ -175,7 +181,9 @@ Action:
) )
def test_fake_react_code_agent(self): def test_fake_react_code_agent(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm) agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
@ -186,17 +194,19 @@ Action:
} }
def test_react_code_agent_code_errors_show_offending_lines(self): def test_react_code_agent_code_errors_show_offending_lines(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error) agent = CodeAgent(
tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[]) JsonAgent(llm_engine=fake_react_json_llm, tools=[])
def test_react_fails_max_iterations(self): def test_react_fails_max_iterations(self):
agent = ReactCodeAgent( agent = CodeAgent(
tools=[PythonInterpreterTool()], tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_no_return, # use this callable because it never ends llm_engine=fake_code_llm_no_return, # use this callable because it never ends
max_iterations=5, max_iterations=5,
@ -208,51 +218,62 @@ Action:
@require_torch @require_torch
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 1 len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default ) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2) toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert ( assert (
len(agent.toolbox.tools) == 2 len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one ) # same as previous one, where toolset_3 is an instantiation of previous one
# check that add_base_tools will not interfere with existing tools # check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True) agent = JsonAgent(
tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True
)
assert "already exists in the toolbox" in str(e) assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter) assert (
len(agent.toolbox.tools) == 7
) # added final_answer tool + 6 base tools (excluding interpreter)
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = ReactCodeAgent( agent = CodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"] tools=[],
llm_engine=fake_react_code_functiondef,
max_iterations=2,
additional_authorized_imports=["numpy"],
) )
res = agent.run("ok") res = agent.run("ok")
assert res[0] == 0.5 assert res[0] == 0.5
def test_init_managed_agent(self): def test_init_managed_agent(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
assert managed_agent.name == "managed_agent" assert managed_agent.name == "managed_agent"
assert managed_agent.description == "Empty" assert managed_agent.description == "Empty"
def test_agent_description_gets_correctly_inserted_in_system_prompt(self): def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
manager_agent = ReactCodeAgent( manager_agent = CodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent] tools=[],
llm_engine=fake_react_code_functiondef,
managed_agents=[managed_agent],
) )
assert "You can also give requests to team members." not in agent.system_prompt assert "You can also give requests to team members." not in agent.system_prompt
assert "<<managed_agents_descriptions>>" not in agent.system_prompt assert "<<managed_agents_descriptions>>" not in agent.system_prompt
assert "You can also give requests to team members." in manager_agent.system_prompt assert (
"You can also give requests to team members." in manager_agent.system_prompt
)

View File

@ -1,41 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document_question_answering")
self.tool.setup()
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
result = self.tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
self.tool(document=document, question="When is the coffee break?")

View File

@ -87,7 +87,11 @@ class ExampleDifferenceTests(unittest.TestCase):
examples_path = Path("examples").resolve() examples_path = Path("examples").resolve()
def one_complete_example( def one_complete_example(
self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None self,
complete_file_name: str,
parser_only: bool,
secondary_filename: str = None,
special_strings: list = None,
): ):
""" """
Tests a single `complete` example against all of the implemented `by_feature` scripts Tests a single `complete` example against all of the implemented `by_feature` scripts
@ -112,10 +116,15 @@ class ExampleDifferenceTests(unittest.TestCase):
with self.subTest( with self.subTest(
tested_script=complete_file_name, tested_script=complete_file_name,
feature_script=item, feature_script=item,
tested_section="main()" if parser_only else "training_function()", tested_section="main()"
if parser_only
else "training_function()",
): ):
diff = compare_against_test( diff = compare_against_test(
self.examples_path / complete_file_name, item_path, parser_only, secondary_filename self.examples_path / complete_file_name,
item_path,
parser_only,
secondary_filename,
) )
diff = "\n".join(diff) diff = "\n".join(diff)
if special_strings is not None: if special_strings is not None:
@ -140,8 +149,12 @@ class ExampleDifferenceTests(unittest.TestCase):
" " * 12, " " * 12,
" " * 8 + "for step, batch in enumerate(active_dataloader):\n", " " * 8 + "for step, batch in enumerate(active_dataloader):\n",
] ]
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings) self.one_complete_example(
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings) "complete_cv_example.py", True, cv_path, special_strings
)
self.one_complete_example(
"complete_cv_example.py", False, cv_path, special_strings
)
@mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"}) @mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"})

View File

@ -47,9 +47,9 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def create_inputs(self): def create_inputs(self):
inputs_text = {"answer": "Text input"} inputs_text = {"answer": "Text input"}
inputs_image = { inputs_image = {
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize( "answer": Image.open(
(512, 512) Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
) ).resize((512, 512))
} }
inputs_audio = {"answer": torch.Tensor(np.ones(3000))} inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}

View File

@ -1,42 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image_question_answering")
self.tool.setup()
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")

View File

@ -129,7 +129,9 @@ final_answer('This is the final answer.')
def test_streaming_agent_image_output(self): def test_streaming_agent_image_output(self):
def dummy_llm_engine(prompt, **kwargs): def dummy_llm_engine(prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' return (
'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
)
agent = ReactJsonAgent( agent = ReactJsonAgent(
tools=[], tools=[],
@ -138,7 +140,14 @@ final_answer('This is the final answer.')
) )
# Use stream_to_gradio to capture the output # Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True)) outputs = list(
stream_to_gradio(
agent,
task="Test task",
image=AgentImage(value="path.png"),
test_mode=True,
)
)
self.assertEqual(len(outputs), 2) self.assertEqual(len(outputs), 2)
final_message = outputs[-1] final_message = outputs[-1]

View File

@ -21,7 +21,10 @@ import pytest
from transformers import load_tool from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS from transformers.agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code from transformers.agents.python_interpreter import (
InterpreterError,
evaluate_python_code,
)
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin
@ -57,7 +60,12 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
for _input, expected_input in zip(inputs, self.tool.inputs.values()): for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"] input_type = expected_input["type"]
if isinstance(input_type, list): if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) _inputs.append(
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
else: else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
@ -91,7 +99,10 @@ class PythonInterpreterTester(unittest.TestCase):
code = "print = '3'" code = "print = '3'"
with pytest.raises(InterpreterError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {"print": print}, state={}) evaluate_python_code(code, {"print": print}, state={})
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e) assert (
"Cannot assign to name 'print': doing this would erase the existing tool!"
in str(e)
)
def test_evaluate_call(self): def test_evaluate_call(self):
code = "y = add_two(x)" code = "y = add_two(x)"
@ -117,7 +128,9 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5}) self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
def test_evaluate_expression(self): def test_evaluate_expression(self):
code = "x = 3\ny = 5" code = "x = 3\ny = 5"
@ -133,7 +146,9 @@ class PythonInterpreterTester(unittest.TestCase):
result = evaluate_python_code(code, {}, state=state) result = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == "This is x: 3." assert result == "This is x: 3."
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}) self.assertDictEqual(
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
)
def test_evaluate_if(self): def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5" code = "if x <= 3:\n y = 2\nelse:\n y = 5"
@ -174,11 +189,15 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3} state = {"x": 3}
result = evaluate_python_code(code, {"add_two": add_two}, state=state) result = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {} state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) evaluate_python_code(
code, {"min": min, "print": print, "round": round}, state=state
)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self): def test_subscript_string_with_string_index_raises_appropriate_error(self):
@ -292,7 +311,16 @@ print(check_digits)
""" """
state = {} state = {}
evaluate_python_code( evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state code,
{
"range": range,
"print": print,
"sum": sum,
"enumerate": enumerate,
"int": int,
"str": str,
},
state,
) )
def test_listcomp(self): def test_listcomp(self):
@ -325,7 +353,9 @@ print(check_digits)
assert result == {0: 0, 1: 1, 2: 4} assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) result = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert result == {102: "b"} assert result == {102: "b"}
code = """ code = """
@ -373,7 +403,9 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Brooklyn" assert result == "Brooklyn"
code = """if d > e and a < b: code = """if d > e and a < b:
@ -384,7 +416,9 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Sacramento" assert result == "Sacramento"
def test_if_conditions(self): def test_if_conditions(self):
@ -400,7 +434,9 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0 assert result == 2.0
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose" assert result == "lose"
@ -434,10 +470,14 @@ if char.isalpha():
# Test submodules are handled properly, thus not raising error # Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) result = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
def test_additional_imports(self): def test_additional_imports(self):
code = "import numpy as np" code = "import numpy as np"
@ -554,7 +594,11 @@ cat_sound = cat.sound()
cat_str = str(cat) cat_str = str(cat)
""" """
state = {} state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) evaluate_python_code(
code,
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
state=state,
)
# Assert results # Assert results
assert state["dog1_sound"] == "The dog barks." assert state["dog1_sound"] == "The dog barks."
@ -588,7 +632,11 @@ except ValueError as e:
exception_message = str(e) exception_message = str(e)
""" """
state = {} state = {}
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) evaluate_python_code(
code,
{"print": print, "len": len, "super": super, "str": str, "sum": sum},
state=state,
)
assert state["exception_message"] == "An error occurred" assert state["exception_message"] == "An error occurred"
def test_print(self): def test_print(self):
@ -600,7 +648,9 @@ except ValueError as e:
def test_types_as_objects(self): def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int" code = "type_a = float(2); type_b = str; type_c = int"
state = {} state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) result = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state
)
assert result is int assert result is int
def test_tuple_id(self): def test_tuple_id(self):
@ -731,7 +781,9 @@ def add_one(n, shift):
add_one(1, 1) add_one(1, 1)
""" """
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) result = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result == 2 assert result == 2
# test returning None # test returning None
@ -742,7 +794,9 @@ def returns_none(a):
returns_none(1) returns_none(1)
""" """
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) result = evaluate_python_code(
code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state
)
assert result is None assert result is None
def test_nested_for_loop(self): def test_nested_for_loop(self):
@ -758,7 +812,9 @@ out = [i for sublist in all_res for i in sublist]
out[:10] out[:10]
""" """
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state) result = evaluate_python_code(
code, {"print": print, "range": range}, state=state
)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self): def test_pandas(self):
@ -773,7 +829,9 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1] parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
""" """
state = {} state = {}
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"]) result = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"]
)
assert np.array_equal(result, [-1, 5]) assert np.array_equal(result, [-1, 5])
code = """ code = """
@ -785,7 +843,9 @@ print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers # Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])] filtered_df = df.loc[df['AtomicNumber'].isin([104])]
""" """
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) result = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert np.array_equal(result.values[0], [104, 1]) assert np.array_equal(result.values[0], [104, 1])
code = """import pandas as pd code = """import pandas as pd
@ -818,7 +878,9 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
""" """
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"]) result = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
assert round(result, 1) == 622395.4 assert round(result, 1) == 622395.4
def test_for(self): def test_for(self):

View File

@ -1,36 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("speech_to_text")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(np.ones(3000))
self.assertEqual(result, " Thank you.")
def test_exact_match_kwarg(self):
result = self.tool(audio=np.ones(3000))
self.assertEqual(result, " Thank you.")

View File

@ -1,50 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from transformers.utils import is_torch_available
if is_torch_available():
import torch
from transformers.testing_utils import require_torch
from .test_tools_common import ToolTesterMixin
@require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text_to_speech")
self.tool.setup()
def test_exact_match_arg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
def test_exact_match_kwarg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)

View File

@ -20,7 +20,12 @@ import numpy as np
import pytest import pytest
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText from transformers.agents.agent_types import (
AGENT_TYPE_MAPPING,
AgentAudio,
AgentImage,
AgentText,
)
from transformers.agents.tools import Tool, tool from transformers.agents.tools import Tool, tool
from transformers.testing_utils import get_tests_dir from transformers.testing_utils import get_tests_dir

View File

@ -32,7 +32,9 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
self.assertEqual(result, "- Hé, comment ça va?") self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg(self): def test_exact_match_kwarg(self):
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French") result = self.tool(
text="Hey, what's up?", src_lang="English", tgt_lang="French"
)
self.assertEqual(result, "- Hé, comment ça va?") self.assertEqual(result, "- Hé, comment ça va?")
def test_call(self): def test_call(self):