Add trust_remote_code arg to TransformersModel (#240)
This commit is contained in:
		
							parent
							
								
									c255c1ff84
								
							
						
					
					
						commit
						11a738e53a
					
				|  | @ -335,6 +335,8 @@ class TransformersModel(Model): | |||
|             The device_map to initialize your model with. | ||||
|         torch_dtype (`str`, *optional*): | ||||
|             The torch_dtype to initialize your model with. | ||||
|         trust_remote_code (bool): | ||||
|             Some models on the Hub require running remote code: for this model, you would have to set this flag to True. | ||||
|         kwargs (dict, *optional*): | ||||
|             Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`. | ||||
|     Raises: | ||||
|  | @ -360,6 +362,7 @@ class TransformersModel(Model): | |||
|         model_id: Optional[str] = None, | ||||
|         device_map: Optional[str] = None, | ||||
|         torch_dtype: Optional[str] = None, | ||||
|         trust_remote_code: bool = False, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | @ -381,7 +384,10 @@ class TransformersModel(Model): | |||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, device_map=device_map, torch_dtype=torch_dtype | ||||
|                 model_id, | ||||
|                 device_map=device_map, | ||||
|                 torch_dtype=torch_dtype, | ||||
|                 trust_remote_code=trust_remote_code, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue