Get rid of necessity to declare tools in their own .py file
This commit is contained in:
		
							parent
							
								
									aef0510e68
								
							
						
					
					
						commit
						0eb582bdba
					
				|  | @ -38,10 +38,10 @@ The custom tool needs: | |||
| 
 | ||||
| The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`]. | ||||
| 
 | ||||
| Also, all imports should be put within the tool's forward function, else you will get an error. | ||||
| 
 | ||||
| ```python | ||||
| from transformers import Tool | ||||
| from huggingface_hub import list_models | ||||
| 
 | ||||
| class HFModelDownloadsTool(Tool): | ||||
|     name = "model_download_counter" | ||||
|  | @ -58,26 +58,27 @@ class HFModelDownloadsTool(Tool): | |||
|     output_type = "string" | ||||
| 
 | ||||
|     def forward(self, task: str): | ||||
|         from huggingface_hub import list_models | ||||
| 
 | ||||
|         model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) | ||||
|         return model.id | ||||
| ``` | ||||
| 
 | ||||
| Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. | ||||
| 
 | ||||
| 
 | ||||
| ```python | ||||
| from model_downloads import HFModelDownloadsTool | ||||
| 
 | ||||
| tool = HFModelDownloadsTool() | ||||
| ``` | ||||
| 
 | ||||
| Now the custom `HfModelDownloadsTool` class is ready. | ||||
| 
 | ||||
| You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. | ||||
| 
 | ||||
| ```python | ||||
| tool.push_to_hub("{your_username}/hf-model-downloads") | ||||
| from dotenv import load_dotenv | ||||
| 
 | ||||
| load_dotenv() | ||||
| 
 | ||||
| tool.push_to_hub("m-ric/hf-model-downloads", token=os.getenv("HF_TOKEN")) | ||||
| ``` | ||||
| 
 | ||||
| Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. | ||||
| Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`. | ||||
| 
 | ||||
| ```python | ||||
| from transformers import load_tool, CodeAgent | ||||
|  | @ -159,7 +160,7 @@ We love Langchain and think it has a very compelling suite of tools. | |||
| To import a tool from LangChain, use the `from_langchain()` method. | ||||
| 
 | ||||
| Here is how you can use it to recreate the intro's search result using a LangChain web search tool. | ||||
| This tool will need `pip install google-search-results` to work properly. | ||||
| This tool will need `pip install langchain google-search-results -q` to work properly. | ||||
| ```python | ||||
| from langchain.agents import load_tools | ||||
| from agents import Tool, CodeAgent | ||||
|  | @ -191,7 +192,6 @@ agent.run( | |||
| ) | ||||
| ``` | ||||
| 
 | ||||
| 
 | ||||
| | **Audio**                                                                                                                                            | | ||||
| |------------------------------------------------------------------------------------------------------------------------------------------------------| | ||||
| | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> | | ||||
|  |  | |||
|  | @ -44,7 +44,6 @@ from transformers.dynamic_module_utils import ( | |||
| ) | ||||
| from transformers import AutoProcessor | ||||
| from transformers.utils import ( | ||||
|     CONFIG_NAME, | ||||
|     TypeHintParsingException, | ||||
|     cached_file, | ||||
|     get_json_schema, | ||||
|  | @ -53,6 +52,8 @@ from transformers.utils import ( | |||
|     is_vision_available, | ||||
| ) | ||||
| from .types import ImageType, handle_agent_inputs, handle_agent_outputs | ||||
| from .utils import ImportFinder | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -104,7 +105,7 @@ def setup_default_tools(): | |||
| 
 | ||||
| # docstyle-ignore | ||||
| APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | ||||
| from {module_name} import {class_name} | ||||
| from tool import {class_name} | ||||
| 
 | ||||
| launch_gradio_demo({class_name}) | ||||
| """ | ||||
|  | @ -304,28 +305,44 @@ class Tool: | |||
|             output_dir (`str`): The folder in which you want to save your tool. | ||||
|         """ | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|         # Save module file | ||||
|         if self.__module__ == "__main__": | ||||
|             raise ValueError( | ||||
|                 f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You " | ||||
|                 "have to put this code in a separate module so we can include it in the saved folder." | ||||
|             ) | ||||
|         module_files = custom_object_save(self, output_dir) | ||||
|         class_name = self.__class__.__name__ | ||||
| 
 | ||||
|         module_name = self.__class__.__module__ | ||||
|         last_module = module_name.split(".")[-1] | ||||
|         full_name = f"{last_module}.{self.__class__.__name__}" | ||||
|         # Save tool file | ||||
|         forward_source_code = inspect.getsource(self.forward) | ||||
|         tool_code = textwrap.dedent(f""" | ||||
|         from agents import Tool | ||||
| 
 | ||||
|         class {class_name}(Tool): | ||||
|             name = "{self.name}" | ||||
|             description = "{self.description}" | ||||
|             inputs = {json.dumps(self.inputs, separators=(',', ':'))} | ||||
|             output_type = "{self.output_type}" | ||||
|         """).strip() | ||||
| 
 | ||||
|         import re | ||||
|         def add_self_argument(source_code: str) -> str: | ||||
|             """Add 'self' as first argument to a function definition if not present.""" | ||||
|             pattern = r'def forward\(((?!self)[^)]*)\)' | ||||
|              | ||||
|             def replacement(match): | ||||
|                 args = match.group(1).strip() | ||||
|                 if args:  # If there are other arguments | ||||
|                     return f'def forward(self, {args})' | ||||
|                 return 'def forward(self)' | ||||
|                  | ||||
|             return re.sub(pattern, replacement, source_code) | ||||
| 
 | ||||
|         forward_source_code = forward_source_code.replace(self.name, "forward") | ||||
|         forward_source_code = add_self_argument(forward_source_code) | ||||
|         forward_source_code = forward_source_code.replace("@tool", "").strip() | ||||
|         tool_code += "\n\n" + textwrap.indent(forward_source_code, "    ") | ||||
|         with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f: | ||||
|             f.write(tool_code) | ||||
| 
 | ||||
|         # Save config file | ||||
|         config_file = os.path.join(output_dir, "tool_config.json") | ||||
|         if os.path.isfile(config_file): | ||||
|             with open(config_file, "r", encoding="utf-8") as f: | ||||
|                 tool_config = json.load(f) | ||||
|         else: | ||||
|             tool_config = {} | ||||
| 
 | ||||
|         tool_config = { | ||||
|             "tool_class": full_name, | ||||
|             "tool_class": self.__class__.__name__, | ||||
|             "description": self.description, | ||||
|             "name": self.name, | ||||
|             "inputs": self.inputs, | ||||
|  | @ -339,131 +356,20 @@ class Tool: | |||
|         with open(app_file, "w", encoding="utf-8") as f: | ||||
|             f.write( | ||||
|                 APP_FILE_TEMPLATE.format( | ||||
|                     module_name=last_module, class_name=self.__class__.__name__ | ||||
|                     class_name=class_name | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         # Save requirements file | ||||
|         requirements_file = os.path.join(output_dir, "requirements.txt") | ||||
|         imports = [] | ||||
|         for module in module_files: | ||||
|             imports.extend(get_imports(module)) | ||||
|         imports = list(set(imports)) | ||||
| 
 | ||||
|         tree = ast.parse(forward_source_code) | ||||
|         import_finder = ImportFinder() | ||||
|         import_finder.visit(tree) | ||||
| 
 | ||||
|         imports = list(set(import_finder.packages)) | ||||
|         with open(requirements_file, "w", encoding="utf-8") as f: | ||||
|             f.write("\n".join(imports) + "\n") | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_hub( | ||||
|         cls, | ||||
|         repo_id: str, | ||||
|         token: Optional[str] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
|         Loads a tool defined on the Hub. | ||||
| 
 | ||||
|         <Tip warning={true}> | ||||
| 
 | ||||
|         Loading a tool from the Hub means that you'll download the tool and execute it locally. | ||||
|         ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when | ||||
|         installing a package using pip/npm/apt. | ||||
| 
 | ||||
|         </Tip> | ||||
| 
 | ||||
|         Args: | ||||
|             repo_id (`str`): | ||||
|                 The name of the repo on the Hub where your tool is defined. | ||||
|             token (`str`, *optional*): | ||||
|                 The token to identify you on hf.co. If unset, will use the token generated when running | ||||
|                 `huggingface-cli login` (stored in `~/.huggingface`). | ||||
|             kwargs (additional keyword arguments, *optional*): | ||||
|                 Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as | ||||
|                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the | ||||
|                 others will be passed along to its init. | ||||
|         """ | ||||
|         hub_kwargs_names = [ | ||||
|             "cache_dir", | ||||
|             "force_download", | ||||
|             "resume_download", | ||||
|             "proxies", | ||||
|             "revision", | ||||
|             "repo_type", | ||||
|             "subfolder", | ||||
|             "local_files_only", | ||||
|         ] | ||||
|         hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} | ||||
| 
 | ||||
|         # Try to get the tool config first. | ||||
|         hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) | ||||
|         resolved_config_file = cached_file( | ||||
|             repo_id, | ||||
|             TOOL_CONFIG_FILE, | ||||
|             token=token, | ||||
|             **hub_kwargs, | ||||
|             _raise_exceptions_for_gated_repo=False, | ||||
|             _raise_exceptions_for_missing_entries=False, | ||||
|             _raise_exceptions_for_connection_errors=False, | ||||
|         ) | ||||
|         is_tool_config = resolved_config_file is not None | ||||
|         if resolved_config_file is None: | ||||
|             resolved_config_file = cached_file( | ||||
|                 repo_id, | ||||
|                 CONFIG_NAME, | ||||
|                 token=token, | ||||
|                 **hub_kwargs, | ||||
|                 _raise_exceptions_for_gated_repo=False, | ||||
|                 _raise_exceptions_for_missing_entries=False, | ||||
|                 _raise_exceptions_for_connection_errors=False, | ||||
|             ) | ||||
|         if resolved_config_file is None: | ||||
|             raise EnvironmentError( | ||||
|                 f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." | ||||
|             ) | ||||
| 
 | ||||
|         with open(resolved_config_file, encoding="utf-8") as reader: | ||||
|             config = json.load(reader) | ||||
| 
 | ||||
|         if not is_tool_config: | ||||
|             if "custom_tool" not in config: | ||||
|                 raise EnvironmentError( | ||||
|                     f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`." | ||||
|                 ) | ||||
|             custom_tool = config["custom_tool"] | ||||
|         else: | ||||
|             custom_tool = config | ||||
| 
 | ||||
|         tool_class = custom_tool["tool_class"] | ||||
|         tool_class = get_class_from_dynamic_module( | ||||
|             tool_class, repo_id, token=token, **hub_kwargs | ||||
|         ) | ||||
| 
 | ||||
|         if len(tool_class.name) == 0: | ||||
|             tool_class.name = custom_tool["name"] | ||||
|         if tool_class.name != custom_tool["name"]: | ||||
|             logger.warning( | ||||
|                 f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool " | ||||
|                 "configuration name." | ||||
|             ) | ||||
|             tool_class.name = custom_tool["name"] | ||||
| 
 | ||||
|         if len(tool_class.description) == 0: | ||||
|             tool_class.description = custom_tool["description"] | ||||
|         if tool_class.description != custom_tool["description"]: | ||||
|             logger.warning( | ||||
|                 f"{tool_class.__name__} implements a different description in its configuration and class. Using the " | ||||
|                 "tool configuration description." | ||||
|             ) | ||||
|             tool_class.description = custom_tool["description"] | ||||
| 
 | ||||
|         if tool_class.inputs != custom_tool["inputs"]: | ||||
|             tool_class.inputs = custom_tool["inputs"] | ||||
|         if tool_class.output_type != custom_tool["output_type"]: | ||||
|             tool_class.output_type = custom_tool["output_type"] | ||||
| 
 | ||||
|         if not isinstance(tool_class.inputs, dict): | ||||
|             tool_class.inputs = ast.literal_eval(tool_class.inputs) | ||||
| 
 | ||||
|         return tool_class(**kwargs) | ||||
|             f.write("agents_package\n" + "\n".join(imports) + "\n") | ||||
| 
 | ||||
|     def push_to_hub( | ||||
|         self, | ||||
|  | @ -512,6 +418,9 @@ class Tool: | |||
|         with tempfile.TemporaryDirectory() as work_dir: | ||||
|             # Save all files. | ||||
|             self.save(work_dir) | ||||
|             print(work_dir) | ||||
|             with open(work_dir + "/tool.py", "r") as f: | ||||
|                 print('\n'.join(f.readlines())) | ||||
|             logger.info( | ||||
|                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" | ||||
|             ) | ||||
|  | @ -524,6 +433,110 @@ class Tool: | |||
|                 repo_type="space", | ||||
|             ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_hub( | ||||
|         cls, | ||||
|         repo_id: str, | ||||
|         token: Optional[str] = None, | ||||
|         trust_remote_code: bool = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
|         Loads a tool defined on the Hub. | ||||
| 
 | ||||
|         <Tip warning={true}> | ||||
| 
 | ||||
|         Loading a tool from the Hub means that you'll download the tool and execute it locally. | ||||
|         ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when | ||||
|         installing a package using pip/npm/apt. | ||||
| 
 | ||||
|         </Tip> | ||||
| 
 | ||||
|         Args: | ||||
|             repo_id (`str`): | ||||
|                 The name of the repo on the Hub where your tool is defined. | ||||
|             token (`str`, *optional*): | ||||
|                 The token to identify you on hf.co. If unset, will use the token generated when running | ||||
|                 `huggingface-cli login` (stored in `~/.huggingface`). | ||||
|             kwargs (additional keyword arguments, *optional*): | ||||
|                 Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as | ||||
|                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the | ||||
|                 others will be passed along to its init. | ||||
|         """ | ||||
|         assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." | ||||
| 
 | ||||
|         hub_kwargs_names = [ | ||||
|             "cache_dir", | ||||
|             "force_download", | ||||
|             "resume_download", | ||||
|             "proxies", | ||||
|             "revision", | ||||
|             "repo_type", | ||||
|             "subfolder", | ||||
|             "local_files_only", | ||||
|         ] | ||||
|         hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} | ||||
| 
 | ||||
|         tool_file = "tool.py" | ||||
| 
 | ||||
|         # Get the tool's tool.py file. | ||||
|         hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) | ||||
|         resolved_tool_file = cached_file( | ||||
|             repo_id, | ||||
|             tool_file, | ||||
|             token=token, | ||||
|             **hub_kwargs, | ||||
|             _raise_exceptions_for_gated_repo=False, | ||||
|             _raise_exceptions_for_missing_entries=False, | ||||
|             _raise_exceptions_for_connection_errors=False, | ||||
|         ) | ||||
|         tool_code = resolved_tool_file is not None | ||||
|         if resolved_tool_file is None: | ||||
|             resolved_tool_file = cached_file( | ||||
|                 repo_id, | ||||
|                 tool_file, | ||||
|                 token=token, | ||||
|                 **hub_kwargs, | ||||
|                 _raise_exceptions_for_gated_repo=False, | ||||
|                 _raise_exceptions_for_missing_entries=False, | ||||
|                 _raise_exceptions_for_connection_errors=False, | ||||
|             ) | ||||
|         if resolved_tool_file is None: | ||||
|             raise EnvironmentError( | ||||
|                 f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." | ||||
|             ) | ||||
| 
 | ||||
|         with open(resolved_tool_file, encoding="utf-8") as reader: | ||||
|             tool_code = "".join(reader.readlines())     | ||||
|          | ||||
|         # Find the Tool subclass in the namespace | ||||
|         with tempfile.TemporaryDirectory() as temp_dir: | ||||
|             # Save the code to a file | ||||
|             module_path = os.path.join(temp_dir, "tool.py") | ||||
|             with open(module_path, "w") as f: | ||||
|                 f.write(tool_code) | ||||
| 
 | ||||
|             # Load module from file path | ||||
|             spec = importlib.util.spec_from_file_location("custom_tool", module_path) | ||||
|             module = importlib.util.module_from_spec(spec) | ||||
|             spec.loader.exec_module(module) | ||||
| 
 | ||||
|             # Find and instantiate the Tool class | ||||
|             for item_name in dir(module): | ||||
|                 item = getattr(module, item_name) | ||||
|                 if isinstance(item, type) and issubclass(item, Tool) and item != Tool: | ||||
|                     tool_class = item | ||||
|                     break | ||||
| 
 | ||||
|             if tool_class is None: | ||||
|                 raise ValueError("No Tool subclass found in the code") | ||||
|          | ||||
|         if not isinstance(tool_class.inputs, dict): | ||||
|             tool_class.inputs = ast.literal_eval(tool_class.inputs) | ||||
| 
 | ||||
|         return tool_class(**kwargs) | ||||
| 
 | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_space( | ||||
|         space_id: str, | ||||
|  | @ -967,7 +980,8 @@ def tool(tool_function: Callable) -> Tool: | |||
|         raise TypeHintParsingException( | ||||
|             "Tool return type not found: make sure your function has a return type hint!" | ||||
|         ) | ||||
|     class_name = f"{parameters['name'].capitalize()}Tool" | ||||
|     class_name = ''.join([el.title() for el in parameters['name'].split('_')]) | ||||
| 
 | ||||
|     if parameters["return"]["type"] == "object": | ||||
|         parameters["return"]["type"] = "any" | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ | |||
| import json | ||||
| import re | ||||
| from typing import Tuple, Dict, Union | ||||
| import ast | ||||
| 
 | ||||
| from transformers.utils.import_utils import _is_package_available | ||||
| 
 | ||||
|  | @ -110,4 +111,20 @@ def truncate_content( | |||
|             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] | ||||
|         ) | ||||
|      | ||||
| class ImportFinder(ast.NodeVisitor): | ||||
|     def __init__(self): | ||||
|         self.packages = set() | ||||
|          | ||||
|     def visit_Import(self, node): | ||||
|         for alias in node.names: | ||||
|             # Get the base package name (before any dots) | ||||
|             base_package = alias.name.split('.')[0] | ||||
|             self.packages.add(base_package) | ||||
| 
 | ||||
|     def visit_ImportFrom(self, node): | ||||
|         if node.module:  # for "from x import y" statements | ||||
|             # Get the base package name (before any dots) | ||||
|             base_package = node.module.split('.')[0] | ||||
|             self.packages.add(base_package) | ||||
| 
 | ||||
| __all__ = [] | ||||
		Loading…
	
		Reference in New Issue