Ruff formatting
This commit is contained in:
		
							parent
							
								
									851e177e71
								
							
						
					
					
						commit
						67deb6808f
					
				|  | @ -24,10 +24,25 @@ from transformers.utils import ( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| _import_structure = { | _import_structure = { | ||||||
|     "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "CodeAgent", "JsonAgent", "Toolbox"], |     "agents": [ | ||||||
|  |         "Agent", | ||||||
|  |         "CodeAgent", | ||||||
|  |         "ManagedAgent", | ||||||
|  |         "ReactAgent", | ||||||
|  |         "CodeAgent", | ||||||
|  |         "JsonAgent", | ||||||
|  |         "Toolbox", | ||||||
|  |     ], | ||||||
|     "llm_engine": ["HfApiEngine", "TransformersEngine"], |     "llm_engine": ["HfApiEngine", "TransformersEngine"], | ||||||
|     "monitoring": ["stream_to_gradio"], |     "monitoring": ["stream_to_gradio"], | ||||||
|     "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"], |     "tools": [ | ||||||
|  |         "PipelineTool", | ||||||
|  |         "Tool", | ||||||
|  |         "ToolCollection", | ||||||
|  |         "launch_gradio_demo", | ||||||
|  |         "load_tool", | ||||||
|  |         "tool", | ||||||
|  |     ], | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|  | @ -45,10 +60,25 @@ else: | ||||||
|     _import_structure["translation"] = ["TranslationTool"] |     _import_structure["translation"] = ["TranslationTool"] | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, CodeAgent, JsonAgent, Toolbox |     from .agents import ( | ||||||
|  |         Agent, | ||||||
|  |         CodeAgent, | ||||||
|  |         ManagedAgent, | ||||||
|  |         ReactAgent, | ||||||
|  |         CodeAgent, | ||||||
|  |         JsonAgent, | ||||||
|  |         Toolbox, | ||||||
|  |     ) | ||||||
|     from .llm_engine import HfApiEngine, TransformersEngine |     from .llm_engine import HfApiEngine, TransformersEngine | ||||||
|     from .monitoring import stream_to_gradio |     from .monitoring import stream_to_gradio | ||||||
|     from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool |     from .tools import ( | ||||||
|  |         PipelineTool, | ||||||
|  |         Tool, | ||||||
|  |         ToolCollection, | ||||||
|  |         launch_gradio_demo, | ||||||
|  |         load_tool, | ||||||
|  |         tool, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         if not is_torch_available(): |         if not is_torch_available(): | ||||||
|  | @ -66,4 +96,6 @@ if TYPE_CHECKING: | ||||||
| else: | else: | ||||||
|     import sys |     import sys | ||||||
| 
 | 
 | ||||||
|     sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |     sys.modules[__name__] = _LazyModule( | ||||||
|  |         __name__, globals()["__file__"], _import_structure, module_spec=__spec__ | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | @ -19,7 +19,11 @@ import uuid | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available | from transformers.utils import ( | ||||||
|  |     is_soundfile_availble, | ||||||
|  |     is_torch_available, | ||||||
|  |     is_vision_available, | ||||||
|  | ) | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -108,7 +112,9 @@ class AgentImage(AgentType, ImageType): | ||||||
|         elif isinstance(value, np.ndarray): |         elif isinstance(value, np.ndarray): | ||||||
|             self._tensor = torch.from_numpy(value) |             self._tensor = torch.from_numpy(value) | ||||||
|         else: |         else: | ||||||
|             raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") |             raise TypeError( | ||||||
|  |                 f"Unsupported type for {self.__class__.__name__}: {type(value)}" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def _ipython_display_(self, include=None, exclude=None): |     def _ipython_display_(self, include=None, exclude=None): | ||||||
|         """ |         """ | ||||||
|  | @ -159,7 +165,7 @@ class AgentImage(AgentType, ImageType): | ||||||
| 
 | 
 | ||||||
|             return self._path |             return self._path | ||||||
| 
 | 
 | ||||||
|     def save(self, output_bytes, format : str = None, **params): |     def save(self, output_bytes, format: str = None, **params): | ||||||
|         """ |         """ | ||||||
|         Saves the image to a file. |         Saves the image to a file. | ||||||
|         Args: |         Args: | ||||||
|  | @ -243,7 +249,9 @@ if is_torch_available(): | ||||||
| 
 | 
 | ||||||
| def handle_agent_inputs(*args, **kwargs): | def handle_agent_inputs(*args, **kwargs): | ||||||
|     args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] |     args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] | ||||||
|     kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()} |     kwargs = { | ||||||
|  |         k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items() | ||||||
|  |     } | ||||||
|     return args, kwargs |     return args, kwargs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										372
									
								
								agents/agents.py
								
								
								
								
							
							
						
						
									
										372
									
								
								agents/agents.py
								
								
								
								
							|  | @ -15,12 +15,10 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import time | import time | ||||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union, Tuple | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from rich.syntax import Syntax | from rich.syntax import Syntax | ||||||
| 
 | 
 | ||||||
| from langfuse.decorators import langfuse_context, observe |  | ||||||
| 
 |  | ||||||
| from transformers.utils import is_torch_available | from transformers.utils import is_torch_available | ||||||
| 
 | 
 | ||||||
| from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content | ||||||
|  | @ -51,6 +49,7 @@ from .tools import ( | ||||||
| 
 | 
 | ||||||
| HUGGINGFACE_DEFAULT_TOOLS = {} | HUGGINGFACE_DEFAULT_TOOLS = {} | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class AgentError(Exception): | class AgentError(Exception): | ||||||
|     """Base class for other agent-related exceptions""" |     """Base class for other agent-related exceptions""" | ||||||
| 
 | 
 | ||||||
|  | @ -60,7 +59,6 @@ class AgentError(Exception): | ||||||
|         console.print(f"[bold red]{message}[/bold red]") |         console.print(f"[bold red]{message}[/bold red]") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| class AgentParsingError(AgentError): | class AgentParsingError(AgentError): | ||||||
|     """Exception raised for errors in parsing in the agent""" |     """Exception raised for errors in parsing in the agent""" | ||||||
| 
 | 
 | ||||||
|  | @ -84,12 +82,14 @@ class AgentGenerationError(AgentError): | ||||||
| 
 | 
 | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class AgentStep: | class AgentStep: | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ActionStep(AgentStep): | class ActionStep(AgentStep): | ||||||
|     tool_call: str | None = None |     tool_call: Dict[str, str] | None = None | ||||||
|     start_time: float | None = None |     start_time: float | None = None | ||||||
|     step_end_time: float | None = None |     step_end_time: float | None = None | ||||||
|     iteration: int | None = None |     iteration: int | None = None | ||||||
|  | @ -97,32 +97,43 @@ class ActionStep(AgentStep): | ||||||
|     error: AgentError | None = None |     error: AgentError | None = None | ||||||
|     step_duration: float | None = None |     step_duration: float | None = None | ||||||
|     llm_output: str | None = None |     llm_output: str | None = None | ||||||
|  |     observation: str | None = None | ||||||
|  |     agent_memory: List[Dict[str, str]] | None = None | ||||||
|  |     rationale: str | None = None | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class PlanningStep(AgentStep): | class PlanningStep(AgentStep): | ||||||
|     plan: str |     plan: str | ||||||
|     facts: str |     facts: str | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class TaskStep(AgentStep): | class TaskStep(AgentStep): | ||||||
|     task: str |     task: str | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class SystemPromptStep(AgentStep): | class SystemPromptStep(AgentStep): | ||||||
|     system_prompt: str |     system_prompt: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: | def format_prompt_with_tools( | ||||||
|  |     toolbox: Toolbox, prompt_template: str, tool_description_template: str | ||||||
|  | ) -> str: | ||||||
|     tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) |     tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) | ||||||
|     prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) |     prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) | ||||||
| 
 | 
 | ||||||
|     if "{{tool_names}}" in prompt: |     if "{{tool_names}}" in prompt: | ||||||
|         prompt = prompt.replace("{{tool_names}}", ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()])) |         prompt = prompt.replace( | ||||||
|  |             "{{tool_names}}", | ||||||
|  |             ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]), | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     return prompt |     return prompt | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def show_agents_descriptions(managed_agents: list): | def show_agents_descriptions(managed_agents: Dict): | ||||||
|     managed_agents_descriptions = """ |     managed_agents_descriptions = """ | ||||||
| You can also give requests to team members. | You can also give requests to team members. | ||||||
| Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request. | Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request. | ||||||
|  | @ -133,16 +144,24 @@ Here is a list of the team members that you can call:""" | ||||||
|     return managed_agents_descriptions |     return managed_agents_descriptions | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str: | def format_prompt_with_managed_agents_descriptions( | ||||||
|  |     prompt_template, managed_agents=None | ||||||
|  | ) -> str: | ||||||
|     if managed_agents is not None: |     if managed_agents is not None: | ||||||
|         return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents)) |         return prompt_template.replace( | ||||||
|  |             "<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents) | ||||||
|  |         ) | ||||||
|     else: |     else: | ||||||
|         return prompt_template.replace("<<managed_agents_descriptions>>", "") |         return prompt_template.replace("<<managed_agents_descriptions>>", "") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str: | def format_prompt_with_imports( | ||||||
|  |     prompt_template: str, authorized_imports: List[str] | ||||||
|  | ) -> str: | ||||||
|     if "<<authorized_imports>>" not in prompt_template: |     if "<<authorized_imports>>" not in prompt_template: | ||||||
|         raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.") |         raise AgentError( | ||||||
|  |             "Tag '<<authorized_imports>>' should be provided in the prompt." | ||||||
|  |         ) | ||||||
|     return prompt_template.replace("<<authorized_imports>>", str(authorized_imports)) |     return prompt_template.replace("<<authorized_imports>>", str(authorized_imports)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -150,7 +169,7 @@ class BaseAgent: | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         tools: Union[List[Tool], Toolbox], |         tools: Union[List[Tool], Toolbox], | ||||||
|         llm_engine: Callable = None, |         llm_engine: Optional[Callable[[List[Dict[str, str]]], str]] = None, | ||||||
|         system_prompt: Optional[str] = None, |         system_prompt: Optional[str] = None, | ||||||
|         tool_description_template: Optional[str] = None, |         tool_description_template: Optional[str] = None, | ||||||
|         additional_args: Dict = {}, |         additional_args: Dict = {}, | ||||||
|  | @ -159,10 +178,12 @@ class BaseAgent: | ||||||
|         add_base_tools: bool = False, |         add_base_tools: bool = False, | ||||||
|         verbose: bool = False, |         verbose: bool = False, | ||||||
|         grammar: Optional[Dict[str, str]] = None, |         grammar: Optional[Dict[str, str]] = None, | ||||||
|         managed_agents: Optional[List] = None, |         managed_agents: Optional[Dict] = None, | ||||||
|         step_callbacks: Optional[List[Callable]] = None, |         step_callbacks: Optional[List[Callable]] = None, | ||||||
|         monitor_metrics: bool = True, |         monitor_metrics: bool = True, | ||||||
|     ): |     ): | ||||||
|  |         if llm_engine is None: | ||||||
|  |             llm_engine = HfApiEngine() | ||||||
|         if system_prompt is None: |         if system_prompt is None: | ||||||
|             system_prompt = CODE_SYSTEM_PROMPT |             system_prompt = CODE_SYSTEM_PROMPT | ||||||
|         if tool_parser is None: |         if tool_parser is None: | ||||||
|  | @ -171,14 +192,16 @@ class BaseAgent: | ||||||
|         self.llm_engine = llm_engine |         self.llm_engine = llm_engine | ||||||
|         self.system_prompt_template = system_prompt |         self.system_prompt_template = system_prompt | ||||||
|         self.tool_description_template = ( |         self.tool_description_template = ( | ||||||
|             tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE |             tool_description_template | ||||||
|  |             if tool_description_template | ||||||
|  |             else DEFAULT_TOOL_DESCRIPTION_TEMPLATE | ||||||
|         ) |         ) | ||||||
|         self.additional_args = additional_args |         self.additional_args = additional_args | ||||||
|         self.max_iterations = max_iterations |         self.max_iterations = max_iterations | ||||||
|         self.tool_parser = tool_parser |         self.tool_parser = tool_parser | ||||||
|         self.grammar = grammar |         self.grammar = grammar | ||||||
| 
 | 
 | ||||||
|         self.managed_agents = None |         self.managed_agents = {} | ||||||
|         if managed_agents is not None: |         if managed_agents is not None: | ||||||
|             self.managed_agents = {agent.name: agent for agent in managed_agents} |             self.managed_agents = {agent.name: agent for agent in managed_agents} | ||||||
| 
 | 
 | ||||||
|  | @ -186,9 +209,13 @@ class BaseAgent: | ||||||
|             self._toolbox = tools |             self._toolbox = tools | ||||||
|             if add_base_tools: |             if add_base_tools: | ||||||
|                 if not is_torch_available(): |                 if not is_torch_available(): | ||||||
|                     raise ImportError("Using the base tools requires torch to be installed.") |                     raise ImportError( | ||||||
|  |                         "Using the base tools requires torch to be installed." | ||||||
|  |                     ) | ||||||
| 
 | 
 | ||||||
|                 self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == JsonAgent)) |                 self._toolbox.add_base_tools( | ||||||
|  |                     add_python_interpreter=(self.__class__ == JsonAgent) | ||||||
|  |                 ) | ||||||
|         else: |         else: | ||||||
|             self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) |             self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) | ||||||
|         self._toolbox.add_tool(FinalAnswerTool()) |         self._toolbox.add_tool(FinalAnswerTool()) | ||||||
|  | @ -196,7 +223,9 @@ class BaseAgent: | ||||||
|         self.system_prompt = format_prompt_with_tools( |         self.system_prompt = format_prompt_with_tools( | ||||||
|             self._toolbox, self.system_prompt_template, self.tool_description_template |             self._toolbox, self.system_prompt_template, self.tool_description_template | ||||||
|         ) |         ) | ||||||
|         self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) |         self.system_prompt = format_prompt_with_managed_agents_descriptions( | ||||||
|  |             self.system_prompt, self.managed_agents | ||||||
|  |         ) | ||||||
|         self.prompt_messages = None |         self.prompt_messages = None | ||||||
|         self.logs = [] |         self.logs = [] | ||||||
|         self.task = None |         self.task = None | ||||||
|  | @ -222,15 +251,20 @@ class BaseAgent: | ||||||
|             self.system_prompt_template, |             self.system_prompt_template, | ||||||
|             self.tool_description_template, |             self.tool_description_template, | ||||||
|         ) |         ) | ||||||
|         self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) |         self.system_prompt = format_prompt_with_managed_agents_descriptions( | ||||||
|  |             self.system_prompt, self.managed_agents | ||||||
|  |         ) | ||||||
|         if hasattr(self, "authorized_imports"): |         if hasattr(self, "authorized_imports"): | ||||||
|             self.system_prompt = format_prompt_with_imports( |             self.system_prompt = format_prompt_with_imports( | ||||||
|                 self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) |                 self.system_prompt, | ||||||
|  |                 list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))), | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         return self.system_prompt |         return self.system_prompt | ||||||
| 
 | 
 | ||||||
|     def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: |     def write_inner_memory_from_logs( | ||||||
|  |         self, summary_mode: Optional[bool] = False | ||||||
|  |     ) -> List[Dict[str, str]]: | ||||||
|         """ |         """ | ||||||
|         Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages |         Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages | ||||||
|         that can be used as input to the LLM. |         that can be used as input to the LLM. | ||||||
|  | @ -253,7 +287,10 @@ class BaseAgent: | ||||||
|                 memory.append(thought_message) |                 memory.append(thought_message) | ||||||
| 
 | 
 | ||||||
|                 if not summary_mode: |                 if not summary_mode: | ||||||
|                     thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log.plan.strip()} |                     thought_message = { | ||||||
|  |                         "role": MessageRole.ASSISTANT, | ||||||
|  |                         "content": "[PLAN]:\n" + step_log.plan.strip(), | ||||||
|  |                     } | ||||||
|                     memory.append(thought_message) |                     memory.append(thought_message) | ||||||
| 
 | 
 | ||||||
|             elif isinstance(step_log, TaskStep): |             elif isinstance(step_log, TaskStep): | ||||||
|  | @ -265,13 +302,17 @@ class BaseAgent: | ||||||
| 
 | 
 | ||||||
|             elif isinstance(step_log, ActionStep): |             elif isinstance(step_log, ActionStep): | ||||||
|                 if step_log.llm_output is not None and not summary_mode: |                 if step_log.llm_output is not None and not summary_mode: | ||||||
|                     thought_message = {"role": MessageRole.ASSISTANT, "content": step_log.llm_output.strip()} |                     thought_message = { | ||||||
|  |                         "role": MessageRole.ASSISTANT, | ||||||
|  |                         "content": step_log.llm_output.strip(), | ||||||
|  |                     } | ||||||
|                     memory.append(thought_message) |                     memory.append(thought_message) | ||||||
| 
 | 
 | ||||||
|                 if step_log.tool_call is not None and summary_mode: |                 if step_log.tool_call is not None and summary_mode: | ||||||
|                     tool_call_message = { |                     tool_call_message = { | ||||||
|                         "role": MessageRole.ASSISTANT, |                         "role": MessageRole.ASSISTANT, | ||||||
|                         "content": f"[STEP {i} TOOL CALL]: " + str(step_log.tool_call).strip(), |                         "content": f"[STEP {i} TOOL CALL]: " | ||||||
|  |                         + str(step_log.tool_call).strip(), | ||||||
|                     } |                     } | ||||||
|                     memory.append(tool_call_message) |                     memory.append(tool_call_message) | ||||||
| 
 | 
 | ||||||
|  | @ -284,15 +325,21 @@ class BaseAgent: | ||||||
|                         ) |                         ) | ||||||
|                     elif step_log.observation is not None: |                     elif step_log.observation is not None: | ||||||
|                         message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" |                         message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log.observation}" | ||||||
|                     tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} |                     tool_response_message = { | ||||||
|  |                         "role": MessageRole.TOOL_RESPONSE, | ||||||
|  |                         "content": message_content, | ||||||
|  |                     } | ||||||
|                     memory.append(tool_response_message) |                     memory.append(tool_response_message) | ||||||
| 
 | 
 | ||||||
|         return memory |         return memory | ||||||
| 
 | 
 | ||||||
|     def get_succinct_logs(self): |     def get_succinct_logs(self): | ||||||
|         return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs] |         return [ | ||||||
|  |             {key: value for key, value in log.items() if key != "agent_memory"} | ||||||
|  |             for log in self.logs | ||||||
|  |         ] | ||||||
| 
 | 
 | ||||||
|     def extract_action(self, llm_output: str, split_token: str) -> str: |     def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]: | ||||||
|         """ |         """ | ||||||
|         Parse action from the LLM output |         Parse action from the LLM output | ||||||
| 
 | 
 | ||||||
|  | @ -312,54 +359,6 @@ class BaseAgent: | ||||||
|             ) |             ) | ||||||
|         return rationale.strip(), action.strip() |         return rationale.strip(), action.strip() | ||||||
| 
 | 
 | ||||||
|     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: |  | ||||||
|         """ |  | ||||||
|         Execute tool with the provided input and returns the result. |  | ||||||
|         This method replaces arguments with the actual values from the state if they refer to state variables. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). |  | ||||||
|             arguments (Dict[str, str]): Arguments passed to the Tool. |  | ||||||
|         """ |  | ||||||
|         available_tools = self.toolbox.tools |  | ||||||
|         if self.managed_agents is not None: |  | ||||||
|             available_tools = {**available_tools, **self.managed_agents} |  | ||||||
|         if tool_name not in available_tools: |  | ||||||
|             error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." |  | ||||||
|             console.print(f"[bold red]{error_msg}") |  | ||||||
|             raise AgentExecutionError(error_msg) |  | ||||||
| 
 |  | ||||||
|         try: |  | ||||||
|             if isinstance(arguments, str): |  | ||||||
|                 observation = available_tools[tool_name](arguments) |  | ||||||
|             elif isinstance(arguments, dict): |  | ||||||
|                 for key, value in arguments.items(): |  | ||||||
|                     if isinstance(value, str) and value in self.state: |  | ||||||
|                         arguments[key] = self.state[value] |  | ||||||
|                 observation = available_tools[tool_name](**arguments) |  | ||||||
|             else: |  | ||||||
|                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." |  | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |  | ||||||
|             return observation |  | ||||||
|         except Exception as e: |  | ||||||
|             if tool_name in self.toolbox.tools: |  | ||||||
|                 tool_description = get_tool_description_with_args(available_tools[tool_name]) |  | ||||||
|                 error_msg = ( |  | ||||||
|                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" |  | ||||||
|                     f"As a reminder, this tool's description is the following:\n{tool_description}" |  | ||||||
|                 ) |  | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |  | ||||||
|             elif tool_name in self.managed_agents: |  | ||||||
|                 error_msg = ( |  | ||||||
|                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" |  | ||||||
|                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" |  | ||||||
|                 ) |  | ||||||
|                 console.print(f"[bold red]{error_msg}") |  | ||||||
|                 raise AgentExecutionError(error_msg) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     def run(self, **kwargs): |     def run(self, **kwargs): | ||||||
|         """To be implemented in the child class""" |         """To be implemented in the child class""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  | @ -382,8 +381,6 @@ class ReactAgent(BaseAgent): | ||||||
|         planning_interval: Optional[int] = None, |         planning_interval: Optional[int] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         if llm_engine is None: |  | ||||||
|             llm_engine = HfApiEngine() |  | ||||||
|         if system_prompt is None: |         if system_prompt is None: | ||||||
|             system_prompt = CODE_SYSTEM_PROMPT |             system_prompt = CODE_SYSTEM_PROMPT | ||||||
|         if tool_description_template is None: |         if tool_description_template is None: | ||||||
|  | @ -423,8 +420,67 @@ class ReactAgent(BaseAgent): | ||||||
|             console.print(f"[bold red]{error_msg}[/bold red]") |             console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|             return error_msg |             return error_msg | ||||||
| 
 | 
 | ||||||
|     @observe |     def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: | ||||||
|     def run(self, task: str, stream: bool = False, reset: bool = True, oneshot: bool = False, **kwargs): |         """ | ||||||
|  |         Execute tool with the provided input and returns the result. | ||||||
|  |         This method replaces arguments with the actual values from the state if they refer to state variables. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). | ||||||
|  |             arguments (Dict[str, str]): Arguments passed to the Tool. | ||||||
|  |         """ | ||||||
|  |         available_tools = self.toolbox.tools | ||||||
|  |         if self.managed_agents is not None: | ||||||
|  |             available_tools = {**available_tools, **self.managed_agents} | ||||||
|  |         if tool_name not in available_tools: | ||||||
|  |             error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." | ||||||
|  |             console.print(f"[bold red]{error_msg}") | ||||||
|  |             raise AgentExecutionError(error_msg) | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             if isinstance(arguments, str): | ||||||
|  |                 observation = available_tools[tool_name](arguments) | ||||||
|  |             elif isinstance(arguments, dict): | ||||||
|  |                 for key, value in arguments.items(): | ||||||
|  |                     if isinstance(value, str) and value in self.state: | ||||||
|  |                         arguments[key] = self.state[value] | ||||||
|  |                 observation = available_tools[tool_name](**arguments) | ||||||
|  |             else: | ||||||
|  |                 error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." | ||||||
|  |                 console.print(f"[bold red]{error_msg}") | ||||||
|  |                 raise AgentExecutionError(error_msg) | ||||||
|  |             return observation | ||||||
|  |         except Exception as e: | ||||||
|  |             if tool_name in self.toolbox.tools: | ||||||
|  |                 tool_description = get_tool_description_with_args( | ||||||
|  |                     available_tools[tool_name] | ||||||
|  |                 ) | ||||||
|  |                 error_msg = ( | ||||||
|  |                     f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" | ||||||
|  |                     f"As a reminder, this tool's description is the following:\n{tool_description}" | ||||||
|  |                 ) | ||||||
|  |                 console.print(f"[bold red]{error_msg}") | ||||||
|  |                 raise AgentExecutionError(error_msg) | ||||||
|  |             elif tool_name in self.managed_agents: | ||||||
|  |                 error_msg = ( | ||||||
|  |                     f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" | ||||||
|  |                     f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" | ||||||
|  |                 ) | ||||||
|  |                 console.print(f"[bold red]{error_msg}") | ||||||
|  |                 raise AgentExecutionError(error_msg) | ||||||
|  | 
 | ||||||
|  |     def step(self, log_entry: ActionStep): | ||||||
|  |         """To be implemented in children classes""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def run( | ||||||
|  |         self, | ||||||
|  |         task: str, | ||||||
|  |         stream: bool = False, | ||||||
|  |         reset: bool = True, | ||||||
|  |         oneshot: bool = False, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|         """ |         """ | ||||||
|         Runs the agent for the given task. |         Runs the agent for the given task. | ||||||
| 
 | 
 | ||||||
|  | @ -441,10 +497,11 @@ class ReactAgent(BaseAgent): | ||||||
|         agent.run("What is the result of 2 power 3.7384?") |         agent.run("What is the result of 2 power 3.7384?") | ||||||
|         ``` |         ``` | ||||||
|         """ |         """ | ||||||
|         print("LANGFUSE REF:", langfuse_context.get_current_trace_url()) |  | ||||||
|         self.task = task |         self.task = task | ||||||
|         if len(kwargs) > 0: |         if len(kwargs) > 0: | ||||||
|             self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." |             self.task += ( | ||||||
|  |                 f"\nYou have been provided with these initial arguments: {str(kwargs)}." | ||||||
|  |             ) | ||||||
|         self.state = kwargs.copy() |         self.state = kwargs.copy() | ||||||
| 
 | 
 | ||||||
|         self.initialize_system_prompt() |         self.initialize_system_prompt() | ||||||
|  | @ -460,7 +517,7 @@ class ReactAgent(BaseAgent): | ||||||
|             else: |             else: | ||||||
|                 self.logs.append(system_prompt_step) |                 self.logs.append(system_prompt_step) | ||||||
| 
 | 
 | ||||||
|         console.rule("[bold]New task", characters='=') |         console.rule("[bold]New task", characters="=") | ||||||
|         console.print(self.task) |         console.print(self.task) | ||||||
|         self.logs.append(TaskStep(task=task)) |         self.logs.append(TaskStep(task=task)) | ||||||
| 
 | 
 | ||||||
|  | @ -489,8 +546,13 @@ class ReactAgent(BaseAgent): | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(iteration=iteration, start_time=step_start_time) |             step_log = ActionStep(iteration=iteration, start_time=step_start_time) | ||||||
|             try: |             try: | ||||||
|                 if self.planning_interval is not None and iteration % self.planning_interval == 0: |                 if ( | ||||||
|                     self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) |                     self.planning_interval is not None | ||||||
|  |                     and iteration % self.planning_interval == 0 | ||||||
|  |                 ): | ||||||
|  |                     self.planning_step( | ||||||
|  |                         task, is_first_step=(iteration == 0), iteration=iteration | ||||||
|  |                     ) | ||||||
|                 console.rule("[bold]New step") |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) |                 self.step(step_log) | ||||||
|                 if step_log.final_answer is not None: |                 if step_log.final_answer is not None: | ||||||
|  | @ -530,8 +592,13 @@ class ReactAgent(BaseAgent): | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(iteration=iteration, start_time=step_start_time) |             step_log = ActionStep(iteration=iteration, start_time=step_start_time) | ||||||
|             try: |             try: | ||||||
|                 if self.planning_interval is not None and iteration % self.planning_interval == 0: |                 if ( | ||||||
|                     self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) |                     self.planning_interval is not None | ||||||
|  |                     and iteration % self.planning_interval == 0 | ||||||
|  |                 ): | ||||||
|  |                     self.planning_step( | ||||||
|  |                         task, is_first_step=(iteration == 0), iteration=iteration | ||||||
|  |                     ) | ||||||
|                 console.rule("[bold]New step") |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) |                 self.step(step_log) | ||||||
|                 if step_log.final_answer is not None: |                 if step_log.final_answer is not None: | ||||||
|  | @ -559,7 +626,7 @@ class ReactAgent(BaseAgent): | ||||||
| 
 | 
 | ||||||
|         return final_answer |         return final_answer | ||||||
| 
 | 
 | ||||||
|     def planning_step(self, task, is_first_step: bool = False, iteration: int = None): |     def planning_step(self, task, is_first_step: bool, iteration: int): | ||||||
|         """ |         """ | ||||||
|         Used periodically by the agent to plan the next steps to reach the objective. |         Used periodically by the agent to plan the next steps to reach the objective. | ||||||
| 
 | 
 | ||||||
|  | @ -569,7 +636,10 @@ class ReactAgent(BaseAgent): | ||||||
|             iteration (`int`): The number of the current step, used as an indication for the LLM. |             iteration (`int`): The number of the current step, used as an indication for the LLM. | ||||||
|         """ |         """ | ||||||
|         if is_first_step: |         if is_first_step: | ||||||
|             message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS} |             message_prompt_facts = { | ||||||
|  |                 "role": MessageRole.SYSTEM, | ||||||
|  |                 "content": SYSTEM_PROMPT_FACTS, | ||||||
|  |             } | ||||||
|             message_prompt_task = { |             message_prompt_task = { | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": f"""Here is the task: |                 "content": f"""Here is the task: | ||||||
|  | @ -589,15 +659,20 @@ Now begin!""", | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": USER_PROMPT_PLAN.format( |                 "content": USER_PROMPT_PLAN.format( | ||||||
|                     task=task, |                     task=task, | ||||||
|                     tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), |                     tool_descriptions=self._toolbox.show_tool_descriptions( | ||||||
|  |                         self.tool_description_template | ||||||
|  |                     ), | ||||||
|                     managed_agents_descriptions=( |                     managed_agents_descriptions=( | ||||||
|                         show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" |                         show_agents_descriptions(self.managed_agents) | ||||||
|  |                         if self.managed_agents is not None | ||||||
|  |                         else "" | ||||||
|                     ), |                     ), | ||||||
|                     answer_facts=answer_facts, |                     answer_facts=answer_facts, | ||||||
|                 ), |                 ), | ||||||
|             } |             } | ||||||
|             answer_plan = self.llm_engine( |             answer_plan = self.llm_engine( | ||||||
|                 [message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"] |                 [message_system_prompt_plan, message_user_prompt_plan], | ||||||
|  |                 stop_sequences=["<end_plan>"], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: |             final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: | ||||||
|  | @ -608,7 +683,9 @@ Now begin!""", | ||||||
| ``` | ``` | ||||||
| {answer_facts} | {answer_facts} | ||||||
| ```""".strip() | ```""".strip() | ||||||
|             self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) |             self.logs.append( | ||||||
|  |                 PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) | ||||||
|  |             ) | ||||||
|             console.rule("[orange]Initial plan") |             console.rule("[orange]Initial plan") | ||||||
|             console.print(final_plan_redaction) |             console.print(final_plan_redaction) | ||||||
|         else:  # update plan |         else:  # update plan | ||||||
|  | @ -625,7 +702,9 @@ Now begin!""", | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": USER_PROMPT_FACTS_UPDATE, |                 "content": USER_PROMPT_FACTS_UPDATE, | ||||||
|             } |             } | ||||||
|             facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message]) |             facts_update = self.llm_engine( | ||||||
|  |                 [facts_update_system_prompt] + agent_memory + [facts_update_message] | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|             # Redact updated plan |             # Redact updated plan | ||||||
|             plan_update_message = { |             plan_update_message = { | ||||||
|  | @ -636,25 +715,34 @@ Now begin!""", | ||||||
|                 "role": MessageRole.USER, |                 "role": MessageRole.USER, | ||||||
|                 "content": USER_PROMPT_PLAN_UPDATE.format( |                 "content": USER_PROMPT_PLAN_UPDATE.format( | ||||||
|                     task=task, |                     task=task, | ||||||
|                     tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template), |                     tool_descriptions=self._toolbox.show_tool_descriptions( | ||||||
|  |                         self.tool_description_template | ||||||
|  |                     ), | ||||||
|                     managed_agents_descriptions=( |                     managed_agents_descriptions=( | ||||||
|                         show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else "" |                         show_agents_descriptions(self.managed_agents) | ||||||
|  |                         if self.managed_agents is not None | ||||||
|  |                         else "" | ||||||
|                     ), |                     ), | ||||||
|                     facts_update=facts_update, |                     facts_update=facts_update, | ||||||
|                     remaining_steps=(self.max_iterations - iteration), |                     remaining_steps=(self.max_iterations - iteration), | ||||||
|                 ), |                 ), | ||||||
|             } |             } | ||||||
|             plan_update = self.llm_engine( |             plan_update = self.llm_engine( | ||||||
|                 [plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"] |                 [plan_update_message] + agent_memory + [plan_update_message_user], | ||||||
|  |                 stop_sequences=["<end_plan>"], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             # Log final facts and plan |             # Log final facts and plan | ||||||
|             final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) |             final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format( | ||||||
|  |                 task=task, plan_update=plan_update | ||||||
|  |             ) | ||||||
|             final_facts_redaction = f"""Here is the updated list of the facts that I know: |             final_facts_redaction = f"""Here is the updated list of the facts that I know: | ||||||
| ``` | ``` | ||||||
| {facts_update} | {facts_update} | ||||||
| ```""" | ```""" | ||||||
|             self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) |             self.logs.append( | ||||||
|  |                 PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction) | ||||||
|  |             ) | ||||||
|             console.rule("[orange]Updated plan") |             console.rule("[orange]Updated plan") | ||||||
|             console.print(final_plan_redaction) |             console.print(final_plan_redaction) | ||||||
| 
 | 
 | ||||||
|  | @ -705,14 +793,20 @@ class JsonAgent(ReactAgent): | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Calling LLM engine with this last message:", align="left") |             console.rule( | ||||||
|  |                 "[italic]Calling LLM engine with this last message:", align="left" | ||||||
|  |             ) | ||||||
|             console.print(self.prompt_messages[-1]) |             console.print(self.prompt_messages[-1]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = ( | ||||||
|  |                 {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|  |             ) | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt_messages, | ||||||
|  |                 stop_sequences=["<end_action>", "Observation:"], | ||||||
|  |                 **additional_args, | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -723,7 +817,9 @@ class JsonAgent(ReactAgent): | ||||||
|             console.print(llm_output) |             console.print(llm_output) | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") |         rationale, action = self.extract_action( | ||||||
|  |             llm_output=llm_output, split_token="Action:" | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             tool_name, arguments = self.tool_parser(action) |             tool_name, arguments = self.tool_parser(action) | ||||||
|  | @ -807,12 +903,18 @@ class CodeAgent(ReactAgent): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self.python_evaluator = evaluate_python_code |         self.python_evaluator = evaluate_python_code | ||||||
|         self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] |         self.additional_authorized_imports = ( | ||||||
|         self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) |             additional_authorized_imports if additional_authorized_imports else [] | ||||||
|         self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports)) |         ) | ||||||
|  |         self.authorized_imports = list( | ||||||
|  |             set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) | ||||||
|  |         ) | ||||||
|  |         self.system_prompt = self.system_prompt.replace( | ||||||
|  |             "<<authorized_imports>>", str(self.authorized_imports) | ||||||
|  |         ) | ||||||
|         self.custom_tools = {} |         self.custom_tools = {} | ||||||
| 
 | 
 | ||||||
|     def step(self, log_entry: Dict[str, Any]): |     def step(self, log_entry: ActionStep): | ||||||
|         """ |         """ | ||||||
|         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. |         Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||||||
|         The errors are raised here, they are caught and logged in the run() method. |         The errors are raised here, they are caught and logged in the run() method. | ||||||
|  | @ -825,14 +927,20 @@ class CodeAgent(ReactAgent): | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Calling LLM engine with these last messages:", align="left") |             console.rule( | ||||||
|  |                 "[italic]Calling LLM engine with these last messages:", align="left" | ||||||
|  |             ) | ||||||
|             console.print(self.prompt_messages[-2:]) |             console.print(self.prompt_messages[-2:]) | ||||||
|             console.rule() |             console.rule() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             additional_args = {"grammar": self.grammar} if self.grammar is not None else {} |             additional_args = ( | ||||||
|  |                 {"grammar": self.grammar} if self.grammar is not None else {} | ||||||
|  |             ) | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt_messages, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt_messages, | ||||||
|  |                 stop_sequences=["<end_action>", "Observation:"], | ||||||
|  |                 **additional_args, | ||||||
|             ) |             ) | ||||||
|             log_entry.llm_output = llm_output |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -840,13 +948,19 @@ class CodeAgent(ReactAgent): | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Output message of the LLM:") |             console.rule("[italic]Output message of the LLM:") | ||||||
|             console.print(Syntax(llm_output, lexer='markdown', background_color='default')) |             console.print( | ||||||
|  |                 Syntax(llm_output, lexer="markdown", background_color="default") | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         try: |         try: | ||||||
|             rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") |             rationale, raw_code_action = self.extract_action( | ||||||
|  |                 llm_output=llm_output, split_token="Code:" | ||||||
|  |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") |             console.print( | ||||||
|  |                 f"Error in extracting action, trying to parse the whole output. Error trace: {e}" | ||||||
|  |             ) | ||||||
|             rationale, raw_code_action = llm_output, llm_output |             rationale, raw_code_action = llm_output, llm_output | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -856,14 +970,17 @@ class CodeAgent(ReactAgent): | ||||||
|             raise AgentParsingError(error_msg) |             raise AgentParsingError(error_msg) | ||||||
| 
 | 
 | ||||||
|         log_entry.rationale = rationale |         log_entry.rationale = rationale | ||||||
|         log_entry.tool_call = {"tool_name": "code interpreter", "tool_arguments": code_action} |         log_entry.tool_call = { | ||||||
|  |             "tool_name": "code interpreter", | ||||||
|  |             "tool_arguments": code_action, | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         # Execute |         # Execute | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Agent thoughts") |             console.rule("[italic]Agent thoughts") | ||||||
|             console.print(rationale) |             console.print(rationale) | ||||||
|         console.rule("[bold]Agent is executing the code below:", align="left") |         console.rule("[bold]Agent is executing the code below:", align="left") | ||||||
|         console.print(Syntax(code_action, lexer='python', background_color='default')) |         console.print(Syntax(code_action, lexer="python", background_color="default")) | ||||||
|         console.rule("", align="left") |         console.rule("", align="left") | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|  | @ -886,7 +1003,9 @@ class CodeAgent(ReactAgent): | ||||||
|             if result is not None: |             if result is not None: | ||||||
|                 console.rule("Last output from code snippet:", align="left") |                 console.rule("Last output from code snippet:", align="left") | ||||||
|                 console.print(str(result)) |                 console.print(str(result)) | ||||||
|                 observation += "Last output from code snippet:\n" + truncate_content(str(result)) |                 observation += "Last output from code snippet:\n" + truncate_content( | ||||||
|  |                     str(result) | ||||||
|  |                 ) | ||||||
|             log_entry.observation = observation |             log_entry.observation = observation | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             error_msg = f"Code execution failed due to the following error:\n{str(e)}" |             error_msg = f"Code execution failed due to the following error:\n{str(e)}" | ||||||
|  | @ -902,7 +1021,14 @@ class CodeAgent(ReactAgent): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ManagedAgent: | class ManagedAgent: | ||||||
|     def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         agent, | ||||||
|  |         name, | ||||||
|  |         description, | ||||||
|  |         additional_prompting=None, | ||||||
|  |         provide_run_summary=False, | ||||||
|  |     ): | ||||||
|         self.agent = agent |         self.agent = agent | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.description = description |         self.description = description | ||||||
|  | @ -925,18 +1051,22 @@ Your final_answer WILL HAVE to contain these parts: | ||||||
| 
 | 
 | ||||||
| Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. | Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost. | ||||||
| And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. | And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. | ||||||
| <<additional_prompting>>""" | {{additional_prompting}}""" | ||||||
|         if self.additional_prompting: |         if self.additional_prompting: | ||||||
|             full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip() |             full_task = full_task.replace( | ||||||
|  |                 "\n{{additional_prompting}}", self.additional_prompting | ||||||
|  |             ).strip() | ||||||
|         else: |         else: | ||||||
|             full_task = full_task.replace("\n<<additional_prompting>>", "").strip() |             full_task = full_task.replace("\n{{additional_prompting}}", "").strip() | ||||||
|         return full_task |         return full_task | ||||||
| 
 | 
 | ||||||
|     def __call__(self, request, **kwargs): |     def __call__(self, request, **kwargs): | ||||||
|         full_task = self.write_full_task(request) |         full_task = self.write_full_task(request) | ||||||
|         output = self.agent.run(full_task, **kwargs) |         output = self.agent.run(full_task, **kwargs) | ||||||
|         if self.provide_run_summary: |         if self.provide_run_summary: | ||||||
|             answer = f"Here is the final answer from your managed agent '{self.name}':\n" |             answer = ( | ||||||
|  |                 f"Here is the final answer from your managed agent '{self.name}':\n" | ||||||
|  |             ) | ||||||
|             answer += str(output) |             answer += str(output) | ||||||
|             answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" |             answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" | ||||||
|             for message in self.agent.write_inner_memory_from_logs(summary_mode=True): |             for message in self.agent.write_inner_memory_from_logs(summary_mode=True): | ||||||
|  |  | ||||||
|  | @ -105,7 +105,9 @@ def get_remote_tools(logger, organization="huggingface-tools"): | ||||||
|     tools = {} |     tools = {} | ||||||
|     for space_info in spaces: |     for space_info in spaces: | ||||||
|         repo_id = space_info.id |         repo_id = space_info.id | ||||||
|         resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space") |         resolved_config_file = hf_hub_download( | ||||||
|  |             repo_id, TOOL_CONFIG_FILE, repo_type="space" | ||||||
|  |         ) | ||||||
|         with open(resolved_config_file, encoding="utf-8") as reader: |         with open(resolved_config_file, encoding="utf-8") as reader: | ||||||
|             config = json.load(reader) |             config = json.load(reader) | ||||||
|         task = repo_id.split("/")[-1] |         task = repo_id.split("/")[-1] | ||||||
|  | @ -131,7 +133,9 @@ class PythonInterpreterTool(Tool): | ||||||
|         if authorized_imports is None: |         if authorized_imports is None: | ||||||
|             self.authorized_imports = list(set(LIST_SAFE_MODULES)) |             self.authorized_imports = list(set(LIST_SAFE_MODULES)) | ||||||
|         else: |         else: | ||||||
|             self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) |             self.authorized_imports = list( | ||||||
|  |                 set(LIST_SAFE_MODULES) | set(authorized_imports) | ||||||
|  |             ) | ||||||
|         self.inputs = { |         self.inputs = { | ||||||
|             "code": { |             "code": { | ||||||
|                 "type": "string", |                 "type": "string", | ||||||
|  | @ -145,7 +149,11 @@ class PythonInterpreterTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def forward(self, code): |     def forward(self, code): | ||||||
|         output = str( |         output = str( | ||||||
|             evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports) |             evaluate_python_code( | ||||||
|  |                 code, | ||||||
|  |                 static_tools=BASE_PYTHON_TOOLS, | ||||||
|  |                 authorized_imports=self.authorized_imports, | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|         return output |         return output | ||||||
| 
 | 
 | ||||||
|  | @ -153,16 +161,21 @@ class PythonInterpreterTool(Tool): | ||||||
| class FinalAnswerTool(Tool): | class FinalAnswerTool(Tool): | ||||||
|     name = "final_answer" |     name = "final_answer" | ||||||
|     description = "Provides a final answer to the given problem." |     description = "Provides a final answer to the given problem." | ||||||
|     inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} |     inputs = { | ||||||
|  |         "answer": {"type": "any", "description": "The final answer to the problem"} | ||||||
|  |     } | ||||||
|     output_type = "any" |     output_type = "any" | ||||||
| 
 | 
 | ||||||
|     def forward(self, answer): |     def forward(self, answer): | ||||||
|         return answer |         return answer | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class UserInputTool(Tool): | class UserInputTool(Tool): | ||||||
|     name = "user_input" |     name = "user_input" | ||||||
|     description = "Asks for user's input on a specific question" |     description = "Asks for user's input on a specific question" | ||||||
|     inputs = {"question": {"type": "string", "description": "The question to ask the user"}} |     inputs = { | ||||||
|  |         "question": {"type": "string", "description": "The question to ask the user"} | ||||||
|  |     } | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|     def forward(self, question): |     def forward(self, question): | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText | ||||||
| from .agents import BaseAgent, AgentStep, ActionStep | from .agents import BaseAgent, AgentStep, ActionStep | ||||||
| import gradio as gr | import gradio as gr | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|  | @ -33,7 +34,9 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|                 content=str(content), |                 content=str(content), | ||||||
|             ) |             ) | ||||||
|         if step_log.observation is not None: |         if step_log.observation is not None: | ||||||
|             yield gr.ChatMessage(role="assistant", content=f"```\n{step_log.observation}\n```") |             yield gr.ChatMessage( | ||||||
|  |                 role="assistant", content=f"```\n{step_log.observation}\n```" | ||||||
|  |             ) | ||||||
|         if step_log.error is not None: |         if step_log.error is not None: | ||||||
|             yield gr.ChatMessage( |             yield gr.ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|  | @ -42,7 +45,13 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memory: bool=False, **kwargs): | def stream_to_gradio( | ||||||
|  |     agent, | ||||||
|  |     task: str, | ||||||
|  |     test_mode: bool = False, | ||||||
|  |     reset_agent_memory: bool = False, | ||||||
|  |     **kwargs, | ||||||
|  | ): | ||||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||||
| 
 | 
 | ||||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): |     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): | ||||||
|  | @ -52,7 +61,10 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo | ||||||
|     final_answer = step_log  # Last log is the run's final_answer |     final_answer = step_log  # Last log is the run's final_answer | ||||||
| 
 | 
 | ||||||
|     if isinstance(final_answer, AgentText): |     if isinstance(final_answer, AgentText): | ||||||
|         yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```") |         yield gr.ChatMessage( | ||||||
|  |             role="assistant", | ||||||
|  |             content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```", | ||||||
|  |         ) | ||||||
|     elif isinstance(final_answer, AgentImage): |     elif isinstance(final_answer, AgentImage): | ||||||
|         yield gr.ChatMessage( |         yield gr.ChatMessage( | ||||||
|             role="assistant", |             role="assistant", | ||||||
|  | @ -67,8 +79,9 @@ def stream_to_gradio(agent, task: str, test_mode: bool = False, reset_agent_memo | ||||||
|         yield gr.ChatMessage(role="assistant", content=str(final_answer)) |         yield gr.ChatMessage(role="assistant", content=str(final_answer)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class GradioUI(): | class GradioUI: | ||||||
|     """A one-line interface to launch your agent in Gradio""" |     """A one-line interface to launch your agent in Gradio""" | ||||||
|  | 
 | ||||||
|     def __init__(self, agent: BaseAgent): |     def __init__(self, agent: BaseAgent): | ||||||
|         self.agent = agent |         self.agent = agent | ||||||
| 
 | 
 | ||||||
|  | @ -83,10 +96,17 @@ class GradioUI(): | ||||||
|     def run(self): |     def run(self): | ||||||
|         with gr.Blocks() as demo: |         with gr.Blocks() as demo: | ||||||
|             stored_message = gr.State([]) |             stored_message = gr.State([]) | ||||||
|             chatbot = gr.Chatbot(label="Agent", |             chatbot = gr.Chatbot( | ||||||
|  |                 label="Agent", | ||||||
|                 type="messages", |                 type="messages", | ||||||
|                                 avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")) |                 avatar_images=( | ||||||
|  |                     None, | ||||||
|  |                     "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|             text_input = gr.Textbox(lines=1, label="Chat Message") |             text_input = gr.Textbox(lines=1, label="Chat Message") | ||||||
|             text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(self.interact_with_agent, [stored_message, chatbot], [chatbot]) |             text_input.submit( | ||||||
|  |                 lambda s: (s, ""), [text_input], [stored_message, text_input] | ||||||
|  |             ).then(self.interact_with_agent, [stored_message, chatbot], [chatbot]) | ||||||
| 
 | 
 | ||||||
|         demo.launch() |         demo.launch() | ||||||
|  | @ -39,7 +39,9 @@ class MessageRole(str, Enum): | ||||||
|         return [r.value for r in cls] |         return [r.value for r in cls] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}): | def get_clean_message_list( | ||||||
|  |     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} | ||||||
|  | ): | ||||||
|     """ |     """ | ||||||
|     Subsequent messages with the same role will be concatenated to a single message. |     Subsequent messages with the same role will be concatenated to a single message. | ||||||
| 
 | 
 | ||||||
|  | @ -54,12 +56,17 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: | ||||||
| 
 | 
 | ||||||
|         role = message["role"] |         role = message["role"] | ||||||
|         if role not in MessageRole.roles(): |         if role not in MessageRole.roles(): | ||||||
|             raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") |             raise ValueError( | ||||||
|  |                 f"Incorrect role {role}, only {MessageRole.roles()} are supported for now." | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if role in role_conversions: |         if role in role_conversions: | ||||||
|             message["role"] = role_conversions[role] |             message["role"] = role_conversions[role] | ||||||
| 
 | 
 | ||||||
|         if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: |         if ( | ||||||
|  |             len(final_message_list) > 0 | ||||||
|  |             and message["role"] == final_message_list[-1]["role"] | ||||||
|  |         ): | ||||||
|             final_message_list[-1]["content"] += "\n=======\n" + message["content"] |             final_message_list[-1]["content"] += "\n=======\n" + message["content"] | ||||||
|         else: |         else: | ||||||
|             final_message_list.append(message) |             final_message_list.append(message) | ||||||
|  | @ -81,8 +88,12 @@ class HfEngine: | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.") |             logger.warning( | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct") |                 f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." | ||||||
|  |             ) | ||||||
|  |             self.tokenizer = AutoTokenizer.from_pretrained( | ||||||
|  |                 "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def get_token_counts(self): |     def get_token_counts(self): | ||||||
|         return { |         return { | ||||||
|  | @ -91,12 +102,18 @@ class HfEngine: | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def generate( |     def generate( | ||||||
|         self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|     def __call__( |     def __call__( | ||||||
|         self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|  | @ -127,11 +144,15 @@ class HfEngine: | ||||||
|             ``` |             ``` | ||||||
|         """ |         """ | ||||||
|         if not isinstance(messages, List): |         if not isinstance(messages, List): | ||||||
|             raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.") |             raise ValueError( | ||||||
|  |                 "Messages should be a list of dictionaries with 'role' and 'content' keys." | ||||||
|  |             ) | ||||||
|         if stop_sequences is None: |         if stop_sequences is None: | ||||||
|             stop_sequences = [] |             stop_sequences = [] | ||||||
|         response = self.generate(messages, stop_sequences, grammar) |         response = self.generate(messages, stop_sequences, grammar) | ||||||
|         self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True)) |         self.last_input_token_count = len( | ||||||
|  |             self.tokenizer.apply_chat_template(messages, tokenize=True) | ||||||
|  |         ) | ||||||
|         self.last_output_token_count = len(self.tokenizer.encode(response)) |         self.last_output_token_count = len(self.tokenizer.encode(response)) | ||||||
| 
 | 
 | ||||||
|         # Remove stop sequences from LLM output |         # Remove stop sequences from LLM output | ||||||
|  | @ -175,18 +196,28 @@ class HfApiEngine(HfEngine): | ||||||
|         self.max_tokens = max_tokens |         self.max_tokens = max_tokens | ||||||
| 
 | 
 | ||||||
|     def generate( |     def generate( | ||||||
|         self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None |         self, | ||||||
|  |         messages: List[Dict[str, str]], | ||||||
|  |         stop_sequences: Optional[List[str]] = None, | ||||||
|  |         grammar: Optional[str] = None, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         # Get clean message list |         # Get clean message list | ||||||
|         messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=llama_role_conversions | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # Send messages to the Hugging Face Inference API |         # Send messages to the Hugging Face Inference API | ||||||
|         if grammar is not None: |         if grammar is not None: | ||||||
|             response = self.client.chat_completion( |             response = self.client.chat_completion( | ||||||
|                 messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar |                 messages, | ||||||
|  |                 stop=stop_sequences, | ||||||
|  |                 max_tokens=self.max_tokens, | ||||||
|  |                 response_format=grammar, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens) |             response = self.client.chat_completion( | ||||||
|  |                 messages, stop=stop_sequences, max_tokens=self.max_tokens | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         response = response.choices[0].message.content |         response = response.choices[0].message.content | ||||||
|         return response |         return response | ||||||
|  | @ -207,7 +238,9 @@ class TransformersEngine(HfEngine): | ||||||
|         max_length: int = 1500, |         max_length: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         # Get clean message list |         # Get clean message list | ||||||
|         messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) |         messages = get_clean_message_list( | ||||||
|  |             messages, role_conversions=llama_role_conversions | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # Get LLM output |         # Get LLM output | ||||||
|         if stop_sequences is not None and len(stop_sequences) > 0: |         if stop_sequences is not None and len(stop_sequences) > 0: | ||||||
|  |  | ||||||
|  | @ -17,12 +17,14 @@ | ||||||
| from .utils import console | from .utils import console | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| class Monitor: | class Monitor: | ||||||
|     def __init__(self, tracked_llm_engine): |     def __init__(self, tracked_llm_engine): | ||||||
|         self.step_durations = [] |         self.step_durations = [] | ||||||
|         self.tracked_llm_engine = tracked_llm_engine |         self.tracked_llm_engine = tracked_llm_engine | ||||||
|         if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found": |         if ( | ||||||
|  |             getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") | ||||||
|  |             != "Not found" | ||||||
|  |         ): | ||||||
|             self.total_input_token_count = 0 |             self.total_input_token_count = 0 | ||||||
|             self.total_output_token_count = 0 |             self.total_output_token_count = 0 | ||||||
| 
 | 
 | ||||||
|  | @ -33,103 +35,11 @@ class Monitor: | ||||||
|         console.print(f"- Time taken: {step_duration:.2f} seconds") |         console.print(f"- Time taken: {step_duration:.2f} seconds") | ||||||
| 
 | 
 | ||||||
|         if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: |         if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: | ||||||
|             self.total_input_token_count += self.tracked_llm_engine.last_input_token_count |             self.total_input_token_count += ( | ||||||
|             self.total_output_token_count += self.tracked_llm_engine.last_output_token_count |                 self.tracked_llm_engine.last_input_token_count | ||||||
|  |             ) | ||||||
|  |             self.total_output_token_count += ( | ||||||
|  |                 self.tracked_llm_engine.last_output_token_count | ||||||
|  |             ) | ||||||
|             console.print(f"- Input tokens: {self.total_input_token_count:,}") |             console.print(f"- Input tokens: {self.total_input_token_count:,}") | ||||||
|             console.print(f"- Output tokens: {self.total_output_token_count:,}") |             console.print(f"- Output tokens: {self.total_output_token_count:,}") | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| from typing import Optional, Union, List, Any |  | ||||||
| import httpx |  | ||||||
| import logging |  | ||||||
| import os |  | ||||||
| from langfuse.client import Langfuse, StatefulTraceClient, StatefulSpanClient, StateType |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class BaseTracker: |  | ||||||
|     def __init__(self): |  | ||||||
|         pass |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def call(cls, *args, **kwargs): |  | ||||||
|         pass |  | ||||||
| 
 |  | ||||||
| class LangfuseTracker(BaseTracker): |  | ||||||
|     log = logging.getLogger("langfuse") |  | ||||||
| 
 |  | ||||||
|     def __init__(self, *, public_key: Optional[str] = None, secret_key: Optional[str] = None, |  | ||||||
|                  host: Optional[str] = None, debug: bool = False, stateful_client: Optional[ |  | ||||||
|                 Union[StatefulTraceClient, StatefulSpanClient] |  | ||||||
|             ] = None, update_stateful_client: bool = False, version: Optional[str] = None, |  | ||||||
|                  session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None, |  | ||||||
|                  release: Optional[str] = None, metadata: Optional[Any] = None, tags: Optional[List[str]] = None, |  | ||||||
|                  threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None, |  | ||||||
|                  max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None, |  | ||||||
|                  httpx_client: Optional[httpx.Client] = None, sdk_integration: str = "default") -> None: |  | ||||||
|         super().__init__() |  | ||||||
|         self.version = version |  | ||||||
|         self.session_id = session_id |  | ||||||
|         self.user_id = user_id |  | ||||||
|         self.trace_name = trace_name |  | ||||||
|         self.release = release |  | ||||||
|         self.metadata = metadata |  | ||||||
|         self.tags = tags |  | ||||||
| 
 |  | ||||||
|         self.root_span = None |  | ||||||
|         self.update_stateful_client = update_stateful_client |  | ||||||
|         self.langfuse = None |  | ||||||
| 
 |  | ||||||
|         prio_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY") |  | ||||||
|         prio_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY") |  | ||||||
|         prio_host = host or os.environ.get( |  | ||||||
|             "LANGFUSE_HOST", "https://cloud.langfuse.com" |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         if stateful_client and isinstance(stateful_client, StatefulTraceClient): |  | ||||||
|             self.trace = stateful_client |  | ||||||
|             self._task_manager = stateful_client.task_manager |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         elif stateful_client and isinstance(stateful_client, StatefulSpanClient): |  | ||||||
|             self.root_span = stateful_client |  | ||||||
|             self.trace = StatefulTraceClient( |  | ||||||
|                 stateful_client.client, |  | ||||||
|                 stateful_client.trace_id, |  | ||||||
|                 StateType.TRACE, |  | ||||||
|                 stateful_client.trace_id, |  | ||||||
|                 stateful_client.task_manager, |  | ||||||
|             ) |  | ||||||
|             self._task_manager = stateful_client.task_manager |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         args = { |  | ||||||
|             "public_key": prio_public_key, |  | ||||||
|             "secret_key": prio_secret_key, |  | ||||||
|             "host": prio_host, |  | ||||||
|             "debug": debug, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if release is not None: |  | ||||||
|             args["release"] = release |  | ||||||
|         if threads is not None: |  | ||||||
|             args["threads"] = threads |  | ||||||
|         if flush_at is not None: |  | ||||||
|             args["flush_at"] = flush_at |  | ||||||
|         if flush_interval is not None: |  | ||||||
|             args["flush_interval"] = flush_interval |  | ||||||
|         if max_retries is not None: |  | ||||||
|             args["max_retries"] = max_retries |  | ||||||
|         if timeout is not None: |  | ||||||
|             args["timeout"] = timeout |  | ||||||
|         if enabled is not None: |  | ||||||
|             args["enabled"] = enabled |  | ||||||
|         if httpx_client is not None: |  | ||||||
|             args["httpx_client"] = httpx_client |  | ||||||
|         args["sdk_integration"] = sdk_integration |  | ||||||
| 
 |  | ||||||
|         self.langfuse = Langfuse(**args) |  | ||||||
|         self.trace: Optional[StatefulTraceClient] = None |  | ||||||
|         self._task_manager = self.langfuse.task_manager |  | ||||||
| 
 |  | ||||||
|     def call(self, i, o, name=None, **kwargs): |  | ||||||
|         self.langfuse.trace(input=i, output=o, name=name, metadata=kwargs) |  | ||||||
|  | @ -42,7 +42,10 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): | ||||||
|         return prompt_or_repo_id |         return prompt_or_repo_id | ||||||
| 
 | 
 | ||||||
|     prompt_file = cached_file( |     prompt_file = cached_file( | ||||||
|         prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name} |         prompt_or_repo_id, | ||||||
|  |         PROMPT_FILES[mode], | ||||||
|  |         repo_type="dataset", | ||||||
|  |         user_agent={"agent": agent_name}, | ||||||
|     ) |     ) | ||||||
|     with open(prompt_file, "r", encoding="utf-8") as f: |     with open(prompt_file, "r", encoding="utf-8") as f: | ||||||
|         return f.read() |         return f.read() | ||||||
|  |  | ||||||
|  | @ -26,6 +26,7 @@ import pandas as pd | ||||||
| 
 | 
 | ||||||
| from .utils import truncate_content | from .utils import truncate_content | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class InterpreterError(ValueError): | class InterpreterError(ValueError): | ||||||
|     """ |     """ | ||||||
|     An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported |     An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported | ||||||
|  | @ -38,7 +39,8 @@ class InterpreterError(ValueError): | ||||||
| ERRORS = { | ERRORS = { | ||||||
|     name: getattr(builtins, name) |     name: getattr(builtins, name) | ||||||
|     for name in dir(builtins) |     for name in dir(builtins) | ||||||
|     if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) |     if isinstance(getattr(builtins, name), type) | ||||||
|  |     and issubclass(getattr(builtins, name), BaseException) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -92,7 +94,9 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools): | ||||||
|     elif isinstance(expression.op, ast.Invert): |     elif isinstance(expression.op, ast.Invert): | ||||||
|         return ~operand |         return ~operand | ||||||
|     else: |     else: | ||||||
|         raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") |         raise InterpreterError( | ||||||
|  |             f"Unary operation {expression.op.__class__.__name__} is not supported." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): | def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): | ||||||
|  | @ -102,7 +106,9 @@ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools): | ||||||
|         new_state = state.copy() |         new_state = state.copy() | ||||||
|         for arg, value in zip(args, values): |         for arg, value in zip(args, values): | ||||||
|             new_state[arg] = value |             new_state[arg] = value | ||||||
|         return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools) |         return evaluate_ast( | ||||||
|  |             lambda_expression.body, new_state, static_tools, custom_tools | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     return lambda_func |     return lambda_func | ||||||
| 
 | 
 | ||||||
|  | @ -120,7 +126,9 @@ def evaluate_while(while_loop, state, static_tools, custom_tools): | ||||||
|                 break |                 break | ||||||
|         iterations += 1 |         iterations += 1 | ||||||
|         if iterations > max_iterations: |         if iterations > max_iterations: | ||||||
|             raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded") |             raise InterpreterError( | ||||||
|  |                 f"Maximum number of {max_iterations} iterations in While loop exceeded" | ||||||
|  |             ) | ||||||
|     return None |     return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -128,7 +136,10 @@ def create_function(func_def, state, static_tools, custom_tools): | ||||||
|     def new_func(*args, **kwargs): |     def new_func(*args, **kwargs): | ||||||
|         func_state = state.copy() |         func_state = state.copy() | ||||||
|         arg_names = [arg.arg for arg in func_def.args.args] |         arg_names = [arg.arg for arg in func_def.args.args] | ||||||
|         default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults] |         default_values = [ | ||||||
|  |             evaluate_ast(d, state, static_tools, custom_tools) | ||||||
|  |             for d in func_def.args.defaults | ||||||
|  |         ] | ||||||
| 
 | 
 | ||||||
|         # Apply default values |         # Apply default values | ||||||
|         defaults = dict(zip(arg_names[-len(default_values) :], default_values)) |         defaults = dict(zip(arg_names[-len(default_values) :], default_values)) | ||||||
|  | @ -180,26 +191,39 @@ def create_class(class_name, class_bases, class_body): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_function_def(func_def, state, static_tools, custom_tools): | def evaluate_function_def(func_def, state, static_tools, custom_tools): | ||||||
|     custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools) |     custom_tools[func_def.name] = create_function( | ||||||
|  |         func_def, state, static_tools, custom_tools | ||||||
|  |     ) | ||||||
|     return custom_tools[func_def.name] |     return custom_tools[func_def.name] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_class_def(class_def, state, static_tools, custom_tools): | def evaluate_class_def(class_def, state, static_tools, custom_tools): | ||||||
|     class_name = class_def.name |     class_name = class_def.name | ||||||
|     bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases] |     bases = [ | ||||||
|  |         evaluate_ast(base, state, static_tools, custom_tools) | ||||||
|  |         for base in class_def.bases | ||||||
|  |     ] | ||||||
|     class_dict = {} |     class_dict = {} | ||||||
| 
 | 
 | ||||||
|     for stmt in class_def.body: |     for stmt in class_def.body: | ||||||
|         if isinstance(stmt, ast.FunctionDef): |         if isinstance(stmt, ast.FunctionDef): | ||||||
|             class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools) |             class_dict[stmt.name] = evaluate_function_def( | ||||||
|  |                 stmt, state, static_tools, custom_tools | ||||||
|  |             ) | ||||||
|         elif isinstance(stmt, ast.Assign): |         elif isinstance(stmt, ast.Assign): | ||||||
|             for target in stmt.targets: |             for target in stmt.targets: | ||||||
|                 if isinstance(target, ast.Name): |                 if isinstance(target, ast.Name): | ||||||
|                     class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools) |                     class_dict[target.id] = evaluate_ast( | ||||||
|  |                         stmt.value, state, static_tools, custom_tools | ||||||
|  |                     ) | ||||||
|                 elif isinstance(target, ast.Attribute): |                 elif isinstance(target, ast.Attribute): | ||||||
|                     class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools) |                     class_dict[target.attr] = evaluate_ast( | ||||||
|  |                         stmt.value, state, static_tools, custom_tools | ||||||
|  |                     ) | ||||||
|         else: |         else: | ||||||
|             raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") |             raise InterpreterError( | ||||||
|  |                 f"Unsupported statement in class body: {stmt.__class__.__name__}" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     new_class = type(class_name, tuple(bases), class_dict) |     new_class = type(class_name, tuple(bases), class_dict) | ||||||
|     state[class_name] = new_class |     state[class_name] = new_class | ||||||
|  | @ -223,7 +247,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools): | ||||||
|         elif isinstance(target, ast.List): |         elif isinstance(target, ast.List): | ||||||
|             return [get_current_value(elt) for elt in target.elts] |             return [get_current_value(elt) for elt in target.elts] | ||||||
|         else: |         else: | ||||||
|             raise InterpreterError("AugAssign not supported for {type(target)} targets.") |             raise InterpreterError( | ||||||
|  |                 "AugAssign not supported for {type(target)} targets." | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     current_value = get_current_value(expression.target) |     current_value = get_current_value(expression.target) | ||||||
|     value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) |     value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools) | ||||||
|  | @ -232,7 +258,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools): | ||||||
|     if isinstance(expression.op, ast.Add): |     if isinstance(expression.op, ast.Add): | ||||||
|         if isinstance(current_value, list): |         if isinstance(current_value, list): | ||||||
|             if not isinstance(value_to_add, list): |             if not isinstance(value_to_add, list): | ||||||
|                 raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") |                 raise InterpreterError( | ||||||
|  |                     f"Cannot add non-list value {value_to_add} to a list." | ||||||
|  |                 ) | ||||||
|             updated_value = current_value + value_to_add |             updated_value = current_value + value_to_add | ||||||
|         else: |         else: | ||||||
|             updated_value = current_value + value_to_add |             updated_value = current_value + value_to_add | ||||||
|  | @ -259,7 +287,9 @@ def evaluate_augassign(expression, state, static_tools, custom_tools): | ||||||
|     elif isinstance(expression.op, ast.RShift): |     elif isinstance(expression.op, ast.RShift): | ||||||
|         updated_value = current_value >> value_to_add |         updated_value = current_value >> value_to_add | ||||||
|     else: |     else: | ||||||
|         raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") |         raise InterpreterError( | ||||||
|  |             f"Operation {type(expression.op).__name__} is not supported." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     # Update the state |     # Update the state | ||||||
|     set_value(expression.target, updated_value, state, static_tools, custom_tools) |     set_value(expression.target, updated_value, state, static_tools, custom_tools) | ||||||
|  | @ -311,7 +341,9 @@ def evaluate_binop(binop, state, static_tools, custom_tools): | ||||||
|     elif isinstance(binop.op, ast.RShift): |     elif isinstance(binop.op, ast.RShift): | ||||||
|         return left_val >> right_val |         return left_val >> right_val | ||||||
|     else: |     else: | ||||||
|         raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") |         raise NotImplementedError( | ||||||
|  |             f"Binary operation {type(binop.op).__name__} is not implemented." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_assign(assign, state, static_tools, custom_tools): | def evaluate_assign(assign, state, static_tools, custom_tools): | ||||||
|  | @ -321,7 +353,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools): | ||||||
|         set_value(target, result, state, static_tools, custom_tools) |         set_value(target, result, state, static_tools, custom_tools) | ||||||
|     else: |     else: | ||||||
|         if len(assign.targets) != len(result): |         if len(assign.targets) != len(result): | ||||||
|             raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") |             raise InterpreterError( | ||||||
|  |                 f"Assign failed: expected {len(result)} values but got {len(assign.targets)}." | ||||||
|  |             ) | ||||||
|         expanded_values = [] |         expanded_values = [] | ||||||
|         for tgt in assign.targets: |         for tgt in assign.targets: | ||||||
|             if isinstance(tgt, ast.Starred): |             if isinstance(tgt, ast.Starred): | ||||||
|  | @ -336,7 +370,9 @@ def evaluate_assign(assign, state, static_tools, custom_tools): | ||||||
| def set_value(target, value, state, static_tools, custom_tools): | def set_value(target, value, state, static_tools, custom_tools): | ||||||
|     if isinstance(target, ast.Name): |     if isinstance(target, ast.Name): | ||||||
|         if target.id in static_tools: |         if target.id in static_tools: | ||||||
|             raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") |             raise InterpreterError( | ||||||
|  |                 f"Cannot assign to name '{target.id}': doing this would erase the existing tool!" | ||||||
|  |             ) | ||||||
|         state[target.id] = value |         state[target.id] = value | ||||||
|     elif isinstance(target, ast.Tuple): |     elif isinstance(target, ast.Tuple): | ||||||
|         if not isinstance(value, tuple): |         if not isinstance(value, tuple): | ||||||
|  | @ -399,9 +435,14 @@ def evaluate_call(call, state, static_tools, custom_tools): | ||||||
|         else: |         else: | ||||||
|             args.append(evaluate_ast(arg, state, static_tools, custom_tools)) |             args.append(evaluate_ast(arg, state, static_tools, custom_tools)) | ||||||
| 
 | 
 | ||||||
