Force imports inside tool
This commit is contained in:
		
							parent
							
								
									b6fc583d96
								
							
						
					
					
						commit
						aef0510e68
					
				|  | @ -22,6 +22,7 @@ import io | ||||||
| import json | import json | ||||||
| import os | import os | ||||||
| import tempfile | import tempfile | ||||||
|  | import textwrap | ||||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union | ||||||
|  | @ -115,12 +116,77 @@ def validate_after_init(cls, do_validate_forward: bool = True): | ||||||
|     @wraps(original_init) |     @wraps(original_init) | ||||||
|     def new_init(self, *args, **kwargs): |     def new_init(self, *args, **kwargs): | ||||||
|         original_init(self, *args, **kwargs) |         original_init(self, *args, **kwargs) | ||||||
|         if not isinstance(self, PipelineTool): |  | ||||||
|         self.validate_arguments(do_validate_forward=do_validate_forward) |         self.validate_arguments(do_validate_forward=do_validate_forward) | ||||||
| 
 | 
 | ||||||
|     cls.__init__ = new_init |     cls.__init__ = new_init | ||||||
|     return cls |     return cls | ||||||
| 
 | 
 | ||||||
|  | def validate_forward_method_args(cls): | ||||||
|  |     """Validates that all names in forward method are properly defined. | ||||||
|  |     In particular it will check that all imports are done within the function.""" | ||||||
|  |     if 'forward' not in cls.__dict__: | ||||||
|  |         return | ||||||
|  | 
 | ||||||
|  |     forward = cls.__dict__['forward'] | ||||||
|  |     source_code = textwrap.dedent(inspect.getsource(forward)) | ||||||
|  |     tree = ast.parse(source_code) | ||||||
|  |      | ||||||
|  |     # Get function arguments | ||||||
|  |     func_node = tree.body[0] | ||||||
|  |     arg_names = {arg.arg for arg in func_node.args.args} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     import builtins | ||||||
|  |     builtin_names = set(vars(builtins)) | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |     # Find all used names that aren't arguments or self attributes | ||||||
|  |     class NameChecker(ast.NodeVisitor): | ||||||
|  |         def __init__(self): | ||||||
|  |             self.undefined_names = set() | ||||||
|  |             self.imports = {} | ||||||
|  |             self.from_imports = {} | ||||||
|  | 
 | ||||||
|  |         def visit_Import(self, node): | ||||||
|  |             """Handle simple imports like 'import datetime'.""" | ||||||
|  |             for name in node.names: | ||||||
|  |                 actual_name = name.asname or name.name | ||||||
|  |                 self.imports[actual_name] = (name.name, actual_name) | ||||||
|  |                  | ||||||
|  |         def visit_ImportFrom(self, node): | ||||||
|  |             """Handle from imports like 'from datetime import datetime'.""" | ||||||
|  |             module = node.module or '' | ||||||
|  |             for name in node.names: | ||||||
|  |                 actual_name = name.asname or name.name | ||||||
|  |                 self.from_imports[actual_name] = (module, name.name, actual_name) | ||||||
|  |              | ||||||
|  |         def visit_Name(self, node): | ||||||
|  |             if (isinstance(node.ctx, ast.Load) and not ( | ||||||
|  |                 node.id == "tool" or | ||||||
|  |                 node.id in builtin_names or | ||||||
|  |                 node.id in arg_names or  | ||||||
|  |                 node.id == 'self' | ||||||
|  |             )): | ||||||
|  |                 if node.id not in self.from_imports and node.id not in self.imports: | ||||||
|  |                     self.undefined_names.add(node.id) | ||||||
|  |                  | ||||||
|  |         def visit_Attribute(self, node): | ||||||
|  |             # Skip self.something | ||||||
|  |             if not (isinstance(node.value, ast.Name) and node.value.id == 'self'): | ||||||
|  |                 self.generic_visit(node) | ||||||
|  |      | ||||||
|  |     checker = NameChecker() | ||||||
|  |     checker.visit(tree) | ||||||
|  |      | ||||||
|  |     if checker.undefined_names: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}. | ||||||
|  |             Make sure all imports and variables are defined within the method. | ||||||
|  |             For instance: | ||||||
|  |              | ||||||
|  |             """ | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| AUTHORIZED_TYPES = [ | AUTHORIZED_TYPES = [ | ||||||
|     "string", |     "string", | ||||||
|     "boolean", |     "boolean", | ||||||
|  | @ -136,7 +202,7 @@ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} | ||||||
| 
 | 
 | ||||||
| class Tool: | class Tool: | ||||||
|     """ |     """ | ||||||
|     A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the |     A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the | ||||||
|     following class attributes: |     following class attributes: | ||||||
| 
 | 
 | ||||||
|     - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it |     - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it | ||||||
|  | @ -151,7 +217,7 @@ class Tool: | ||||||
|     - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo` |     - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo` | ||||||
|       or to make a nice space from your tool, and also can be used in the generated description for your tool. |       or to make a nice space from your tool, and also can be used in the generated description for your tool. | ||||||
| 
 | 
 | ||||||
|     You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being |     You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being | ||||||
|     usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at |     usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at | ||||||
|     instantiation. |     instantiation. | ||||||
|     """ |     """ | ||||||
|  | @ -166,8 +232,10 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|     def __init_subclass__(cls, **kwargs): |     def __init_subclass__(cls, **kwargs): | ||||||
|         super().__init_subclass__(**kwargs) |         super().__init_subclass__(**kwargs) | ||||||
|  |         validate_forward_method_args(cls) | ||||||
|         validate_after_init(cls, do_validate_forward=False) |         validate_after_init(cls, do_validate_forward=False) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|     def validate_arguments(self, do_validate_forward: bool = True): |     def validate_arguments(self, do_validate_forward: bool = True): | ||||||
|         required_attributes = { |         required_attributes = { | ||||||
|             "description": str, |             "description": str, | ||||||
|  | @ -198,7 +266,6 @@ class Tool: | ||||||
| 
 | 
 | ||||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES |         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||||
|         if do_validate_forward: |         if do_validate_forward: | ||||||
|             if not isinstance(self, PipelineTool): |  | ||||||
|             signature = inspect.signature(self.forward) |             signature = inspect.signature(self.forward) | ||||||
|             if not set(signature.parameters.keys()) == set(self.inputs.keys()): |             if not set(signature.parameters.keys()) == set(self.inputs.keys()): | ||||||
|                 raise Exception( |                 raise Exception( | ||||||
|  | @ -206,9 +273,11 @@ class Tool: | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|     def forward(self, *args, **kwargs): |     def forward(self, *args, **kwargs): | ||||||
|         return NotImplemented("Write this method in your subclass of `Tool`.") |         return NotImplementedError("Write this method in your subclass of `Tool`.") | ||||||
| 
 | 
 | ||||||
|     def __call__(self, *args, **kwargs): |     def __call__(self, *args, **kwargs): | ||||||
|  |         if not self.is_initialized: | ||||||
|  |             self.setup() | ||||||
|         args, kwargs = handle_agent_inputs(*args, **kwargs) |         args, kwargs = handle_agent_inputs(*args, **kwargs) | ||||||
|         outputs = self.forward(*args, **kwargs) |         outputs = self.forward(*args, **kwargs) | ||||||
|         return handle_agent_outputs(outputs, self.output_type) |         return handle_agent_outputs(outputs, self.output_type) | ||||||
|  | @ -225,7 +294,6 @@ class Tool: | ||||||
|         Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your |         Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your | ||||||
|         tool in `output_dir` as well as autogenerate: |         tool in `output_dir` as well as autogenerate: | ||||||
| 
 | 
 | ||||||
|         - a config file named `tool_config.json` |  | ||||||
|         - an `app.py` file so that your tool can be converted to a space |         - an `app.py` file so that your tool can be converted to a space | ||||||
|         - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its |         - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its | ||||||
|           code) |           code) | ||||||
|  | @ -677,166 +745,6 @@ def compile_jinja_template(template): | ||||||
|     return jinja_env.from_string(template) |     return jinja_env.from_string(template) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PipelineTool(Tool): |  | ||||||
|     """ |  | ||||||
|     A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will |  | ||||||
|     need to specify: |  | ||||||
| 
 |  | ||||||
|     - **model_class** (`type`) -- The class to use to load the model in this tool. |  | ||||||
|     - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. |  | ||||||
|     - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the |  | ||||||
|       pre-processor |  | ||||||
|     - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the |  | ||||||
|       post-processor (when different from the pre-processor). |  | ||||||
| 
 |  | ||||||
|     Args: |  | ||||||
|         model (`str` or [`PreTrainedModel`], *optional*): |  | ||||||
|             The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the |  | ||||||
|             value of the class attribute `default_checkpoint`. |  | ||||||
|         pre_processor (`str` or `Any`, *optional*): |  | ||||||
|             The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a |  | ||||||
|             tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if |  | ||||||
|             unset. |  | ||||||
|         post_processor (`str` or `Any`, *optional*): |  | ||||||
|             The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a |  | ||||||
|             tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if |  | ||||||
|             unset. |  | ||||||
|         device (`int`, `str` or `torch.device`, *optional*): |  | ||||||
|             The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the |  | ||||||
|             CPU otherwise. |  | ||||||
|         device_map (`str` or `dict`, *optional*): |  | ||||||
|             If passed along, will be used to instantiate the model. |  | ||||||
|         model_kwargs (`dict`, *optional*): |  | ||||||
|             Any keyword argument to send to the model instantiation. |  | ||||||
|         token (`str`, *optional*): |  | ||||||
|             The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when |  | ||||||
|             running `huggingface-cli login` (stored in `~/.huggingface`). |  | ||||||
|         hub_kwargs (additional keyword arguments, *optional*): |  | ||||||
|             Any additional keyword argument to send to the methods that will load the data from the Hub. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     pre_processor_class = AutoProcessor |  | ||||||
|     model_class = None |  | ||||||
|     post_processor_class = AutoProcessor |  | ||||||
|     default_checkpoint = None |  | ||||||
|     description = "This is a pipeline tool" |  | ||||||
|     name = "pipeline" |  | ||||||
|     inputs = {"prompt": str} |  | ||||||
|     output_type = str |  | ||||||
| 
 |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         model=None, |  | ||||||
|         pre_processor=None, |  | ||||||
|         post_processor=None, |  | ||||||
|         device=None, |  | ||||||
|         device_map=None, |  | ||||||
|         model_kwargs=None, |  | ||||||
|         token=None, |  | ||||||
|         **hub_kwargs, |  | ||||||
|     ): |  | ||||||
|         if not is_torch_available(): |  | ||||||
|             raise ImportError("Please install torch in order to use this tool.") |  | ||||||
| 
 |  | ||||||
|         if not is_accelerate_available(): |  | ||||||
|             raise ImportError("Please install accelerate in order to use this tool.") |  | ||||||
| 
 |  | ||||||
|         if model is None: |  | ||||||
|             if self.default_checkpoint is None: |  | ||||||
|                 raise ValueError( |  | ||||||
|                     "This tool does not implement a default checkpoint, you need to pass one." |  | ||||||
|                 ) |  | ||||||
|             model = self.default_checkpoint |  | ||||||
|         if pre_processor is None: |  | ||||||
|             pre_processor = model |  | ||||||
| 
 |  | ||||||
|         self.model = model |  | ||||||
|         self.pre_processor = pre_processor |  | ||||||
|         self.post_processor = post_processor |  | ||||||
|         self.device = device |  | ||||||
|         self.device_map = device_map |  | ||||||
|         self.model_kwargs = {} if model_kwargs is None else model_kwargs |  | ||||||
|         if device_map is not None: |  | ||||||
|             self.model_kwargs["device_map"] = device_map |  | ||||||
|         self.hub_kwargs = hub_kwargs |  | ||||||
|         self.hub_kwargs["token"] = token |  | ||||||
| 
 |  | ||||||
|         super().__init__() |  | ||||||
| 
 |  | ||||||
|     def setup(self): |  | ||||||
|         """ |  | ||||||
|         Instantiates the `pre_processor`, `model` and `post_processor` if necessary. |  | ||||||
|         """ |  | ||||||
|         if isinstance(self.pre_processor, str): |  | ||||||
|             self.pre_processor = self.pre_processor_class.from_pretrained( |  | ||||||
|                 self.pre_processor, **self.hub_kwargs |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         if isinstance(self.model, str): |  | ||||||
|             self.model = self.model_class.from_pretrained( |  | ||||||
|                 self.model, **self.model_kwargs, **self.hub_kwargs |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         if self.post_processor is None: |  | ||||||
|             self.post_processor = self.pre_processor |  | ||||||
|         elif isinstance(self.post_processor, str): |  | ||||||
|             self.post_processor = self.post_processor_class.from_pretrained( |  | ||||||
|                 self.post_processor, **self.hub_kwargs |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         if self.device is None: |  | ||||||
|             if self.device_map is not None: |  | ||||||
|                 self.device = list(self.model.hf_device_map.values())[0] |  | ||||||
|             else: |  | ||||||
|                 self.device = PartialState().default_device |  | ||||||
| 
 |  | ||||||
|         if self.device_map is None: |  | ||||||
|             self.model.to(self.device) |  | ||||||
| 
 |  | ||||||
|         super().setup() |  | ||||||
| 
 |  | ||||||
|     def encode(self, raw_inputs): |  | ||||||
|         """ |  | ||||||
|         Uses the `pre_processor` to prepare the inputs for the `model`. |  | ||||||
|         """ |  | ||||||
|         return self.pre_processor(raw_inputs) |  | ||||||
| 
 |  | ||||||
|     def forward(self, inputs): |  | ||||||
|         """ |  | ||||||
|         Sends the inputs through the `model`. |  | ||||||
|         """ |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             return self.model(**inputs) |  | ||||||
| 
 |  | ||||||
|     def decode(self, outputs): |  | ||||||
|         """ |  | ||||||
|         Uses the `post_processor` to decode the model output. |  | ||||||
|         """ |  | ||||||
|         return self.post_processor(outputs) |  | ||||||
| 
 |  | ||||||
|     def __call__(self, *args, **kwargs): |  | ||||||
|         args, kwargs = handle_agent_inputs(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|         if not self.is_initialized: |  | ||||||
|             self.setup() |  | ||||||
| 
 |  | ||||||
|         encoded_inputs = self.encode(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|         tensor_inputs = { |  | ||||||
|             k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) |  | ||||||
|         } |  | ||||||
|         non_tensor_inputs = { |  | ||||||
|             k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor) |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         encoded_inputs = send_to_device(tensor_inputs, self.device) |  | ||||||
|         outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) |  | ||||||
|         outputs = send_to_device(outputs, "cpu") |  | ||||||
|         decoded_outputs = self.decode(outputs) |  | ||||||
| 
 |  | ||||||
|         return handle_agent_outputs(decoded_outputs, self.output_type) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def launch_gradio_demo(tool_class: Tool): | def launch_gradio_demo(tool_class: Tool): | ||||||
|     """ |     """ | ||||||
|     Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes |     Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes | ||||||
|  | @ -1060,6 +968,8 @@ def tool(tool_function: Callable) -> Tool: | ||||||
|             "Tool return type not found: make sure your function has a return type hint!" |             "Tool return type not found: make sure your function has a return type hint!" | ||||||
|         ) |         ) | ||||||
|     class_name = f"{parameters['name'].capitalize()}Tool" |     class_name = f"{parameters['name'].capitalize()}Tool" | ||||||
|  |     if parameters["return"]["type"] == "object": | ||||||
|  |         parameters["return"]["type"] = "any" | ||||||
| 
 | 
 | ||||||
|     class SpecificTool(Tool): |     class SpecificTool(Tool): | ||||||
|         name = parameters["name"] |         name = parameters["name"] | ||||||
|  |  | ||||||
|  | @ -162,3 +162,44 @@ class ToolTests(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|             assert coolfunc.output_type == "number" |             assert coolfunc.output_type == "number" | ||||||
|         assert "docstring has no description for the argument" in str(e) |         assert "docstring has no description for the argument" in str(e) | ||||||
|  | 
 | ||||||
|  |     def test_tool_definition_needs_imports_in_function(self): | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             from datetime import datetime | ||||||
|  |             @tool | ||||||
|  |             def get_current_time() -> str: | ||||||
|  |                 """ | ||||||
|  |                 Gets the current time. | ||||||
|  |                 """ | ||||||
|  |                 return str(datetime.now()) | ||||||
|  |         assert "datetime" in str(e) | ||||||
|  | 
 | ||||||
|  |         # Also test with classic definition | ||||||
|  |         with pytest.raises(Exception) as e: | ||||||
|  |             class GetCurrentTimeTool(Tool): | ||||||
|  |                 name="get_current_time_tool" | ||||||
|  |                 description="Gets the current time" | ||||||
|  |                 inputs = {} | ||||||
|  |                 output_type = "string" | ||||||
|  | 
 | ||||||
|  |                 def forward(self): | ||||||
|  |                     return str(datetime.now()) | ||||||
|  |         assert "datetime" in str(e) | ||||||
|  | 
 | ||||||
|  |         @tool | ||||||
|  |         def get_current_time() -> str: | ||||||
|  |             """ | ||||||
|  |             Gets the current time. | ||||||
|  |             """ | ||||||
|  |             from datetime import datetime | ||||||
|  |             return str(datetime.now()) | ||||||
|  |          | ||||||
|  |         class GetCurrentTimeTool(Tool): | ||||||
|  |             name="get_current_time_tool" | ||||||
|  |             description="Gets the current time" | ||||||
|  |             inputs = {} | ||||||
|  |             output_type = "string" | ||||||
|  | 
 | ||||||
|  |             def forward(self): | ||||||
|  |                 from datetime import datetime | ||||||
|  |                 return str(datetime.now()) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue