Merge pull request #75 from lelayf/fix_space_tool_fwd_sig_validation
Fix forward signature validation in `SpaceToolWrapper`
This commit is contained in:
		
						commit
						c3a11e09da
					
				|  | @ -206,10 +206,10 @@ class Tool: | |||
| 
 | ||||
|         assert getattr(self, "output_type", None) in AUTHORIZED_TYPES | ||||
| 
 | ||||
|         # Validate forward function signature, except for PipelineTool | ||||
|         # Validate forward function signature, except for Tools that use a "generic" signature (PipelineTool, SpaceToolWrapper) | ||||
|         if not ( | ||||
|             hasattr(self, "is_pipeline_tool") | ||||
|             and getattr(self, "is_pipeline_tool") is True | ||||
|             hasattr(self, "skip_forward_signature_validation") | ||||
|             and getattr(self, "skip_forward_signature_validation") is True | ||||
|         ): | ||||
|             signature = inspect.signature(self.forward) | ||||
| 
 | ||||
|  | @ -575,6 +575,9 @@ class Tool: | |||
|         from gradio_client import Client, handle_file | ||||
| 
 | ||||
|         class SpaceToolWrapper(Tool): | ||||
|              | ||||
|             skip_forward_signature_validation = True | ||||
| 
 | ||||
|             def __init__( | ||||
|                 self, | ||||
|                 space_id: str, | ||||
|  | @ -1098,7 +1101,7 @@ class PipelineTool(Tool): | |||
|     name = "pipeline" | ||||
|     inputs = {"prompt": str} | ||||
|     output_type = str | ||||
|     is_pipeline_tool = True | ||||
|     skip_forward_signature_validation = True | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue