Get rid of necessity to declare tools in their own .py file
This commit is contained in:
		
							parent
							
								
									aef0510e68
								
							
						
					
					
						commit
						0eb582bdba
					
				|  | @ -38,10 +38,10 @@ The custom tool needs: | ||||||
| 
 | 
 | ||||||
| The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`]. | The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`]. | ||||||
| 
 | 
 | ||||||
|  | Also, all imports should be put within the tool's forward function, else you will get an error. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from transformers import Tool | from transformers import Tool | ||||||
| from huggingface_hub import list_models |  | ||||||
| 
 | 
 | ||||||
| class HFModelDownloadsTool(Tool): | class HFModelDownloadsTool(Tool): | ||||||
|     name = "model_download_counter" |     name = "model_download_counter" | ||||||
|  | @ -58,26 +58,27 @@ class HFModelDownloadsTool(Tool): | ||||||
|     output_type = "string" |     output_type = "string" | ||||||
| 
 | 
 | ||||||
|     def forward(self, task: str): |     def forward(self, task: str): | ||||||
|  |         from huggingface_hub import list_models | ||||||
|  | 
 | ||||||
|         model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) |         model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) | ||||||
|         return model.id |         return model.id | ||||||
| ``` |  | ||||||
| 
 |  | ||||||
| Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| ```python |  | ||||||
| from model_downloads import HFModelDownloadsTool |  | ||||||
| 
 |  | ||||||
| tool = HFModelDownloadsTool() | tool = HFModelDownloadsTool() | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | Now the custom `HfModelDownloadsTool` class is ready. | ||||||
|  | 
 | ||||||
| You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. | You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| tool.push_to_hub("{your_username}/hf-model-downloads") | from dotenv import load_dotenv | ||||||
|  | 
 | ||||||
|  | load_dotenv() | ||||||
|  | 
 | ||||||
|  | tool.push_to_hub("m-ric/hf-model-downloads", token=os.getenv("HF_TOKEN")) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. | Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. | ||||||
|  | Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`. | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| from transformers import load_tool, CodeAgent | from transformers import load_tool, CodeAgent | ||||||
|  | @ -159,7 +160,7 @@ We love Langchain and think it has a very compelling suite of tools. | ||||||
| To import a tool from LangChain, use the `from_langchain()` method. | To import a tool from LangChain, use the `from_langchain()` method. | ||||||
| 
 | 
 | ||||||
| Here is how you can use it to recreate the intro's search result using a LangChain web search tool. | Here is how you can use it to recreate the intro's search result using a LangChain web search tool. | ||||||
| This tool will need `pip install google-search-results` to work properly. | This tool will need `pip install langchain google-search-results -q` to work properly. | ||||||
| ```python | ```python | ||||||
| from langchain.agents import load_tools | from langchain.agents import load_tools | ||||||
| from agents import Tool, CodeAgent | from agents import Tool, CodeAgent | ||||||
|  | @ -191,7 +192,6 @@ agent.run( | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| | **Audio**                                                                                                                                            | | | **Audio**                                                                                                                                            | | ||||||
| |------------------------------------------------------------------------------------------------------------------------------------------------------| | |------------------------------------------------------------------------------------------------------------------------------------------------------| | ||||||
| | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> | | | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> | | ||||||
|  |  | ||||||
|  | @ -44,7 +44,6 @@ from transformers.dynamic_module_utils import ( | ||||||
| ) | ) | ||||||
| from transformers import AutoProcessor | from transformers import AutoProcessor | ||||||
| from transformers.utils import ( | from transformers.utils import ( | ||||||
|     CONFIG_NAME, |  | ||||||
|     TypeHintParsingException, |     TypeHintParsingException, | ||||||
|     cached_file, |     cached_file, | ||||||
|     get_json_schema, |     get_json_schema, | ||||||
|  | @ -53,6 +52,8 @@ from transformers.utils import ( | ||||||
|     is_vision_available, |     is_vision_available, | ||||||
| ) | ) | ||||||
| from .types import ImageType, handle_agent_inputs, handle_agent_outputs | from .types import ImageType, handle_agent_inputs, handle_agent_outputs | ||||||
|  | from .utils import ImportFinder | ||||||
|  | 
 | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -104,7 +105,7 @@ def setup_default_tools(): | ||||||
| 
 | 
 | ||||||
| # docstyle-ignore | # docstyle-ignore | ||||||
| APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo | ||||||
| from {module_name} import {class_name} | from tool import {class_name} | ||||||
| 
 | 
 | ||||||
| launch_gradio_demo({class_name}) | launch_gradio_demo({class_name}) | ||||||
| """ | """ | ||||||
|  | @ -304,28 +305,44 @@ class Tool: | ||||||
|             output_dir (`str`): The folder in which you want to save your tool. |             output_dir (`str`): The folder in which you want to save your tool. | ||||||
|         """ |         """ | ||||||
|         os.makedirs(output_dir, exist_ok=True) |         os.makedirs(output_dir, exist_ok=True) | ||||||
|         # Save module file |         class_name = self.__class__.__name__ | ||||||
|         if self.__module__ == "__main__": |  | ||||||
|             raise ValueError( |  | ||||||
|                 f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You " |  | ||||||
|                 "have to put this code in a separate module so we can include it in the saved folder." |  | ||||||
|             ) |  | ||||||
|         module_files = custom_object_save(self, output_dir) |  | ||||||
| 
 | 
 | ||||||
|         module_name = self.__class__.__module__ |         # Save tool file | ||||||
|         last_module = module_name.split(".")[-1] |         forward_source_code = inspect.getsource(self.forward) | ||||||
|         full_name = f"{last_module}.{self.__class__.__name__}" |         tool_code = textwrap.dedent(f""" | ||||||
|  |         from agents import Tool | ||||||
|  | 
 | ||||||
|  |         class {class_name}(Tool): | ||||||
|  |             name = "{self.name}" | ||||||
|  |             description = "{self.description}" | ||||||
|  |             inputs = {json.dumps(self.inputs, separators=(',', ':'))} | ||||||
|  |             output_type = "{self.output_type}" | ||||||
|  |         """).strip() | ||||||
|  | 
 | ||||||
|  |         import re | ||||||
|  |         def add_self_argument(source_code: str) -> str: | ||||||
|  |             """Add 'self' as first argument to a function definition if not present.""" | ||||||
|  |             pattern = r'def forward\(((?!self)[^)]*)\)' | ||||||
|  |              | ||||||
|  |             def replacement(match): | ||||||
|  |                 args = match.group(1).strip() | ||||||
|  |                 if args:  # If there are other arguments | ||||||
|  |                     return f'def forward(self, {args})' | ||||||
|  |                 return 'def forward(self)' | ||||||
|  |                  | ||||||
|  |             return re.sub(pattern, replacement, source_code) | ||||||
|  | 
 | ||||||
|  |         forward_source_code = forward_source_code.replace(self.name, "forward") | ||||||
|  |         forward_source_code = add_self_argument(forward_source_code) | ||||||
|  |         forward_source_code = forward_source_code.replace("@tool", "").strip() | ||||||
|  |         tool_code += "\n\n" + textwrap.indent(forward_source_code, "    ") | ||||||
|  |         with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f: | ||||||
|  |             f.write(tool_code) | ||||||
| 
 | 
 | ||||||
|         # Save config file |         # Save config file | ||||||
|         config_file = os.path.join(output_dir, "tool_config.json") |         config_file = os.path.join(output_dir, "tool_config.json") | ||||||
|         if os.path.isfile(config_file): |  | ||||||
|             with open(config_file, "r", encoding="utf-8") as f: |  | ||||||
|                 tool_config = json.load(f) |  | ||||||
|         else: |  | ||||||
|             tool_config = {} |  | ||||||
| 
 |  | ||||||
|         tool_config = { |         tool_config = { | ||||||
|             "tool_class": full_name, |             "tool_class": self.__class__.__name__, | ||||||
|             "description": self.description, |             "description": self.description, | ||||||
|             "name": self.name, |             "name": self.name, | ||||||
|             "inputs": self.inputs, |             "inputs": self.inputs, | ||||||
|  | @ -339,131 +356,20 @@ class Tool: | ||||||
|         with open(app_file, "w", encoding="utf-8") as f: |         with open(app_file, "w", encoding="utf-8") as f: | ||||||
|             f.write( |             f.write( | ||||||
|                 APP_FILE_TEMPLATE.format( |                 APP_FILE_TEMPLATE.format( | ||||||
|                     module_name=last_module, class_name=self.__class__.__name__ |                     class_name=class_name | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # Save requirements file |         # Save requirements file | ||||||
|         requirements_file = os.path.join(output_dir, "requirements.txt") |         requirements_file = os.path.join(output_dir, "requirements.txt") | ||||||
|         imports = [] | 
 | ||||||
|         for module in module_files: |         tree = ast.parse(forward_source_code) | ||||||
|             imports.extend(get_imports(module)) |         import_finder = ImportFinder() | ||||||
|         imports = list(set(imports)) |         import_finder.visit(tree) | ||||||
|  | 
 | ||||||
|  |         imports = list(set(import_finder.packages)) | ||||||
|         with open(requirements_file, "w", encoding="utf-8") as f: |         with open(requirements_file, "w", encoding="utf-8") as f: | ||||||
|             f.write("\n".join(imports) + "\n") |             f.write("agents_package\n" + "\n".join(imports) + "\n") | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def from_hub( |  | ||||||
|         cls, |  | ||||||
|         repo_id: str, |  | ||||||
|         token: Optional[str] = None, |  | ||||||
|         **kwargs, |  | ||||||
|     ): |  | ||||||
|         """ |  | ||||||
|         Loads a tool defined on the Hub. |  | ||||||
| 
 |  | ||||||
|         <Tip warning={true}> |  | ||||||
| 
 |  | ||||||
|         Loading a tool from the Hub means that you'll download the tool and execute it locally. |  | ||||||
|         ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when |  | ||||||
|         installing a package using pip/npm/apt. |  | ||||||
| 
 |  | ||||||
|         </Tip> |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             repo_id (`str`): |  | ||||||
|                 The name of the repo on the Hub where your tool is defined. |  | ||||||
|             token (`str`, *optional*): |  | ||||||
|                 The token to identify you on hf.co. If unset, will use the token generated when running |  | ||||||
|                 `huggingface-cli login` (stored in `~/.huggingface`). |  | ||||||
|             kwargs (additional keyword arguments, *optional*): |  | ||||||
|                 Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as |  | ||||||
|                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the |  | ||||||
|                 others will be passed along to its init. |  | ||||||
|         """ |  | ||||||
|         hub_kwargs_names = [ |  | ||||||
|             "cache_dir", |  | ||||||
|             "force_download", |  | ||||||
|             "resume_download", |  | ||||||
|             "proxies", |  | ||||||
|             "revision", |  | ||||||
|             "repo_type", |  | ||||||
|             "subfolder", |  | ||||||
|             "local_files_only", |  | ||||||
|         ] |  | ||||||
|         hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} |  | ||||||
| 
 |  | ||||||
|         # Try to get the tool config first. |  | ||||||
|         hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) |  | ||||||
|         resolved_config_file = cached_file( |  | ||||||
|             repo_id, |  | ||||||
|             TOOL_CONFIG_FILE, |  | ||||||
|             token=token, |  | ||||||
|             **hub_kwargs, |  | ||||||
|             _raise_exceptions_for_gated_repo=False, |  | ||||||
|             _raise_exceptions_for_missing_entries=False, |  | ||||||
|             _raise_exceptions_for_connection_errors=False, |  | ||||||
|         ) |  | ||||||
|         is_tool_config = resolved_config_file is not None |  | ||||||
|         if resolved_config_file is None: |  | ||||||
|             resolved_config_file = cached_file( |  | ||||||
|                 repo_id, |  | ||||||
|                 CONFIG_NAME, |  | ||||||
|                 token=token, |  | ||||||
|                 **hub_kwargs, |  | ||||||
|                 _raise_exceptions_for_gated_repo=False, |  | ||||||
|                 _raise_exceptions_for_missing_entries=False, |  | ||||||
|                 _raise_exceptions_for_connection_errors=False, |  | ||||||
|             ) |  | ||||||
|         if resolved_config_file is None: |  | ||||||
|             raise EnvironmentError( |  | ||||||
|                 f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         with open(resolved_config_file, encoding="utf-8") as reader: |  | ||||||
|             config = json.load(reader) |  | ||||||
| 
 |  | ||||||
|         if not is_tool_config: |  | ||||||
|             if "custom_tool" not in config: |  | ||||||
|                 raise EnvironmentError( |  | ||||||
|                     f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`." |  | ||||||
|                 ) |  | ||||||
|             custom_tool = config["custom_tool"] |  | ||||||
|         else: |  | ||||||
|             custom_tool = config |  | ||||||
| 
 |  | ||||||
|         tool_class = custom_tool["tool_class"] |  | ||||||
|         tool_class = get_class_from_dynamic_module( |  | ||||||
|             tool_class, repo_id, token=token, **hub_kwargs |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         if len(tool_class.name) == 0: |  | ||||||
|             tool_class.name = custom_tool["name"] |  | ||||||
|         if tool_class.name != custom_tool["name"]: |  | ||||||
|             logger.warning( |  | ||||||
|                 f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool " |  | ||||||
|                 "configuration name." |  | ||||||
|             ) |  | ||||||
|             tool_class.name = custom_tool["name"] |  | ||||||
| 
 |  | ||||||
|         if len(tool_class.description) == 0: |  | ||||||
|             tool_class.description = custom_tool["description"] |  | ||||||
|         if tool_class.description != custom_tool["description"]: |  | ||||||
|             logger.warning( |  | ||||||
|                 f"{tool_class.__name__} implements a different description in its configuration and class. Using the " |  | ||||||
|                 "tool configuration description." |  | ||||||
|             ) |  | ||||||
|             tool_class.description = custom_tool["description"] |  | ||||||
| 
 |  | ||||||
|         if tool_class.inputs != custom_tool["inputs"]: |  | ||||||
|             tool_class.inputs = custom_tool["inputs"] |  | ||||||
|         if tool_class.output_type != custom_tool["output_type"]: |  | ||||||
|             tool_class.output_type = custom_tool["output_type"] |  | ||||||
| 
 |  | ||||||
|         if not isinstance(tool_class.inputs, dict): |  | ||||||
|             tool_class.inputs = ast.literal_eval(tool_class.inputs) |  | ||||||
| 
 |  | ||||||
|         return tool_class(**kwargs) |  | ||||||
| 
 | 
 | ||||||
|     def push_to_hub( |     def push_to_hub( | ||||||
|         self, |         self, | ||||||
|  | @ -512,6 +418,9 @@ class Tool: | ||||||
|         with tempfile.TemporaryDirectory() as work_dir: |         with tempfile.TemporaryDirectory() as work_dir: | ||||||
|             # Save all files. |             # Save all files. | ||||||
|             self.save(work_dir) |             self.save(work_dir) | ||||||
|  |             print(work_dir) | ||||||
|  |             with open(work_dir + "/tool.py", "r") as f: | ||||||
|  |                 print('\n'.join(f.readlines())) | ||||||
|             logger.info( |             logger.info( | ||||||
|                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" |                 f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" | ||||||
|             ) |             ) | ||||||
|  | @ -524,6 +433,110 @@ class Tool: | ||||||
|                 repo_type="space", |                 repo_type="space", | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_hub( | ||||||
|  |         cls, | ||||||
|  |         repo_id: str, | ||||||
|  |         token: Optional[str] = None, | ||||||
|  |         trust_remote_code: bool = False, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Loads a tool defined on the Hub. | ||||||
|  | 
 | ||||||
|  |         <Tip warning={true}> | ||||||
|  | 
 | ||||||
|  |         Loading a tool from the Hub means that you'll download the tool and execute it locally. | ||||||
|  |         ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when | ||||||
|  |         installing a package using pip/npm/apt. | ||||||
|  | 
 | ||||||
|  |         </Tip> | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             repo_id (`str`): | ||||||
|  |                 The name of the repo on the Hub where your tool is defined. | ||||||
|  |             token (`str`, *optional*): | ||||||
|  |                 The token to identify you on hf.co. If unset, will use the token generated when running | ||||||
|  |                 `huggingface-cli login` (stored in `~/.huggingface`). | ||||||
|  |             kwargs (additional keyword arguments, *optional*): | ||||||
|  |                 Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as | ||||||
|  |                 `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the | ||||||
|  |                 others will be passed along to its init. | ||||||
|  |         """ | ||||||
|  |         assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." | ||||||
|  | 
 | ||||||
|  |         hub_kwargs_names = [ | ||||||
|  |             "cache_dir", | ||||||
|  |             "force_download", | ||||||
|  |             "resume_download", | ||||||
|  |             "proxies", | ||||||
|  |             "revision", | ||||||
|  |             "repo_type", | ||||||
|  |             "subfolder", | ||||||
|  |             "local_files_only", | ||||||
|  |         ] | ||||||
|  |         hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names} | ||||||
|  | 
 | ||||||
|  |         tool_file = "tool.py" | ||||||
|  | 
 | ||||||
|  |         # Get the tool's tool.py file. | ||||||
|  |         hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs) | ||||||
|  |         resolved_tool_file = cached_file( | ||||||
|  |             repo_id, | ||||||
|  |             tool_file, | ||||||
|  |             token=token, | ||||||
|  |             **hub_kwargs, | ||||||
|  |             _raise_exceptions_for_gated_repo=False, | ||||||
|  |             _raise_exceptions_for_missing_entries=False, | ||||||
|  |             _raise_exceptions_for_connection_errors=False, | ||||||
|  |         ) | ||||||
|  |         tool_code = resolved_tool_file is not None | ||||||
|  |         if resolved_tool_file is None: | ||||||
|  |             resolved_tool_file = cached_file( | ||||||
|  |                 repo_id, | ||||||
|  |                 tool_file, | ||||||
|  |                 token=token, | ||||||
|  |                 **hub_kwargs, | ||||||
|  |                 _raise_exceptions_for_gated_repo=False, | ||||||
|  |                 _raise_exceptions_for_missing_entries=False, | ||||||
|  |                 _raise_exceptions_for_connection_errors=False, | ||||||
|  |             ) | ||||||
|  |         if resolved_tool_file is None: | ||||||
|  |             raise EnvironmentError( | ||||||
|  |                 f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         with open(resolved_tool_file, encoding="utf-8") as reader: | ||||||
|  |             tool_code = "".join(reader.readlines())     | ||||||
|  |          | ||||||
|  |         # Find the Tool subclass in the namespace | ||||||
|  |         with tempfile.TemporaryDirectory() as temp_dir: | ||||||
|  |             # Save the code to a file | ||||||
|  |             module_path = os.path.join(temp_dir, "tool.py") | ||||||
|  |             with open(module_path, "w") as f: | ||||||
|  |                 f.write(tool_code) | ||||||
|  | 
 | ||||||
|  |             # Load module from file path | ||||||
|  |             spec = importlib.util.spec_from_file_location("custom_tool", module_path) | ||||||
|  |             module = importlib.util.module_from_spec(spec) | ||||||
|  |             spec.loader.exec_module(module) | ||||||
|  | 
 | ||||||
|  |             # Find and instantiate the Tool class | ||||||
|  |             for item_name in dir(module): | ||||||
|  |                 item = getattr(module, item_name) | ||||||
|  |                 if isinstance(item, type) and issubclass(item, Tool) and item != Tool: | ||||||
|  |                     tool_class = item | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |             if tool_class is None: | ||||||
|  |                 raise ValueError("No Tool subclass found in the code") | ||||||
|  |          | ||||||
|  |         if not isinstance(tool_class.inputs, dict): | ||||||
|  |             tool_class.inputs = ast.literal_eval(tool_class.inputs) | ||||||
|  | 
 | ||||||
|  |         return tool_class(**kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_space( |     def from_space( | ||||||
|         space_id: str, |         space_id: str, | ||||||
|  | @ -967,7 +980,8 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|         raise TypeHintParsingException( |         raise TypeHintParsingException( | ||||||
|             "Tool return type not found: make sure your function has a return type hint!" |             "Tool return type not found: make sure your function has a return type hint!" | ||||||
|         ) |         ) | ||||||
|     class_name = f"{parameters['name'].capitalize()}Tool" |     class_name = ''.join([el.title() for el in parameters['name'].split('_')]) | ||||||
|  | 
 | ||||||
|     if parameters["return"]["type"] == "object": |     if parameters["return"]["type"] == "object": | ||||||
|         parameters["return"]["type"] = "any" |         parameters["return"]["type"] = "any" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ | ||||||
| import json | import json | ||||||
| import re | import re | ||||||
| from typing import Tuple, Dict, Union | from typing import Tuple, Dict, Union | ||||||
|  | import ast | ||||||
| 
 | 
 | ||||||
| from transformers.utils.import_utils import _is_package_available | from transformers.utils.import_utils import _is_package_available | ||||||
| 
 | 
 | ||||||
|  | @ -109,5 +110,21 @@ def truncate_content( | ||||||
|             + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" |             + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" | ||||||
|             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] |             + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] | ||||||
|         ) |         ) | ||||||
|  |      | ||||||
|  | class ImportFinder(ast.NodeVisitor): | ||||||
|  |     def __init__(self): | ||||||
|  |         self.packages = set() | ||||||
|  |          | ||||||
|  |     def visit_Import(self, node): | ||||||
|  |         for alias in node.names: | ||||||
|  |             # Get the base package name (before any dots) | ||||||
|  |             base_package = alias.name.split('.')[0] | ||||||
|  |             self.packages.add(base_package) | ||||||
|  | 
 | ||||||
|  |     def visit_ImportFrom(self, node): | ||||||
|  |         if node.module:  # for "from x import y" statements | ||||||
|  |             # Get the base package name (before any dots) | ||||||
|  |             base_package = node.module.split('.')[0] | ||||||
|  |             self.packages.add(base_package) | ||||||
| 
 | 
 | ||||||
| __all__ = [] | __all__ = [] | ||||||
		Loading…
	
		Reference in New Issue