|     kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords} |     kwargs = { | ||||||
|  |         keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) | ||||||
|  |         for keyword in call.keywords | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     if isinstance(func, type) and len(func.__module__.split(".")) > 1:  # Check for user-defined classes |     if ( | ||||||
|  |         isinstance(func, type) and len(func.__module__.split(".")) > 1 | ||||||
|  |     ):  # Check for user-defined classes | ||||||
|         # Instantiate the class using its constructor |         # Instantiate the class using its constructor | ||||||
|         obj = func.__new__(func)  # Create a new instance of the class |         obj = func.__new__(func)  # Create a new instance of the class | ||||||
|         if hasattr(obj, "__init__"):  # Check if the class has an __init__ method |         if hasattr(obj, "__init__"):  # Check if the class has an __init__ method | ||||||
|  | @ -441,7 +482,9 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools): | ||||||
|     value = evaluate_ast(subscript.value, state, static_tools, custom_tools) |     value = evaluate_ast(subscript.value, state, static_tools, custom_tools) | ||||||
| 
 | 
 | ||||||
|     if isinstance(value, str) and isinstance(index, str): |     if isinstance(value, str) and isinstance(index, str): | ||||||
|         raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") |         raise InterpreterError( | ||||||
|  |             "You're trying to subscript a string with a string index, which is impossible" | ||||||
|  |         ) | ||||||
|     if isinstance(value, pd.core.indexing._LocIndexer): |     if isinstance(value, pd.core.indexing._LocIndexer): | ||||||
|         parent_object = value.obj |         parent_object = value.obj | ||||||
|         return parent_object.loc[index] |         return parent_object.loc[index] | ||||||
|  | @ -453,11 +496,15 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools): | ||||||
|         return value[index] |         return value[index] | ||||||
|     elif isinstance(value, (list, tuple)): |     elif isinstance(value, (list, tuple)): | ||||||
|         if not (-len(value) <= index < len(value)): |         if not (-len(value) <= index < len(value)): | ||||||
|             raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") |             raise InterpreterError( | ||||||
|  |                 f"Index {index} out of bounds for list of length {len(value)}" | ||||||
|  |             ) | ||||||
|         return value[int(index)] |         return value[int(index)] | ||||||
|     elif isinstance(value, str): |     elif isinstance(value, str): | ||||||
|         if not (-len(value) <= index < len(value)): |         if not (-len(value) <= index < len(value)): | ||||||
|             raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") |             raise InterpreterError( | ||||||
|  |                 f"Index {index} out of bounds for string of length {len(value)}" | ||||||
|  |             ) | ||||||
|         return value[index] |         return value[index] | ||||||
|     elif index in value: |     elif index in value: | ||||||
|         return value[index] |         return value[index] | ||||||
|  | @ -483,7 +530,10 @@ def evaluate_name(name, state, static_tools, custom_tools): | ||||||
| 
 | 
 | ||||||
| def evaluate_condition(condition, state, static_tools, custom_tools): | def evaluate_condition(condition, state, static_tools, custom_tools): | ||||||
|     left = evaluate_ast(condition.left, state, static_tools, custom_tools) |     left = evaluate_ast(condition.left, state, static_tools, custom_tools) | ||||||
|     comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators] |     comparators = [ | ||||||
|  |         evaluate_ast(c, state, static_tools, custom_tools) | ||||||
|  |         for c in condition.comparators | ||||||
|  |     ] | ||||||
|     ops = [type(op) for op in condition.ops] |     ops = [type(op) for op in condition.ops] | ||||||
| 
 | 
 | ||||||
|     result = True |     result = True | ||||||
|  | @ -561,9 +611,13 @@ def evaluate_for(for_loop, state, static_tools, custom_tools): | ||||||
| def evaluate_listcomp(listcomp, state, static_tools, custom_tools): | def evaluate_listcomp(listcomp, state, static_tools, custom_tools): | ||||||
|     def inner_evaluate(generators, index, current_state): |     def inner_evaluate(generators, index, current_state): | ||||||
|         if index >= len(generators): |         if index >= len(generators): | ||||||
|             return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)] |             return [ | ||||||
|  |                 evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools) | ||||||
|  |             ] | ||||||
|         generator = generators[index] |         generator = generators[index] | ||||||
|         iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools) |         iter_value = evaluate_ast( | ||||||
|  |             generator.iter, current_state, static_tools, custom_tools | ||||||
|  |         ) | ||||||
|         result = [] |         result = [] | ||||||
|         for value in iter_value: |         for value in iter_value: | ||||||
|             new_state = current_state.copy() |             new_state = current_state.copy() | ||||||
|  | @ -572,7 +626,10 @@ def evaluate_listcomp(listcomp, state, static_tools, custom_tools): | ||||||
|                     new_state[elem.id] = value[idx] |                     new_state[elem.id] = value[idx] | ||||||
|             else: |             else: | ||||||
|                 new_state[generator.target.id] = value |                 new_state[generator.target.id] = value | ||||||
|             if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs): |             if all( | ||||||
|  |                 evaluate_ast(if_clause, new_state, static_tools, custom_tools) | ||||||
|  |                 for if_clause in generator.ifs | ||||||
|  |             ): | ||||||
|                 result.extend(inner_evaluate(generators, index + 1, new_state)) |                 result.extend(inner_evaluate(generators, index + 1, new_state)) | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|  | @ -586,7 +643,9 @@ def evaluate_try(try_node, state, static_tools, custom_tools): | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         matched = False |         matched = False | ||||||
|         for handler in try_node.handlers: |         for handler in try_node.handlers: | ||||||
|             if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)): |             if handler.type is None or isinstance( | ||||||
|  |                 e, evaluate_ast(handler.type, state, static_tools, custom_tools) | ||||||
|  |             ): | ||||||
|                 matched = True |                 matched = True | ||||||
|                 if handler.name: |                 if handler.name: | ||||||
|                     state[handler.name] = e |                     state[handler.name] = e | ||||||
|  | @ -638,7 +697,9 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools): | ||||||
| def evaluate_with(with_node, state, static_tools, custom_tools): | def evaluate_with(with_node, state, static_tools, custom_tools): | ||||||
|     contexts = [] |     contexts = [] | ||||||
|     for item in with_node.items: |     for item in with_node.items: | ||||||
|         context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools) |         context_expr = evaluate_ast( | ||||||
|  |             item.context_expr, state, static_tools, custom_tools | ||||||
|  |         ) | ||||||
|         if item.optional_vars: |         if item.optional_vars: | ||||||
|             state[item.optional_vars.id] = context_expr.__enter__() |             state[item.optional_vars.id] = context_expr.__enter__() | ||||||
|             contexts.append(state[item.optional_vars.id]) |             contexts.append(state[item.optional_vars.id]) | ||||||
|  | @ -661,7 +722,9 @@ def evaluate_with(with_node, state, static_tools, custom_tools): | ||||||
| def import_modules(expression, state, authorized_imports): | def import_modules(expression, state, authorized_imports): | ||||||
|     def check_module_authorized(module_name): |     def check_module_authorized(module_name): | ||||||
|         module_path = module_name.split(".") |         module_path = module_name.split(".") | ||||||
|         module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] |         module_subpaths = [ | ||||||
|  |             ".".join(module_path[:i]) for i in range(1, len(module_path) + 1) | ||||||
|  |         ] | ||||||
|         return any(subpath in authorized_imports for subpath in module_subpaths) |         return any(subpath in authorized_imports for subpath in module_subpaths) | ||||||
| 
 | 
 | ||||||
|     if isinstance(expression, ast.Import): |     if isinstance(expression, ast.Import): | ||||||
|  | @ -676,7 +739,9 @@ def import_modules(expression, state, authorized_imports): | ||||||
|         return None |         return None | ||||||
|     elif isinstance(expression, ast.ImportFrom): |     elif isinstance(expression, ast.ImportFrom): | ||||||
|         if check_module_authorized(expression.module): |         if check_module_authorized(expression.module): | ||||||
|             module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) |             module = __import__( | ||||||
|  |                 expression.module, fromlist=[alias.name for alias in expression.names] | ||||||
|  |             ) | ||||||
|             for alias in expression.names: |             for alias in expression.names: | ||||||
|                 state[alias.asname or alias.name] = getattr(module, alias.name) |                 state[alias.asname or alias.name] = getattr(module, alias.name) | ||||||
|         else: |         else: | ||||||
|  | @ -691,9 +756,14 @@ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools): | ||||||
|         for value in iter_value: |         for value in iter_value: | ||||||
|             new_state = state.copy() |             new_state = state.copy() | ||||||
|             set_value(gen.target, value, new_state, static_tools, custom_tools) |             set_value(gen.target, value, new_state, static_tools, custom_tools) | ||||||
|             if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs): |             if all( | ||||||
|  |                 evaluate_ast(if_clause, new_state, static_tools, custom_tools) | ||||||
|  |                 for if_clause in gen.ifs | ||||||
|  |             ): | ||||||
|                 key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools) |                 key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools) | ||||||
|                 val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools) |                 val = evaluate_ast( | ||||||
|  |                     dictcomp.value, new_state, static_tools, custom_tools | ||||||
|  |                 ) | ||||||
|                 result[key] = val |                 result[key] = val | ||||||
|     return result |     return result | ||||||
| 
 | 
 | ||||||
|  | @ -744,7 +814,10 @@ def evaluate_ast( | ||||||
|         # Constant -> just return the value |         # Constant -> just return the value | ||||||
|         return expression.value |         return expression.value | ||||||
|     elif isinstance(expression, ast.Tuple): |     elif isinstance(expression, ast.Tuple): | ||||||
|         return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts) |         return tuple( | ||||||
|  |             evaluate_ast(elt, state, static_tools, custom_tools) | ||||||
|  |             for elt in expression.elts | ||||||
|  |         ) | ||||||
|     elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): |     elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): | ||||||
|         return evaluate_listcomp(expression, state, static_tools, custom_tools) |         return evaluate_listcomp(expression, state, static_tools, custom_tools) | ||||||
|     elif isinstance(expression, ast.UnaryOp): |     elif isinstance(expression, ast.UnaryOp): | ||||||
|  | @ -770,8 +843,13 @@ def evaluate_ast( | ||||||
|         return evaluate_function_def(expression, state, static_tools, custom_tools) |         return evaluate_function_def(expression, state, static_tools, custom_tools) | ||||||
|     elif isinstance(expression, ast.Dict): |     elif isinstance(expression, ast.Dict): | ||||||
|         # Dict -> evaluate all keys and values |         # Dict -> evaluate all keys and values | ||||||
|         keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys] |         keys = [ | ||||||
|         values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values] |             evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys | ||||||
|  |         ] | ||||||
|  |         values = [ | ||||||
|  |             evaluate_ast(v, state, static_tools, custom_tools) | ||||||
|  |             for v in expression.values | ||||||
|  |         ] | ||||||
|         return dict(zip(keys, values)) |         return dict(zip(keys, values)) | ||||||
|     elif isinstance(expression, ast.Expr): |     elif isinstance(expression, ast.Expr): | ||||||
|         # Expression -> evaluate the content |         # Expression -> evaluate the content | ||||||
|  | @ -788,10 +866,18 @@ def evaluate_ast( | ||||||
|     elif hasattr(ast, "Index") and isinstance(expression, ast.Index): |     elif hasattr(ast, "Index") and isinstance(expression, ast.Index): | ||||||
|         return evaluate_ast(expression.value, state, static_tools, custom_tools) |         return evaluate_ast(expression.value, state, static_tools, custom_tools) | ||||||
|     elif isinstance(expression, ast.JoinedStr): |     elif isinstance(expression, ast.JoinedStr): | ||||||
|         return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values]) |         return "".join( | ||||||
|  |             [ | ||||||
|  |                 str(evaluate_ast(v, state, static_tools, custom_tools)) | ||||||
|  |                 for v in expression.values | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|     elif isinstance(expression, ast.List): |     elif isinstance(expression, ast.List): | ||||||
|         # List -> evaluate all elements |         # List -> evaluate all elements | ||||||
|         return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts] |         return [ | ||||||
|  |             evaluate_ast(elt, state, static_tools, custom_tools) | ||||||
|  |             for elt in expression.elts | ||||||
|  |         ] | ||||||
|     elif isinstance(expression, ast.Name): |     elif isinstance(expression, ast.Name): | ||||||
|         # Name -> pick up the value in the state |         # Name -> pick up the value in the state | ||||||
|         return evaluate_name(expression, state, static_tools, custom_tools) |         return evaluate_name(expression, state, static_tools, custom_tools) | ||||||
|  | @ -815,7 +901,9 @@ def evaluate_ast( | ||||||
|             evaluate_ast(expression.upper, state, static_tools, custom_tools) |             evaluate_ast(expression.upper, state, static_tools, custom_tools) | ||||||
|             if expression.upper is not None |             if expression.upper is not None | ||||||
|             else None, |             else None, | ||||||
|             evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None, |             evaluate_ast(expression.step, state, static_tools, custom_tools) | ||||||
|  |             if expression.step is not None | ||||||
|  |             else None, | ||||||
|         ) |         ) | ||||||
|     elif isinstance(expression, ast.DictComp): |     elif isinstance(expression, ast.DictComp): | ||||||
|         return evaluate_dictcomp(expression, state, static_tools, custom_tools) |         return evaluate_dictcomp(expression, state, static_tools, custom_tools) | ||||||
|  | @ -834,17 +922,24 @@ def evaluate_ast( | ||||||
|     elif isinstance(expression, ast.With): |     elif isinstance(expression, ast.With): | ||||||
|         return evaluate_with(expression, state, static_tools, custom_tools) |         return evaluate_with(expression, state, static_tools, custom_tools) | ||||||
|     elif isinstance(expression, ast.Set): |     elif isinstance(expression, ast.Set): | ||||||
|         return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts} |         return { | ||||||
|  |             evaluate_ast(elt, state, static_tools, custom_tools) | ||||||
|  |             for elt in expression.elts | ||||||
|  |         } | ||||||
|     elif isinstance(expression, ast.Return): |     elif isinstance(expression, ast.Return): | ||||||
|         raise ReturnException( |         raise ReturnException( | ||||||
|             evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None |             evaluate_ast(expression.value, state, static_tools, custom_tools) | ||||||
|  |             if expression.value | ||||||
|  |             else None | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         # For now we refuse anything else. Let's add things as we need them. |         # For now we refuse anything else. Let's add things as we need them. | ||||||
|         raise InterpreterError(f"{expression.__class__.__name__} is not supported.") |         raise InterpreterError(f"{expression.__class__.__name__} is not supported.") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str: | def truncate_print_outputs( | ||||||
|  |     print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT | ||||||
|  | ) -> str: | ||||||
|     if len(print_outputs) < max_len_outputs: |     if len(print_outputs) < max_len_outputs: | ||||||
|         return print_outputs |         return print_outputs | ||||||
|     else: |     else: | ||||||
|  | @ -895,8 +990,12 @@ def evaluate_python_code( | ||||||
|     OPERATIONS_COUNT = 0 |     OPERATIONS_COUNT = 0 | ||||||
|     try: |     try: | ||||||
|         for node in expression.body: |         for node in expression.body: | ||||||
|             result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) |             result = evaluate_ast( | ||||||
|         state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) |                 node, state, static_tools, custom_tools, authorized_imports | ||||||
|  |             ) | ||||||
|  |         state["print_outputs"] = truncate_content( | ||||||
|  |             PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT | ||||||
|  |         ) | ||||||
|         return result |         return result | ||||||
|     except InterpreterError as e: |     except InterpreterError as e: | ||||||
|         msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) |         msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) | ||||||
|  |  | ||||||
|  | @ -26,7 +26,9 @@ class DuckDuckGoSearchTool(Tool): | ||||||
|     name = "web_search" |     name = "web_search" | ||||||
|     description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. |     description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. | ||||||
|     Each result has keys 'title', 'href' and 'body'.""" |     Each result has keys 'title', 'href' and 'body'.""" | ||||||
|     inputs = {"query": {"type": "string", "description": "The search query to perform."}} |     inputs = { | ||||||
|  |         "query": {"type": "string", "description": "The search query to perform."} | ||||||
|  |     } | ||||||
|     output_type = "any" |     output_type = "any" | ||||||
| 
 | 
 | ||||||
|     def forward(self, query: str) -> str: |     def forward(self, query: str) -> str: | ||||||
|  |  | ||||||
							
								
								
									
										141
									
								
								agents/tools.py
								
								
								
								
							
							
						
						
									
										141
									
								
								agents/tools.py
								
								
								
								
							|  | @ -26,7 +26,13 @@ from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder | from huggingface_hub import ( | ||||||
|  |     create_repo, | ||||||
|  |     get_collection, | ||||||
|  |     hf_hub_download, | ||||||
|  |     metadata_update, | ||||||
|  |     upload_folder, | ||||||
|  | ) | ||||||
| from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session | ||||||
| from packaging import version | from packaging import version | ||||||
| 
 | 
 | ||||||
|  | @ -73,7 +79,9 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): | ||||||
|             hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) |             hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) | ||||||
|             return "model" |             return "model" | ||||||
|         except RepositoryNotFoundError: |         except RepositoryNotFoundError: | ||||||
|             raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.") |             raise EnvironmentError( | ||||||
|  |                 f"`{repo_id}` does not seem to be a valid repo identifier on the Hub." | ||||||
|  |             ) | ||||||
|         except Exception: |         except Exception: | ||||||
|             return "model" |             return "model" | ||||||
|     except Exception: |     except Exception: | ||||||
|  | @ -158,7 +166,15 @@ class Tool: | ||||||
|             "inputs": dict, |             "inputs": dict, | ||||||
|             "output_type": str, |             "output_type": str, | ||||||
|         } |         } | ||||||
|         authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"] |         authorized_types = [ | ||||||
|  |             "string", | ||||||
|  |             "integer", | ||||||
|  |             "number", | ||||||
|  |             "image", | ||||||
|  |             "audio", | ||||||
|  |             "any", | ||||||
|  |             "boolean", | ||||||
|  |         ] | ||||||
| 
 | 
 | ||||||
|         for attr, expected_type in required_attributes.items(): |         for attr, expected_type in required_attributes.items(): | ||||||
|             attr_value = getattr(self, attr, None) |             attr_value = getattr(self, attr, None) | ||||||
|  | @ -169,7 +185,9 @@ class Tool: | ||||||
|                     f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." |                     f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." | ||||||
|                 ) |                 ) | ||||||
|         for input_name, input_content in self.inputs.items(): |         for input_name, input_content in self.inputs.items(): | ||||||
|             assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." |             assert isinstance( | ||||||
|  |                 input_content, dict | ||||||
|  |             ), f"Input '{input_name}' should be a dictionary." | ||||||
|             assert ( |             assert ( | ||||||
|                 "type" in input_content and "description" in input_content |                 "type" in input_content and "description" in input_content | ||||||
|             ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." |             ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." | ||||||
|  | @ -251,7 +269,11 @@ class Tool: | ||||||
|         # Save app file |         # Save app file | ||||||
|         app_file = os.path.join(output_dir, "app.py") |         app_file = os.path.join(output_dir, "app.py") | ||||||
|         with open(app_file, "w", encoding="utf-8") as f: |         with open(app_file, "w", encoding="utf-8") as f: | ||||||
|             f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__)) |             f.write( | ||||||
|  |                 APP_FILE_TEMPLATE.format( | ||||||
|  |                     module_name=last_module, class_name=self.__class__.__name__ | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         # Save requirements file |         # Save requirements file | ||||||
|         requirements_file = os.path.join(output_dir, "requirements.txt") |         requirements_file = os.path.join(output_dir, "requirements.txt") | ||||||
|  | @ -343,7 +365,9 @@ class Tool: | ||||||
|             custom_tool = config |             custom_tool = config | ||||||
| 
 | 
 | ||||||
|         tool_class = custom_tool["tool_class"] |         tool_class = custom_tool["tool_class"] | ||||||
|         tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs) |         tool_class = get_class_from_dynamic_module( | ||||||
|  |             tool_class, repo_id, token=token, **hub_kwargs | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if len(tool_class.name) == 0: |         if len(tool_class.name) == 0: | ||||||
|             tool_class.name = custom_tool["name"] |             tool_class.name = custom_tool["name"] | ||||||
|  | @ -420,7 +444,9 @@ class Tool: | ||||||
|         with tempfile.TemporaryDirectory() as work_dir: |         with tempfile.TemporaryDirectory() as work_dir: | ||||||
|             # Save all files. |             # Save all files. | ||||||
|             self.save(work_dir) |             self.save(work_dir) | ||||||
|             logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") |             logger.info( | ||||||
|  |                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" | ||||||
|  |             ) | ||||||
|             return upload_folder( |             return upload_folder( | ||||||
|                 repo_id=repo_id, |                 repo_id=repo_id, | ||||||
|                 commit_message=commit_message, |                 commit_message=commit_message, | ||||||
|  | @ -432,7 +458,11 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_space( |     def from_space( | ||||||
|         space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None |         space_id: str, | ||||||
|  |         name: str, | ||||||
|  |         description: str, | ||||||
|  |         api_name: Optional[str] = None, | ||||||
|  |         token: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Creates a [`Tool`] from a Space given its id on the Hub. |         Creates a [`Tool`] from a Space given its id on the Hub. | ||||||
|  | @ -485,7 +515,9 @@ class Tool: | ||||||
|                 self.client = Client(space_id, hf_token=token) |                 self.client = Client(space_id, hf_token=token) | ||||||
|                 self.name = name |                 self.name = name | ||||||
|                 self.description = description |                 self.description = description | ||||||
|                 space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] |                 space_description = self.client.view_api( | ||||||
|  |                     return_format="dict", print_info=False | ||||||
|  |                 )["named_endpoints"] | ||||||
| 
 | 
 | ||||||
|                 # If api_name is not defined, take the first of the available APIs for this space |                 # If api_name is not defined, take the first of the available APIs for this space | ||||||
|                 if api_name is None: |                 if api_name is None: | ||||||
|  | @ -498,7 +530,9 @@ class Tool: | ||||||
|                 try: |                 try: | ||||||
|                     space_description_api = space_description[api_name] |                     space_description_api = space_description[api_name] | ||||||
|                 except KeyError: |                 except KeyError: | ||||||
|                     raise KeyError(f"Could not find specified {api_name=} among available api names.") |                     raise KeyError( | ||||||
|  |                         f"Could not find specified {api_name=} among available api names." | ||||||
|  |                     ) | ||||||
| 
 | 
 | ||||||
|                 self.inputs = {} |                 self.inputs = {} | ||||||
|                 for parameter in space_description_api["parameters"]: |                 for parameter in space_description_api["parameters"]: | ||||||
|  | @ -523,9 +557,11 @@ class Tool: | ||||||
|                     temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |                     temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | ||||||
|                     arg.save(temp_file.name) |                     arg.save(temp_file.name) | ||||||
|                     arg = temp_file.name |                     arg = temp_file.name | ||||||
|                 if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like( |                 if ( | ||||||
|                     arg |                     isinstance(arg, (str, Path)) | ||||||
|                 ): |                     and Path(arg).exists() | ||||||
|  |                     and Path(arg).is_file() | ||||||
|  |                 ) or is_http_url_like(arg): | ||||||
|                     arg = handle_file(arg) |                     arg = handle_file(arg) | ||||||
|                 return arg |                 return arg | ||||||
| 
 | 
 | ||||||
|  | @ -544,7 +580,9 @@ class Tool: | ||||||
|                     ]  # Sometime the space also returns the generation seed, in which case the result is at index 0 |                     ]  # Sometime the space also returns the generation seed, in which case the result is at index 0 | ||||||
|                 return output |                 return output | ||||||
| 
 | 
 | ||||||
|         return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token) |         return SpaceToolWrapper( | ||||||
|  |             space_id, name, description, api_name=api_name, token=token | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_gradio(gradio_tool): |     def from_gradio(gradio_tool): | ||||||
|  | @ -561,7 +599,8 @@ class Tool: | ||||||
|                 self._gradio_tool = _gradio_tool |                 self._gradio_tool = _gradio_tool | ||||||
|                 func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) |                 func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) | ||||||
|                 self.inputs = { |                 self.inputs = { | ||||||
|                     key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args |                     key: {"type": CONVERSION_DICT[value.annotation], "description": ""} | ||||||
|  |                     for key, value in func_args | ||||||
|                 } |                 } | ||||||
|                 self.forward = self._gradio_tool.run |                 self.forward = self._gradio_tool.run | ||||||
| 
 | 
 | ||||||
|  | @ -603,7 +642,9 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str: | def get_tool_description_with_args( | ||||||
|  |     tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE | ||||||
|  | ) -> str: | ||||||
|     compiled_template = compile_jinja_template(description_template) |     compiled_template = compile_jinja_template(description_template) | ||||||
|     rendered = compiled_template.render( |     rendered = compiled_template.render( | ||||||
|         tool=tool, |         tool=tool, | ||||||
|  | @ -621,7 +662,10 @@ def compile_jinja_template(template): | ||||||
|         raise ImportError("template requires jinja2 to be installed.") |         raise ImportError("template requires jinja2 to be installed.") | ||||||
| 
 | 
 | ||||||
|     if version.parse(jinja2.__version__) < version.parse("3.1.0"): |     if version.parse(jinja2.__version__) < version.parse("3.1.0"): | ||||||
|         raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.") |         raise ImportError( | ||||||
|  |             "template requires jinja2>=3.1.0 to be installed. Your version is " | ||||||
|  |             f"{jinja2.__version__}." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def raise_exception(message): |     def raise_exception(message): | ||||||
|         raise TemplateError(message) |         raise TemplateError(message) | ||||||
|  | @ -697,7 +741,9 @@ class PipelineTool(Tool): | ||||||
| 
 | 
 | ||||||
|         if model is None: |         if model is None: | ||||||
|             if self.default_checkpoint is None: |             if self.default_checkpoint is None: | ||||||
|                 raise ValueError("This tool does not implement a default checkpoint, you need to pass one.") |                 raise ValueError( | ||||||
|  |                     "This tool does not implement a default checkpoint, you need to pass one." | ||||||
|  |                 ) | ||||||
|             model = self.default_checkpoint |             model = self.default_checkpoint | ||||||
|         if pre_processor is None: |         if pre_processor is None: | ||||||
|             pre_processor = model |             pre_processor = model | ||||||
|  | @ -720,15 +766,21 @@ class PipelineTool(Tool): | ||||||
|         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. |         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. | ||||||
|         """ |         """ | ||||||
|         if isinstance(self.pre_processor, str): |         if isinstance(self.pre_processor, str): | ||||||
|             self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs) |             self.pre_processor = self.pre_processor_class.from_pretrained( | ||||||
|  |                 self.pre_processor, **self.hub_kwargs | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if isinstance(self.model, str): |         if isinstance(self.model, str): | ||||||
|             self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs) |             self.model = self.model_class.from_pretrained( | ||||||
|  |                 self.model, **self.model_kwargs, **self.hub_kwargs | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if self.post_processor is None: |         if self.post_processor is None: | ||||||
|             self.post_processor = self.pre_processor |             self.post_processor = self.pre_processor | ||||||
|         elif isinstance(self.post_processor, str): |         elif isinstance(self.post_processor, str): | ||||||
|             self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs) |             self.post_processor = self.post_processor_class.from_pretrained( | ||||||
|  |                 self.post_processor, **self.hub_kwargs | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if self.device is None: |         if self.device is None: | ||||||
|             if self.device_map is not None: |             if self.device_map is not None: | ||||||
|  | @ -768,8 +820,12 @@ class PipelineTool(Tool): | ||||||
| 
 | 
 | ||||||
|         encoded_inputs = self.encode(*args, **kwargs) |         encoded_inputs = self.encode(*args, **kwargs) | ||||||
| 
 | 
 | ||||||
|         tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)} |         tensor_inputs = { | ||||||
|         non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)} |             k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) | ||||||
|  |         } | ||||||
|  |         non_tensor_inputs = { | ||||||
|  |             k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor) | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         encoded_inputs = send_to_device(tensor_inputs, self.device) |         encoded_inputs = send_to_device(tensor_inputs, self.device) | ||||||
|         outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) |         outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) | ||||||
|  | @ -790,7 +846,9 @@ def launch_gradio_demo(tool_class: Tool): | ||||||
|     try: |     try: | ||||||
|         import gradio as gr |         import gradio as gr | ||||||
|     except ImportError: |     except ImportError: | ||||||
|         raise ImportError("Gradio should be installed in order to launch a gradio demo.") |         raise ImportError( | ||||||
|  |             "Gradio should be installed in order to launch a gradio demo." | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     tool = tool_class() |     tool = tool_class() | ||||||
| 
 | 
 | ||||||
|  | @ -807,11 +865,15 @@ def launch_gradio_demo(tool_class: Tool): | ||||||
| 
 | 
 | ||||||
|     gradio_inputs = [] |     gradio_inputs = [] | ||||||
|     for input_name, input_details in tool_class.inputs.items(): |     for input_name, input_details in tool_class.inputs.items(): | ||||||
|         input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] |         input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ | ||||||
|  |             input_details["type"] | ||||||
|  |         ] | ||||||
|         new_component = input_gradio_component_class(label=input_name) |         new_component = input_gradio_component_class(label=input_name) | ||||||
|         gradio_inputs.append(new_component) |         gradio_inputs.append(new_component) | ||||||
| 
 | 
 | ||||||
|     output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type] |     output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[ | ||||||
|  |         tool_class.output_type | ||||||
|  |     ] | ||||||
|     gradio_output = output_gradio_componentclass(label=input_name) |     gradio_output = output_gradio_componentclass(label=input_name) | ||||||
| 
 | 
 | ||||||
|     gr.Interface( |     gr.Interface( | ||||||
|  | @ -875,7 +937,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs): | ||||||
|             f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the " |             f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the " | ||||||
|             f"code that you have checked." |             f"code that you have checked." | ||||||
|         ) |         ) | ||||||
|         return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs) |         return Tool.from_hub( | ||||||
|  |             task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def add_description(description): | def add_description(description): | ||||||
|  | @ -935,7 +999,9 @@ class EndpointClient: | ||||||
|             payload["parameters"] = params |             payload["parameters"] = params | ||||||
| 
 | 
 | ||||||
|         # Make API call |         # Make API call | ||||||
|         response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data) |         response = get_session().post( | ||||||
|  |             self.endpoint_url, headers=self.headers, json=payload, data=data | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # By default, parse the response for the user. |         # By default, parse the response for the user. | ||||||
|         if output_image: |         if output_image: | ||||||
|  | @ -972,7 +1038,9 @@ class ToolCollection: | ||||||
| 
 | 
 | ||||||
|     def __init__(self, collection_slug: str, token: Optional[str] = None): |     def __init__(self, collection_slug: str, token: Optional[str] = None): | ||||||
|         self._collection = get_collection(collection_slug, token=token) |         self._collection = get_collection(collection_slug, token=token) | ||||||
|         self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"} |         self._hub_repo_ids = { | ||||||
|  |             item.item_id for item in self._collection.items if item.item_type == "space" | ||||||
|  |         } | ||||||
|         self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} |         self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -986,7 +1054,9 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|     """ |     """ | ||||||
|     parameters = get_json_schema(tool_function)["function"] |     parameters = get_json_schema(tool_function)["function"] | ||||||
|     if "return" not in parameters: |     if "return" not in parameters: | ||||||
|         raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") |         raise TypeHintParsingException( | ||||||
|  |             "Tool return type not found: make sure your function has a return type hint!" | ||||||
|  |         ) | ||||||
|     class_name = f"{parameters['name'].capitalize()}Tool" |     class_name = f"{parameters['name'].capitalize()}Tool" | ||||||
| 
 | 
 | ||||||
|     class SpecificTool(Tool): |     class SpecificTool(Tool): | ||||||
|  | @ -1000,9 +1070,9 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|             return tool_function(*args, **kwargs) |             return tool_function(*args, **kwargs) | ||||||
| 
 | 
 | ||||||
|     original_signature = inspect.signature(tool_function) |     original_signature = inspect.signature(tool_function) | ||||||
|     new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list( |     new_parameters = [ | ||||||
|         original_signature.parameters.values() |         inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) | ||||||
|     ) |     ] + list(original_signature.parameters.values()) | ||||||
|     new_signature = original_signature.replace(parameters=new_parameters) |     new_signature = original_signature.replace(parameters=new_parameters) | ||||||
|     SpecificTool.forward.__signature__ = new_signature |     SpecificTool.forward.__signature__ = new_signature | ||||||
| 
 | 
 | ||||||
|  | @ -1049,7 +1119,10 @@ class Toolbox: | ||||||
|                 The template to use to describe the tools. If not provided, the default template will be used. |                 The template to use to describe the tools. If not provided, the default template will be used. | ||||||
|         """ |         """ | ||||||
|         return "\n".join( |         return "\n".join( | ||||||
|             [get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()] |             [ | ||||||
|  |                 get_tool_description_with_args(tool, tool_description_template) | ||||||
|  |                 for tool in self._tools.values() | ||||||
|  |             ] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def add_tool(self, tool: Tool): |     def add_tool(self, tool: Tool): | ||||||
|  |  | ||||||
|  | @ -16,32 +16,29 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| from typing import Tuple, Dict | from typing import Tuple, Dict, Union | ||||||
| 
 | 
 | ||||||
| from transformers.utils.import_utils import _is_package_available | from transformers.utils.import_utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| _pygments_available = _is_package_available("pygments") | _pygments_available = _is_package_available("pygments") | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def is_pygments_available(): | def is_pygments_available(): | ||||||
|     return _pygments_available |     return _pygments_available | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| from rich.console import Console | from rich.console import Console | ||||||
|  | 
 | ||||||
| console = Console() | console = Console() | ||||||
| 
 | 
 | ||||||
| LENGTH_TRUNCATE_REPORTS = 10000 |  | ||||||
| 
 |  | ||||||
| def truncate_content(content: str, max_length: int = LENGTH_TRUNCATE_REPORTS): |  | ||||||
|     if len(content) < max_length: |  | ||||||
|         return content |  | ||||||
|     else: |  | ||||||
|         return content[:max_length//2] + "\n..._(Content was truncated because too long)_...\n---" + content[-max_length//2:] |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def parse_json_blob(json_blob: str) -> Dict[str, str]: | def parse_json_blob(json_blob: str) -> Dict[str, str]: | ||||||
|     try: |     try: | ||||||
|         first_accolade_index = json_blob.find("{") |         first_accolade_index = json_blob.find("{") | ||||||
|         last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] |         last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] | ||||||
|         json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") |         json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace( | ||||||
|  |             '\\"', "'" | ||||||
|  |         ) | ||||||
|         json_data = json.loads(json_blob, strict=False) |         json_data = json.loads(json_blob, strict=False) | ||||||
|         return json_data |         return json_data | ||||||
|     except json.JSONDecodeError as e: |     except json.JSONDecodeError as e: | ||||||
|  | @ -63,7 +60,12 @@ def parse_code_blob(code_blob: str) -> str: | ||||||
|     try: |     try: | ||||||
|         pattern = r"```(?:py|python)?\n(.*?)\n```" |         pattern = r"```(?:py|python)?\n(.*?)\n```" | ||||||
|         match = re.search(pattern, code_blob, re.DOTALL) |         match = re.search(pattern, code_blob, re.DOTALL) | ||||||
|  |         if match is None: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"No match ground for regex pattern {pattern} in {code_blob=}." | ||||||
|  |             ) | ||||||
|         return match.group(1).strip() |         return match.group(1).strip() | ||||||
|  | 
 | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         raise ValueError( |         raise ValueError( | ||||||
|             f""" |             f""" | ||||||
|  | @ -77,7 +79,7 @@ Code: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: | def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]: | ||||||
|     json_blob = json_blob.replace("```json", "").replace("```", "") |     json_blob = json_blob.replace("```json", "").replace("```", "") | ||||||
|     tool_call = parse_json_blob(json_blob) |     tool_call = parse_json_blob(json_blob) | ||||||
|     if "action" in tool_call and "action_input" in tool_call: |     if "action" in tool_call and "action_input" in tool_call: | ||||||
|  | @ -85,7 +87,25 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: | ||||||
|     elif "action" in tool_call: |     elif "action" in tool_call: | ||||||
|         return tool_call["action"], None |         return tool_call["action"], None | ||||||
|     else: |     else: | ||||||
|         missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call] |         missing_keys = [ | ||||||
|  |             key for key in ["action", "action_input"] if key not in tool_call | ||||||
|  |         ] | ||||||
|         error_msg = f"Missing keys: {missing_keys} in blob {tool_call}" |         error_msg = f"Missing keys: {missing_keys} in blob {tool_call}" | ||||||
|         console.print(f"[bold red]{error_msg}[/bold red]") |         console.print(f"[bold red]{error_msg}[/bold red]") | ||||||
|         raise ValueError(error_msg) |         raise ValueError(error_msg) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | MAX_LENGTH_TRUNCATE_CONTENT = 20000 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def truncate_content( | ||||||
|  |     content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT | ||||||
|  | ) -> str: | ||||||
|  |     if len(content) <= max_length: | ||||||
|  |         return content | ||||||
|  |     else: | ||||||
|  |         return ( | ||||||
|  |             content[: MAX_LENGTH_TRUNCATE_CONTENT // 2] | ||||||
|  |             + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" | ||||||
|  |             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] | ||||||
|  |         ) | ||||||
|  |  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,271 @@ | ||||||
|  | # This file was autogenerated by uv via the following command: | ||||||
|  | #    uv pip compile requirements.in -o requirements.txt | ||||||
|  | aiofiles==23.2.1 | ||||||
|  |     # via gradio | ||||||
|  | annotated-types==0.7.0 | ||||||
|  |     # via pydantic | ||||||
|  | anyio==4.7.0 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   httpx | ||||||
|  |     #   starlette | ||||||
|  | appnope==0.1.4 | ||||||
|  |     # via ipykernel | ||||||
|  | asttokens==3.0.0 | ||||||
|  |     # via stack-data | ||||||
|  | beautifulsoup4==4.12.3 | ||||||
|  |     # via markdownify | ||||||
|  | certifi==2024.8.30 | ||||||
|  |     # via | ||||||
|  |     #   httpcore | ||||||
|  |     #   httpx | ||||||
|  |     #   requests | ||||||
|  | charset-normalizer==3.4.0 | ||||||
|  |     # via requests | ||||||
|  | click==8.1.7 | ||||||
|  |     # via | ||||||
|  |     #   duckduckgo-search | ||||||
|  |     #   typer | ||||||
|  |     #   uvicorn | ||||||
|  | comm==0.2.2 | ||||||
|  |     # via ipykernel | ||||||
|  | debugpy==1.8.10 | ||||||
|  |     # via ipykernel | ||||||
|  | decorator==5.1.1 | ||||||
|  |     # via ipython | ||||||
|  | diskcache==5.6.3 | ||||||
|  |     # via llama-cpp-python | ||||||
|  | duckduckgo-search==6.4.1 | ||||||
|  |     # via -r requirements.in | ||||||
|  | executing==2.1.0 | ||||||
|  |     # via stack-data | ||||||
|  | fastapi==0.115.6 | ||||||
|  |     # via gradio | ||||||
|  | ffmpy==0.4.0 | ||||||
|  |     # via gradio | ||||||
|  | filelock==3.16.1 | ||||||
|  |     # via | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   transformers | ||||||
|  | fsspec==2024.10.0 | ||||||
|  |     # via | ||||||
|  |     #   gradio-client | ||||||
|  |     #   huggingface-hub | ||||||
|  | gradio==5.8.0 | ||||||
|  |     # via -r requirements.in | ||||||
|  | gradio-client==1.5.1 | ||||||
|  |     # via gradio | ||||||
|  | h11==0.14.0 | ||||||
|  |     # via | ||||||
|  |     #   httpcore | ||||||
|  |     #   uvicorn | ||||||
|  | httpcore==1.0.7 | ||||||
|  |     # via httpx | ||||||
|  | httpx==0.28.1 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   gradio-client | ||||||
|  |     #   safehttpx | ||||||
|  | huggingface-hub==0.26.5 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   gradio-client | ||||||
|  |     #   tokenizers | ||||||
|  |     #   transformers | ||||||
|  | idna==3.10 | ||||||
|  |     # via | ||||||
|  |     #   anyio | ||||||
|  |     #   httpx | ||||||
|  |     #   requests | ||||||
|  | iniconfig==2.0.0 | ||||||
|  |     # via pytest | ||||||
|  | ipykernel==6.29.5 | ||||||
|  |     # via -r requirements.in | ||||||
|  | ipython==8.30.0 | ||||||
|  |     # via ipykernel | ||||||
|  | jedi==0.19.2 | ||||||
|  |     # via ipython | ||||||
|  | jinja2==3.1.4 | ||||||
|  |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   gradio | ||||||
|  |     #   llama-cpp-python | ||||||
|  | jupyter-client==8.6.3 | ||||||
|  |     # via ipykernel | ||||||
|  | jupyter-core==5.7.2 | ||||||
|  |     # via | ||||||
|  |     #   ipykernel | ||||||
|  |     #   jupyter-client | ||||||
|  | llama-cpp-python==0.3.5 | ||||||
|  |     # via -r requirements.in | ||||||
|  | markdown-it-py==3.0.0 | ||||||
|  |     # via rich | ||||||
|  | markdownify==0.14.1 | ||||||
|  |     # via -r requirements.in | ||||||
|  | markupsafe==2.1.5 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   jinja2 | ||||||
|  | matplotlib-inline==0.1.7 | ||||||
|  |     # via | ||||||
|  |     #   ipykernel | ||||||
|  |     #   ipython | ||||||
|  | mdurl==0.1.2 | ||||||
|  |     # via markdown-it-py | ||||||
|  | nest-asyncio==1.6.0 | ||||||
|  |     # via ipykernel | ||||||
|  | numpy==2.2.0 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   llama-cpp-python | ||||||
|  |     #   pandas | ||||||
|  |     #   transformers | ||||||
|  | orjson==3.10.12 | ||||||
|  |     # via gradio | ||||||
|  | packaging==24.2 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   gradio-client | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   ipykernel | ||||||
|  |     #   pytest | ||||||
|  |     #   transformers | ||||||
|  | pandas==2.2.3 | ||||||
|  |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   gradio | ||||||
|  | parso==0.8.4 | ||||||
|  |     # via jedi | ||||||
|  | pexpect==4.9.0 | ||||||
|  |     # via ipython | ||||||
|  | pillow==11.0.0 | ||||||
|  |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   gradio | ||||||
|  | platformdirs==4.3.6 | ||||||
|  |     # via jupyter-core | ||||||
|  | pluggy==1.5.0 | ||||||
|  |     # via pytest | ||||||
|  | primp==0.8.3 | ||||||
|  |     # via duckduckgo-search | ||||||
|  | prompt-toolkit==3.0.48 | ||||||
|  |     # via ipython | ||||||
|  | psutil==6.1.0 | ||||||
|  |     # via ipykernel | ||||||
|  | ptyprocess==0.7.0 | ||||||
|  |     # via pexpect | ||||||
|  | pure-eval==0.2.3 | ||||||
|  |     # via stack-data | ||||||
|  | pydantic==2.10.3 | ||||||
|  |     # via | ||||||
|  |     #   fastapi | ||||||
|  |     #   gradio | ||||||
|  | pydantic-core==2.27.1 | ||||||
|  |     # via pydantic | ||||||
|  | pydub==0.25.1 | ||||||
|  |     # via gradio | ||||||
|  | pygments==2.18.0 | ||||||
|  |     # via | ||||||
|  |     #   ipython | ||||||
|  |     #   rich | ||||||
|  | pytest==8.3.4 | ||||||
|  |     # via -r requirements.in | ||||||
|  | python-dateutil==2.9.0.post0 | ||||||
|  |     # via | ||||||
|  |     #   jupyter-client | ||||||
|  |     #   pandas | ||||||
|  | python-dotenv==1.0.1 | ||||||
|  |     # via -r requirements.in | ||||||
|  | python-multipart==0.0.19 | ||||||
|  |     # via gradio | ||||||
|  | pytz==2024.2 | ||||||
|  |     # via pandas | ||||||
|  | pyyaml==6.0.2 | ||||||
|  |     # via | ||||||
|  |     #   gradio | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   transformers | ||||||
|  | pyzmq==26.2.0 | ||||||
|  |     # via | ||||||
|  |     #   ipykernel | ||||||
|  |     #   jupyter-client | ||||||
|  | regex==2024.11.6 | ||||||
|  |     # via transformers | ||||||
|  | requests==2.32.3 | ||||||
|  |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   transformers | ||||||
|  | rich==13.9.4 | ||||||
|  |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   typer | ||||||
|  | ruff==0.8.2 | ||||||
|  |     # via gradio | ||||||
|  | safehttpx==0.1.6 | ||||||
|  |     # via gradio | ||||||
|  | safetensors==0.4.5 | ||||||
|  |     # via transformers | ||||||
|  | semantic-version==2.10.0 | ||||||
|  |     # via gradio | ||||||
|  | shellingham==1.5.4 | ||||||
|  |     # via typer | ||||||
|  | six==1.17.0 | ||||||
|  |     # via | ||||||
|  |     #   markdownify | ||||||
|  |     #   python-dateutil | ||||||
|  | sniffio==1.3.1 | ||||||
|  |     # via anyio | ||||||
|  | soupsieve==2.6 | ||||||
|  |     # via beautifulsoup4 | ||||||
|  | stack-data==0.6.3 | ||||||
|  |     # via ipython | ||||||
|  | starlette==0.41.3 | ||||||
|  |     # via | ||||||
|  |     #   fastapi | ||||||
|  |     #   gradio | ||||||
|  | tokenizers==0.21.0 | ||||||
|  |     # via transformers | ||||||
|  | tomlkit==0.13.2 | ||||||
|  |     # via gradio | ||||||
|  | tornado==6.4.2 | ||||||
|  |     # via | ||||||
|  |     #   ipykernel | ||||||
|  |     #   jupyter-client | ||||||
|  | tqdm==4.67.1 | ||||||
|  |     # via | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   transformers | ||||||
|  | traitlets==5.14.3 | ||||||
|  |     # via | ||||||
|  |     #   comm | ||||||
|  |     #   ipykernel | ||||||
|  |     #   ipython | ||||||
|  |     #   jupyter-client | ||||||
|  |     #   jupyter-core | ||||||
|  |     #   matplotlib-inline | ||||||
|  | transformers==4.47.0 | ||||||
|  |     # via -r requirements.in | ||||||
|  | typer==0.15.1 | ||||||
|  |     # via gradio | ||||||
|  | typing-extensions==4.12.2 | ||||||
|  |     # via | ||||||
|  |     #   anyio | ||||||
|  |     #   fastapi | ||||||
|  |     #   gradio | ||||||
|  |     #   gradio-client | ||||||
|  |     #   huggingface-hub | ||||||
|  |     #   llama-cpp-python | ||||||
|  |     #   pydantic | ||||||
|  |     #   pydantic-core | ||||||
|  |     #   typer | ||||||
|  | tzdata==2024.2 | ||||||
|  |     # via pandas | ||||||
|  | urllib3==2.2.3 | ||||||
|  |     # via requests | ||||||
|  | uvicorn==0.32.1 | ||||||
|  |     # via gradio | ||||||
|  | wcwidth==0.2.13 | ||||||
|  |     # via prompt-toolkit | ||||||
|  | websockets==14.1 | ||||||
|  |     # via gradio-client | ||||||
							
								
								
									
										122
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										122
									
								
								setup.py
								
								
								
								
							|  | @ -1,122 +0,0 @@ | ||||||
| # Copyright 2021 The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| from setuptools import find_packages, setup |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| extras = {} |  | ||||||
| extras["quality"] = [ |  | ||||||
|     "black ~= 23.1",  # hf-doc-builder has a hidden dependency on `black` |  | ||||||
|     "hf-doc-builder >= 0.3.0", |  | ||||||
|     "ruff ~= 0.6.4", |  | ||||||
| ] |  | ||||||
| extras["docs"] = [] |  | ||||||
| extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"] |  | ||||||
| extras["test_dev"] = [ |  | ||||||
|     "datasets", |  | ||||||
|     "diffusers", |  | ||||||
|     "evaluate", |  | ||||||
|     "torchdata>=0.8.0", |  | ||||||
|     "torchpippy>=0.2.0", |  | ||||||
|     "transformers", |  | ||||||
|     "scipy", |  | ||||||
|     "scikit-learn", |  | ||||||
|     "tqdm", |  | ||||||
|     "bitsandbytes", |  | ||||||
|     "timm", |  | ||||||
| ] |  | ||||||
| extras["testing"] = extras["test_prod"] + extras["test_dev"] |  | ||||||
| extras["deepspeed"] = ["deepspeed"] |  | ||||||
| extras["rich"] = ["rich"] |  | ||||||
| 
 |  | ||||||
| extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive"] |  | ||||||
| extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"] |  | ||||||
| 
 |  | ||||||
| extras["sagemaker"] = [ |  | ||||||
|     "sagemaker",  # boto3 is a required package in sagemaker |  | ||||||
| ] |  | ||||||
| 
 |  | ||||||
| setup( |  | ||||||
|     name="accelerate", |  | ||||||
|     version="1.2.0.dev0", |  | ||||||
|     description="Accelerate", |  | ||||||
|     long_description=open("README.md", encoding="utf-8").read(), |  | ||||||
|     long_description_content_type="text/markdown", |  | ||||||
|     keywords="deep learning", |  | ||||||
|     license="Apache", |  | ||||||
|     author="The HuggingFace team", |  | ||||||
|     author_email="zach.mueller@huggingface.co", |  | ||||||
|     url="https://github.com/huggingface/accelerate", |  | ||||||
|     package_dir={"": "src"}, |  | ||||||
|     packages=find_packages("src"), |  | ||||||
|     entry_points={ |  | ||||||
|         "console_scripts": [ |  | ||||||
|             "accelerate=accelerate.commands.accelerate_cli:main", |  | ||||||
|             "accelerate-config=accelerate.commands.config:main", |  | ||||||
|             "accelerate-estimate-memory=accelerate.commands.estimate:main", |  | ||||||
|             "accelerate-launch=accelerate.commands.launch:main", |  | ||||||
|             "accelerate-merge-weights=accelerate.commands.merge:main", |  | ||||||
|         ] |  | ||||||
|     }, |  | ||||||
|     python_requires=">=3.9.0", |  | ||||||
|     install_requires=[ |  | ||||||
|         "numpy>=1.17,<3.0.0", |  | ||||||
|         "packaging>=20.0", |  | ||||||
|         "psutil", |  | ||||||
|         "pyyaml", |  | ||||||
|         "torch>=1.10.0", |  | ||||||
|         "huggingface_hub>=0.21.0", |  | ||||||
|         "safetensors>=0.4.3", |  | ||||||
|     ], |  | ||||||
|     extras_require=extras, |  | ||||||
|     classifiers=[ |  | ||||||
|         "Development Status :: 5 - Production/Stable", |  | ||||||
|         "Intended Audience :: Developers", |  | ||||||
|         "Intended Audience :: Education", |  | ||||||
|         "Intended Audience :: Science/Research", |  | ||||||
|         "License :: OSI Approved :: Apache Software License", |  | ||||||
|         "Operating System :: OS Independent", |  | ||||||
|         "Programming Language :: Python :: 3", |  | ||||||
|         "Programming Language :: Python :: 3.8", |  | ||||||
|         "Topic :: Scientific/Engineering :: Artificial Intelligence", |  | ||||||
|     ], |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| # Release checklist |  | ||||||
| # 1. Checkout the release branch (for a patch the current release branch, for a new minor version, create one): |  | ||||||
| #      git checkout -b vXX.xx-release |  | ||||||
| #    The -b is only necessary for creation (so remove it when doing a patch) |  | ||||||
| # 2. Change the version in __init__.py and setup.py to the proper value. |  | ||||||
| # 3. Commit these changes with the message: "Release: v<VERSION>" |  | ||||||
| # 4. Add a tag in git to mark the release: |  | ||||||
| #      git tag v<VERSION> -m 'Adds tag v<VERSION> for pypi' |  | ||||||
| #    Push the tag and release commit to git: git push --tags origin vXX.xx-release |  | ||||||
| # 5. Run the following commands in the top-level directory: |  | ||||||
| #      rm -rf dist |  | ||||||
| #      rm -rf build |  | ||||||
| #      python setup.py bdist_wheel |  | ||||||
| #      python setup.py sdist |  | ||||||
| # 6. Upload the package to the pypi test server first: |  | ||||||
| #      twine upload dist/* -r testpypi |  | ||||||
| # 7. Check that you can install it in a virtualenv by running: |  | ||||||
| #      pip install accelerate |  | ||||||
| #      pip uninstall accelerate |  | ||||||
| #      pip install -i https://testpypi.python.org/pypi accelerate |  | ||||||
| #      accelerate env |  | ||||||
| #      accelerate test |  | ||||||
| # 8. Upload the final version to actual pypi: |  | ||||||
| #      twine upload dist/* -r pypi |  | ||||||
| # 9. Add release notes to the tag in github once everything is looking hunky-dory. |  | ||||||
| # 10. Go back to the main branch and update the version in __init__.py, setup.py to the new version ".dev" and push to |  | ||||||
| #     main. |  | ||||||
|  | @ -19,19 +19,25 @@ import uuid | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText | from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText | ||||||
| from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision | from transformers.testing_utils import ( | ||||||
| from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available |     get_tests_dir, | ||||||
|  |     require_soundfile, | ||||||
|  |     require_torch, | ||||||
|  |     require_vision, | ||||||
|  | ) | ||||||
|  | from transformers.utils import ( | ||||||
|  |     is_soundfile_availble, | ||||||
|  |     is_torch_available, | ||||||
|  |     is_vision_available, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
|  | import torch | ||||||
|  | from PIL import Image | ||||||
| 
 | 
 | ||||||
| if is_torch_available(): |  | ||||||
|     import torch |  | ||||||
| 
 | 
 | ||||||
| if is_soundfile_availble(): | if is_soundfile_availble(): | ||||||
|     import soundfile as sf |     import soundfile as sf | ||||||
| 
 | 
 | ||||||
| if is_vision_available(): |  | ||||||
|     from PIL import Image |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|     directory = tempfile.mkdtemp() |     directory = tempfile.mkdtemp() | ||||||
|  |  | ||||||
|  | @ -19,17 +19,16 @@ import uuid | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from transformers.agents.agent_types import AgentText | from agents.agent_types import AgentText | ||||||
| from transformers.agents.agents import ( | from agents.agents import ( | ||||||
|     AgentMaxIterationsError, |     AgentMaxIterationsError, | ||||||
|     CodeAgent, |     CodeAgent, | ||||||
|     ManagedAgent, |     ManagedAgent, | ||||||
|     ReactCodeAgent, |     CodeAgent, | ||||||
|     ReactJsonAgent, |     JsonAgent, | ||||||
|     Toolbox, |     Toolbox, | ||||||
| ) | ) | ||||||
| from transformers.agents.default_tools import PythonInterpreterTool | from agents.default_tools import PythonInterpreterTool | ||||||
| from transformers.testing_utils import require_torch |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_new_path(suffix="") -> str: | def get_new_path(suffix="") -> str: | ||||||
|  | @ -149,19 +148,26 @@ print(result) | ||||||
| 
 | 
 | ||||||
| class AgentTests(unittest.TestCase): | class AgentTests(unittest.TestCase): | ||||||
|     def test_fake_code_agent(self): |     def test_fake_code_agent(self): | ||||||
|         agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot) |         agent = CodeAgent( | ||||||
|  |             tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot | ||||||
|  |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, str) |         assert isinstance(output, str) | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
| 
 | 
 | ||||||
|     def test_fake_react_json_agent(self): |     def test_fake_react_json_agent(self): | ||||||
|         agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm) |         agent = JsonAgent( | ||||||
|  |             tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm | ||||||
|  |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, str) |         assert isinstance(output, str) | ||||||
|         assert output == "7.2904" |         assert output == "7.2904" | ||||||
|         assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" |         assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" | ||||||
|         assert agent.logs[1]["observation"] == "7.2904" |         assert agent.logs[1]["observation"] == "7.2904" | ||||||
|         assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker" |         assert ( | ||||||
|  |             agent.logs[1]["rationale"].strip() | ||||||
|  |             == "Thought: I should multiply 2 by 3.6452. special_marker" | ||||||
|  |         ) | ||||||
|         assert ( |         assert ( | ||||||
|             agent.logs[2]["llm_output"] |             agent.logs[2]["llm_output"] | ||||||
|             == """ |             == """ | ||||||
|  | @ -175,7 +181,9 @@ Action: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def test_fake_react_code_agent(self): |     def test_fake_react_code_agent(self): | ||||||
|         agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm) |         agent = CodeAgent( | ||||||
|  |             tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm | ||||||
|  |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, float) |         assert isinstance(output, float) | ||||||
|         assert output == 7.2904 |         assert output == 7.2904 | ||||||
|  | @ -186,17 +194,19 @@ Action: | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def test_react_code_agent_code_errors_show_offending_lines(self): |     def test_react_code_agent_code_errors_show_offending_lines(self): | ||||||
|         agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error) |         agent = CodeAgent( | ||||||
|  |             tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error | ||||||
|  |         ) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?") |         output = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert isinstance(output, AgentText) |         assert isinstance(output, AgentText) | ||||||
|         assert output == "got an error" |         assert output == "got an error" | ||||||
|         assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) |         assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) | ||||||
| 
 | 
 | ||||||
|     def test_setup_agent_with_empty_toolbox(self): |     def test_setup_agent_with_empty_toolbox(self): | ||||||
|         ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[]) |         JsonAgent(llm_engine=fake_react_json_llm, tools=[]) | ||||||
| 
 | 
 | ||||||
|     def test_react_fails_max_iterations(self): |     def test_react_fails_max_iterations(self): | ||||||
|         agent = ReactCodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[PythonInterpreterTool()], |             tools=[PythonInterpreterTool()], | ||||||
|             llm_engine=fake_code_llm_no_return,  # use this callable because it never ends |             llm_engine=fake_code_llm_no_return,  # use this callable because it never ends | ||||||
|             max_iterations=5, |             max_iterations=5, | ||||||
|  | @ -208,51 +218,62 @@ Action: | ||||||
|     @require_torch |     @require_torch | ||||||
|     def test_init_agent_with_different_toolsets(self): |     def test_init_agent_with_different_toolsets(self): | ||||||
|         toolset_1 = [] |         toolset_1 = [] | ||||||
|         agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 1 |             len(agent.toolbox.tools) == 1 | ||||||
|         )  # when no tools are provided, only the final_answer tool is added by default |         )  # when no tools are provided, only the final_answer tool is added by default | ||||||
| 
 | 
 | ||||||
|         toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] |         toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] | ||||||
|         agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 2 |             len(agent.toolbox.tools) == 2 | ||||||
|         )  # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer |         )  # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer | ||||||
| 
 | 
 | ||||||
|         toolset_3 = Toolbox(toolset_2) |         toolset_3 = Toolbox(toolset_2) | ||||||
|         agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) |         agent = CodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) | ||||||
|         assert ( |         assert ( | ||||||
|             len(agent.toolbox.tools) == 2 |             len(agent.toolbox.tools) == 2 | ||||||
|         )  # same as previous one, where toolset_3 is an instantiation of previous one |         )  # same as previous one, where toolset_3 is an instantiation of previous one | ||||||
| 
 | 
 | ||||||
|         # check that add_base_tools will not interfere with existing tools |         # check that add_base_tools will not interfere with existing tools | ||||||
|         with pytest.raises(KeyError) as e: |         with pytest.raises(KeyError) as e: | ||||||
|             agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True) |             agent = JsonAgent( | ||||||
|  |                 tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True | ||||||
|  |             ) | ||||||
|         assert "already exists in the toolbox" in str(e) |         assert "already exists in the toolbox" in str(e) | ||||||
| 
 | 
 | ||||||
|         # check that python_interpreter base tool does not get added to code agents |         # check that python_interpreter base tool does not get added to code agents | ||||||
|         agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) |         agent = CodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) | ||||||
|         assert len(agent.toolbox.tools) == 7  # added final_answer tool + 6 base tools (excluding interpreter) |         assert ( | ||||||
|  |             len(agent.toolbox.tools) == 7 | ||||||
|  |         )  # added final_answer tool + 6 base tools (excluding interpreter) | ||||||
| 
 | 
 | ||||||
|     def test_function_persistence_across_steps(self): |     def test_function_persistence_across_steps(self): | ||||||
|         agent = ReactCodeAgent( |         agent = CodeAgent( | ||||||
|             tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"] |             tools=[], | ||||||
|  |             llm_engine=fake_react_code_functiondef, | ||||||
|  |             max_iterations=2, | ||||||
|  |             additional_authorized_imports=["numpy"], | ||||||
|         ) |         ) | ||||||
|         res = agent.run("ok") |         res = agent.run("ok") | ||||||
|         assert res[0] == 0.5 |         assert res[0] == 0.5 | ||||||
| 
 | 
 | ||||||
|     def test_init_managed_agent(self): |     def test_init_managed_agent(self): | ||||||
|         agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) |         agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) | ||||||
|         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") |         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") | ||||||
|         assert managed_agent.name == "managed_agent" |         assert managed_agent.name == "managed_agent" | ||||||
|         assert managed_agent.description == "Empty" |         assert managed_agent.description == "Empty" | ||||||
| 
 | 
 | ||||||
|     def test_agent_description_gets_correctly_inserted_in_system_prompt(self): |     def test_agent_description_gets_correctly_inserted_in_system_prompt(self): | ||||||
|         agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef) |         agent = CodeAgent(tools=[], llm_engine=fake_react_code_functiondef) | ||||||
|         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") |         managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty") | ||||||
|         manager_agent = ReactCodeAgent( |         manager_agent = CodeAgent( | ||||||
|             tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent] |             tools=[], | ||||||
|  |             llm_engine=fake_react_code_functiondef, | ||||||
|  |             managed_agents=[managed_agent], | ||||||
|         ) |         ) | ||||||
|         assert "You can also give requests to team members." not in agent.system_prompt |         assert "You can also give requests to team members." not in agent.system_prompt | ||||||
|         assert "<<managed_agents_descriptions>>" not in agent.system_prompt |         assert "<<managed_agents_descriptions>>" not in agent.system_prompt | ||||||
|         assert "You can also give requests to team members." in manager_agent.system_prompt |         assert ( | ||||||
|  |             "You can also give requests to team members." in manager_agent.system_prompt | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  | @ -1,41 +0,0 @@ | ||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 HuggingFace Inc. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import unittest |  | ||||||
| 
 |  | ||||||
| from datasets import load_dataset |  | ||||||
| 
 |  | ||||||
| from transformers import load_tool |  | ||||||
| 
 |  | ||||||
| from .test_tools_common import ToolTesterMixin |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): |  | ||||||
|     def setUp(self): |  | ||||||
|         self.tool = load_tool("document_question_answering") |  | ||||||
|         self.tool.setup() |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_arg(self): |  | ||||||
|         dataset = load_dataset("hf-internal-testing/example-documents", split="test") |  | ||||||
|         document = dataset[0]["image"] |  | ||||||
| 
 |  | ||||||
|         result = self.tool(document, "When is the coffee break?") |  | ||||||
|         self.assertEqual(result, "11-14 to 11:39 a.m.") |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_kwarg(self): |  | ||||||
|         dataset = load_dataset("hf-internal-testing/example-documents", split="test") |  | ||||||
|         document = dataset[0]["image"] |  | ||||||
| 
 |  | ||||||
|         self.tool(document=document, question="When is the coffee break?") |  | ||||||
|  | @ -87,7 +87,11 @@ class ExampleDifferenceTests(unittest.TestCase): | ||||||
|     examples_path = Path("examples").resolve() |     examples_path = Path("examples").resolve() | ||||||
| 
 | 
 | ||||||
|     def one_complete_example( |     def one_complete_example( | ||||||
|         self, complete_file_name: str, parser_only: bool, secondary_filename: str = None, special_strings: list = None |         self, | ||||||
|  |         complete_file_name: str, | ||||||
|  |         parser_only: bool, | ||||||
|  |         secondary_filename: str = None, | ||||||
|  |         special_strings: list = None, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Tests a single `complete` example against all of the implemented `by_feature` scripts |         Tests a single `complete` example against all of the implemented `by_feature` scripts | ||||||
|  | @ -112,10 +116,15 @@ class ExampleDifferenceTests(unittest.TestCase): | ||||||
|                     with self.subTest( |                     with self.subTest( | ||||||
|                         tested_script=complete_file_name, |                         tested_script=complete_file_name, | ||||||
|                         feature_script=item, |                         feature_script=item, | ||||||
|                         tested_section="main()" if parser_only else "training_function()", |                         tested_section="main()" | ||||||
|  |                         if parser_only | ||||||
|  |                         else "training_function()", | ||||||
|                     ): |                     ): | ||||||
|                         diff = compare_against_test( |                         diff = compare_against_test( | ||||||
|                             self.examples_path / complete_file_name, item_path, parser_only, secondary_filename |                             self.examples_path / complete_file_name, | ||||||
|  |                             item_path, | ||||||
|  |                             parser_only, | ||||||
|  |                             secondary_filename, | ||||||
|                         ) |                         ) | ||||||
|                         diff = "\n".join(diff) |                         diff = "\n".join(diff) | ||||||
|                         if special_strings is not None: |                         if special_strings is not None: | ||||||
|  | @ -140,8 +149,12 @@ class ExampleDifferenceTests(unittest.TestCase): | ||||||
|             " " * 12, |             " " * 12, | ||||||
|             " " * 8 + "for step, batch in enumerate(active_dataloader):\n", |             " " * 8 + "for step, batch in enumerate(active_dataloader):\n", | ||||||
|         ] |         ] | ||||||
|         self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings) |         self.one_complete_example( | ||||||
|         self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings) |             "complete_cv_example.py", True, cv_path, special_strings | ||||||
|  |         ) | ||||||
|  |         self.one_complete_example( | ||||||
|  |             "complete_cv_example.py", False, cv_path, special_strings | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"}) | @mock.patch.dict(os.environ, {"TESTING_MOCKED_DATALOADERS": "1"}) | ||||||
|  |  | ||||||
|  | @ -47,9 +47,9 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|     def create_inputs(self): |     def create_inputs(self): | ||||||
|         inputs_text = {"answer": "Text input"} |         inputs_text = {"answer": "Text input"} | ||||||
|         inputs_image = { |         inputs_image = { | ||||||
|             "answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize( |             "answer": Image.open( | ||||||
|                 (512, 512) |                 Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" | ||||||
|             ) |             ).resize((512, 512)) | ||||||
|         } |         } | ||||||
|         inputs_audio = {"answer": torch.Tensor(np.ones(3000))} |         inputs_audio = {"answer": torch.Tensor(np.ones(3000))} | ||||||
|         return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} |         return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} | ||||||
|  |  | ||||||
|  | @ -1,42 +0,0 @@ | ||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 HuggingFace Inc. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import unittest |  | ||||||
| from pathlib import Path |  | ||||||
| 
 |  | ||||||
| from transformers import is_vision_available, load_tool |  | ||||||
| from transformers.testing_utils import get_tests_dir |  | ||||||
| 
 |  | ||||||
| from .test_tools_common import ToolTesterMixin |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if is_vision_available(): |  | ||||||
|     from PIL import Image |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): |  | ||||||
|     def setUp(self): |  | ||||||
|         self.tool = load_tool("image_question_answering") |  | ||||||
|         self.tool.setup() |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_arg(self): |  | ||||||
|         image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") |  | ||||||
|         result = self.tool(image, "How many cats are sleeping on the couch?") |  | ||||||
|         self.assertEqual(result, "2") |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_kwarg(self): |  | ||||||
|         image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") |  | ||||||
|         result = self.tool(image=image, question="How many cats are sleeping on the couch?") |  | ||||||
|         self.assertEqual(result, "2") |  | ||||||
|  | @ -129,7 +129,9 @@ final_answer('This is the final answer.') | ||||||
| 
 | 
 | ||||||
|     def test_streaming_agent_image_output(self): |     def test_streaming_agent_image_output(self): | ||||||
|         def dummy_llm_engine(prompt, **kwargs): |         def dummy_llm_engine(prompt, **kwargs): | ||||||
|             return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' |             return ( | ||||||
|  |                 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         agent = ReactJsonAgent( |         agent = ReactJsonAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|  | @ -138,7 +140,14 @@ final_answer('This is the final answer.') | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|         outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True)) |         outputs = list( | ||||||
|  |             stream_to_gradio( | ||||||
|  |                 agent, | ||||||
|  |                 task="Test task", | ||||||
|  |                 image=AgentImage(value="path.png"), | ||||||
|  |                 test_mode=True, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(len(outputs), 2) |         self.assertEqual(len(outputs), 2) | ||||||
|         final_message = outputs[-1] |         final_message = outputs[-1] | ||||||
|  |  | ||||||
|  | @ -21,7 +21,10 @@ import pytest | ||||||
| from transformers import load_tool | from transformers import load_tool | ||||||
| from transformers.agents.agent_types import AGENT_TYPE_MAPPING | from transformers.agents.agent_types import AGENT_TYPE_MAPPING | ||||||
| from transformers.agents.default_tools import BASE_PYTHON_TOOLS | from transformers.agents.default_tools import BASE_PYTHON_TOOLS | ||||||
| from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code | from transformers.agents.python_interpreter import ( | ||||||
|  |     InterpreterError, | ||||||
|  |     evaluate_python_code, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| from .test_tools_common import ToolTesterMixin | from .test_tools_common import ToolTesterMixin | ||||||
| 
 | 
 | ||||||
|  | @ -57,7 +60,12 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|         for _input, expected_input in zip(inputs, self.tool.inputs.values()): |         for _input, expected_input in zip(inputs, self.tool.inputs.values()): | ||||||
|             input_type = expected_input["type"] |             input_type = expected_input["type"] | ||||||
|             if isinstance(input_type, list): |             if isinstance(input_type, list): | ||||||
|                 _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) |                 _inputs.append( | ||||||
|  |                     [ | ||||||
|  |                         AGENT_TYPE_MAPPING[_input_type](_input) | ||||||
|  |                         for _input_type in input_type | ||||||
|  |                     ] | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) |                 _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) | ||||||
| 
 | 
 | ||||||
|  | @ -91,7 +99,10 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|         code = "print = '3'" |         code = "print = '3'" | ||||||
|         with pytest.raises(InterpreterError) as e: |         with pytest.raises(InterpreterError) as e: | ||||||
|             evaluate_python_code(code, {"print": print}, state={}) |             evaluate_python_code(code, {"print": print}, state={}) | ||||||
|         assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e) |         assert ( | ||||||
|  |             "Cannot assign to name 'print': doing this would erase the existing tool!" | ||||||
|  |             in str(e) | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_call(self): |     def test_evaluate_call(self): | ||||||
|         code = "y = add_two(x)" |         code = "y = add_two(x)" | ||||||
|  | @ -117,7 +128,9 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         self.assertDictEqual(result, {"x": 3, "y": 5}) |         self.assertDictEqual(result, {"x": 3, "y": 5}) | ||||||
|         self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) |         self.assertDictEqual( | ||||||
|  |             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_expression(self): |     def test_evaluate_expression(self): | ||||||
|         code = "x = 3\ny = 5" |         code = "x = 3\ny = 5" | ||||||
|  | @ -133,7 +146,9 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|         result = evaluate_python_code(code, {}, state=state) |         result = evaluate_python_code(code, {}, state=state) | ||||||
|         # evaluate returns the value of the last assignment. |         # evaluate returns the value of the last assignment. | ||||||
|         assert result == "This is x: 3." |         assert result == "This is x: 3." | ||||||
|         self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}) |         self.assertDictEqual( | ||||||
|  |             state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""} | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def test_evaluate_if(self): |     def test_evaluate_if(self): | ||||||
|         code = "if x <= 3:\n    y = 2\nelse:\n    y = 5" |         code = "if x <= 3:\n    y = 2\nelse:\n    y = 5" | ||||||
|  | @ -174,11 +189,15 @@ class PythonInterpreterTester(unittest.TestCase): | ||||||
|         state = {"x": 3} |         state = {"x": 3} | ||||||
|         result = evaluate_python_code(code, {"add_two": add_two}, state=state) |         result = evaluate_python_code(code, {"add_two": add_two}, state=state) | ||||||
|         assert result == 5 |         assert result == 5 | ||||||
|         self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) |         self.assertDictEqual( | ||||||
|  |             state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""} | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" |         code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" | ||||||
|         state = {} |         state = {} | ||||||
|         evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) |         evaluate_python_code( | ||||||
|  |             code, {"min": min, "print": print, "round": round}, state=state | ||||||
|  |         ) | ||||||
|         assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} |         assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} | ||||||
| 
 | 
 | ||||||
|     def test_subscript_string_with_string_index_raises_appropriate_error(self): |     def test_subscript_string_with_string_index_raises_appropriate_error(self): | ||||||
|  | @ -292,7 +311,16 @@ print(check_digits) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         evaluate_python_code( |         evaluate_python_code( | ||||||
|             code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state |             code, | ||||||
|  |             { | ||||||
|  |                 "range": range, | ||||||
|  |                 "print": print, | ||||||
|  |                 "sum": sum, | ||||||
|  |                 "enumerate": enumerate, | ||||||
|  |                 "int": int, | ||||||
|  |                 "str": str, | ||||||
|  |             }, | ||||||
|  |             state, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def test_listcomp(self): |     def test_listcomp(self): | ||||||
|  | @ -325,7 +353,9 @@ print(check_digits) | ||||||
|         assert result == {0: 0, 1: 1, 2: 4} |         assert result == {0: 0, 1: 1, 2: 4} | ||||||
| 
 | 
 | ||||||
|         code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" |         code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" | ||||||
|         result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print}, state={}, authorized_imports=["pandas"] | ||||||
|  |         ) | ||||||
|         assert result == {102: "b"} |         assert result == {102: "b"} | ||||||
| 
 | 
 | ||||||
|         code = """ |         code = """ | ||||||
|  | @ -373,7 +403,9 @@ else: | ||||||
|     best_city = "Manhattan" |     best_city = "Manhattan" | ||||||
|     best_city |     best_city | ||||||
|     """ |     """ | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) |         result = evaluate_python_code( | ||||||
|  |             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} | ||||||
|  |         ) | ||||||
|         assert result == "Brooklyn" |         assert result == "Brooklyn" | ||||||
| 
 | 
 | ||||||
|         code = """if d > e and a < b: |         code = """if d > e and a < b: | ||||||
|  | @ -384,7 +416,9 @@ else: | ||||||
|     best_city = "Manhattan" |     best_city = "Manhattan" | ||||||
|     best_city |     best_city | ||||||
|     """ |     """ | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) |         result = evaluate_python_code( | ||||||
|  |             code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} | ||||||
|  |         ) | ||||||
|         assert result == "Sacramento" |         assert result == "Sacramento" | ||||||
| 
 | 
 | ||||||
|     def test_if_conditions(self): |     def test_if_conditions(self): | ||||||
|  | @ -400,7 +434,9 @@ if char.isalpha(): | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == 2.0 |         assert result == 2.0 | ||||||
| 
 | 
 | ||||||
|         code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" |         code = ( | ||||||
|  |             "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])" | ||||||
|  |         ) | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) |         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) | ||||||
|         assert result == "lose" |         assert result == "lose" | ||||||
| 
 | 
 | ||||||
|  | @ -434,10 +470,14 @@ if char.isalpha(): | ||||||
| 
 | 
 | ||||||
|         # Test submodules are handled properly, thus not raising error |         # Test submodules are handled properly, thus not raising error | ||||||
|         code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" |         code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) |         result = evaluate_python_code( | ||||||
|  |             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" |         code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" | ||||||
|         result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]) |         result = evaluate_python_code( | ||||||
|  |             code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"] | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def test_additional_imports(self): |     def test_additional_imports(self): | ||||||
|         code = "import numpy as np" |         code = "import numpy as np" | ||||||
|  | @ -554,7 +594,11 @@ cat_sound = cat.sound() | ||||||
| cat_str = str(cat) | cat_str = str(cat) | ||||||
|     """ |     """ | ||||||
|         state = {} |         state = {} | ||||||
|         evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) |         evaluate_python_code( | ||||||
|  |             code, | ||||||
|  |             {"print": print, "len": len, "super": super, "str": str, "sum": sum}, | ||||||
|  |             state=state, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # Assert results |         # Assert results | ||||||
|         assert state["dog1_sound"] == "The dog barks." |         assert state["dog1_sound"] == "The dog barks." | ||||||
|  | @ -588,7 +632,11 @@ except ValueError as e: | ||||||
|     exception_message = str(e) |     exception_message = str(e) | ||||||
|     """ |     """ | ||||||
|         state = {} |         state = {} | ||||||
|         evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) |         evaluate_python_code( | ||||||
|  |             code, | ||||||
|  |             {"print": print, "len": len, "super": super, "str": str, "sum": sum}, | ||||||
|  |             state=state, | ||||||
|  |         ) | ||||||
|         assert state["exception_message"] == "An error occurred" |         assert state["exception_message"] == "An error occurred" | ||||||
| 
 | 
 | ||||||
|     def test_print(self): |     def test_print(self): | ||||||
|  | @ -600,7 +648,9 @@ except ValueError as e: | ||||||
|     def test_types_as_objects(self): |     def test_types_as_objects(self): | ||||||
|         code = "type_a = float(2); type_b = str; type_c = int" |         code = "type_a = float(2); type_b = str; type_c = int" | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) |         result = evaluate_python_code( | ||||||
|  |             code, {"float": float, "str": str, "int": int}, state=state | ||||||
|  |         ) | ||||||
|         assert result is int |         assert result is int | ||||||
| 
 | 
 | ||||||
|     def test_tuple_id(self): |     def test_tuple_id(self): | ||||||
|  | @ -731,7 +781,9 @@ def add_one(n, shift): | ||||||
| add_one(1, 1) | add_one(1, 1) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state | ||||||
|  |         ) | ||||||
|         assert result == 2 |         assert result == 2 | ||||||
| 
 | 
 | ||||||
|         # test returning None |         # test returning None | ||||||
|  | @ -742,7 +794,9 @@ def returns_none(a): | ||||||
| returns_none(1) | returns_none(1) | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state | ||||||
|  |         ) | ||||||
|         assert result is None |         assert result is None | ||||||
| 
 | 
 | ||||||
|     def test_nested_for_loop(self): |     def test_nested_for_loop(self): | ||||||
|  | @ -758,7 +812,9 @@ out = [i for sublist in all_res for i in sublist] | ||||||
| out[:10] | out[:10] | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {"print": print, "range": range}, state=state) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print, "range": range}, state=state | ||||||
|  |         ) | ||||||
|         assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] |         assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] | ||||||
| 
 | 
 | ||||||
|     def test_pandas(self): |     def test_pandas(self): | ||||||
|  | @ -773,7 +829,9 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0] | ||||||
| parts_with_5_set_count[['Quantity', 'SetCount']].values[1] | parts_with_5_set_count[['Quantity', 'SetCount']].values[1] | ||||||
| """ | """ | ||||||
|         state = {} |         state = {} | ||||||
|         result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"]) |         result = evaluate_python_code( | ||||||
|  |             code, {}, state=state, authorized_imports=["pandas"] | ||||||
|  |         ) | ||||||
|         assert np.array_equal(result, [-1, 5]) |         assert np.array_equal(result, [-1, 5]) | ||||||
| 
 | 
 | ||||||
|         code = """ |         code = """ | ||||||
|  | @ -785,7 +843,9 @@ print("HH0") | ||||||
| # Filter the DataFrame to get only the rows with outdated atomic numbers | # Filter the DataFrame to get only the rows with outdated atomic numbers | ||||||
| filtered_df = df.loc[df['AtomicNumber'].isin([104])] | filtered_df = df.loc[df['AtomicNumber'].isin([104])] | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"]) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print}, state={}, authorized_imports=["pandas"] | ||||||
|  |         ) | ||||||
|         assert np.array_equal(result.values[0], [104, 1]) |         assert np.array_equal(result.values[0], [104, 1]) | ||||||
| 
 | 
 | ||||||
|         code = """import pandas as pd |         code = """import pandas as pd | ||||||
|  | @ -818,7 +878,9 @@ coords_barcelona = (41.3869, 2.1660) | ||||||
| 
 | 
 | ||||||
| distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) | distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) | ||||||
| """ | """ | ||||||
|         result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"]) |         result = evaluate_python_code( | ||||||
|  |             code, {"print": print, "map": map}, state={}, authorized_imports=["math"] | ||||||
|  |         ) | ||||||
|         assert round(result, 1) == 622395.4 |         assert round(result, 1) == 622395.4 | ||||||
| 
 | 
 | ||||||
|     def test_for(self): |     def test_for(self): | ||||||
|  |  | ||||||
|  | @ -1,36 +0,0 @@ | ||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 HuggingFace Inc. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import unittest |  | ||||||
| 
 |  | ||||||
| import numpy as np |  | ||||||
| 
 |  | ||||||
| from transformers import load_tool |  | ||||||
| 
 |  | ||||||
| from .test_tools_common import ToolTesterMixin |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin): |  | ||||||
|     def setUp(self): |  | ||||||
|         self.tool = load_tool("speech_to_text") |  | ||||||
|         self.tool.setup() |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_arg(self): |  | ||||||
|         result = self.tool(np.ones(3000)) |  | ||||||
|         self.assertEqual(result, " Thank you.") |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_kwarg(self): |  | ||||||
|         result = self.tool(audio=np.ones(3000)) |  | ||||||
|         self.assertEqual(result, " Thank you.") |  | ||||||
|  | @ -1,50 +0,0 @@ | ||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 HuggingFace Inc. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| 
 |  | ||||||
| import unittest |  | ||||||
| 
 |  | ||||||
| from transformers import load_tool |  | ||||||
| from transformers.utils import is_torch_available |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if is_torch_available(): |  | ||||||
|     import torch |  | ||||||
| 
 |  | ||||||
| from transformers.testing_utils import require_torch |  | ||||||
| 
 |  | ||||||
| from .test_tools_common import ToolTesterMixin |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @require_torch |  | ||||||
| class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin): |  | ||||||
|     def setUp(self): |  | ||||||
|         self.tool = load_tool("text_to_speech") |  | ||||||
|         self.tool.setup() |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_arg(self): |  | ||||||
|         # SpeechT5 isn't deterministic |  | ||||||
|         torch.manual_seed(0) |  | ||||||
|         result = self.tool("hey") |  | ||||||
|         resulting_tensor = result.to_raw() |  | ||||||
|         self.assertTrue(len(resulting_tensor.detach().shape) == 1) |  | ||||||
|         self.assertTrue(resulting_tensor.detach().shape[0] > 1000) |  | ||||||
| 
 |  | ||||||
|     def test_exact_match_kwarg(self): |  | ||||||
|         # SpeechT5 isn't deterministic |  | ||||||
|         torch.manual_seed(0) |  | ||||||
|         result = self.tool("hey") |  | ||||||
|         resulting_tensor = result.to_raw() |  | ||||||
|         self.assertTrue(len(resulting_tensor.detach().shape) == 1) |  | ||||||
|         self.assertTrue(resulting_tensor.detach().shape[0] > 1000) |  | ||||||
|  | @ -20,7 +20,12 @@ import numpy as np | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from transformers import is_torch_available, is_vision_available | from transformers import is_torch_available, is_vision_available | ||||||
| from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText | from transformers.agents.agent_types import ( | ||||||
|  |     AGENT_TYPE_MAPPING, | ||||||
|  |     AgentAudio, | ||||||
|  |     AgentImage, | ||||||
|  |     AgentText, | ||||||
|  | ) | ||||||
| from transformers.agents.tools import Tool, tool | from transformers.agents.tools import Tool, tool | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -32,7 +32,9 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin): | ||||||
|         self.assertEqual(result, "- Hé, comment ça va?") |         self.assertEqual(result, "- Hé, comment ça va?") | ||||||
| 
 | 
 | ||||||
|     def test_exact_match_kwarg(self): |     def test_exact_match_kwarg(self): | ||||||
|         result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French") |         result = self.tool( | ||||||
|  |             text="Hey, what's up?", src_lang="English", tgt_lang="French" | ||||||
|  |         ) | ||||||
|         self.assertEqual(result, "- Hé, comment ça va?") |         self.assertEqual(result, "- Hé, comment ça va?") | ||||||
| 
 | 
 | ||||||
|     def test_call(self): |     def test_call(self): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue