Add trust_remote_code arg to TransformersModel (#240)
This commit is contained in:
		
							parent
							
								
									c255c1ff84
								
							
						
					
					
						commit
						11a738e53a
					
				|  | @ -335,6 +335,8 @@ class TransformersModel(Model): | ||||||
|             The device_map to initialize your model with. |             The device_map to initialize your model with. | ||||||
|         torch_dtype (`str`, *optional*): |         torch_dtype (`str`, *optional*): | ||||||
|             The torch_dtype to initialize your model with. |             The torch_dtype to initialize your model with. | ||||||
|  |         trust_remote_code (bool): | ||||||
|  |             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||||
|         kwargs (dict, *optional*): |         kwargs (dict, *optional*): | ||||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. |             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||||
|     Raises: |     Raises: | ||||||
|  | @ -360,6 +362,7 @@ class TransformersModel(Model): | ||||||
|         model_id: Optional[str] = None, |         model_id: Optional[str] = None, | ||||||
|         device_map: Optional[str] = None, |         device_map: Optional[str] = None, | ||||||
|         torch_dtype: Optional[str] = None, |         torch_dtype: Optional[str] = None, | ||||||
|  |         trust_remote_code: bool = False, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  | @ -381,7 +384,10 @@ class TransformersModel(Model): | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained( |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|                 model_id, device_map=device_map, torch_dtype=torch_dtype |                 model_id, | ||||||
|  |                 device_map=device_map, | ||||||
|  |                 torch_dtype=torch_dtype, | ||||||
|  |                 trust_remote_code=trust_remote_code, | ||||||
|             ) |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue