Merge pull request #1 from ScientistIzaak/transformer-model-device
Add device parameter for TransformerModel in models.py
This commit is contained in:
		
						commit
						7dbddb9d7e
					
				|  | @ -280,7 +280,14 @@ class HfApiModel(Model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|     """This engine initializes a model and tokenizer from the given `model_id`.""" |     """This engine initializes a model and tokenizer from the given `model_id`. | ||||||
|  |      | ||||||
|  |         Parameters: | ||||||
|  |             model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): | ||||||
|  |                 The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||||
|  |             device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):  | ||||||
|  |                 The device to load the model on (`"cpu"` or `"cuda"`).  | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): |     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  | @ -294,6 +301,7 @@ class TransformersModel(Model): | ||||||
|         if device is None: |         if device is None: | ||||||
|             device = "cuda" if torch.cuda.is_available() else "cpu" |             device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|         self.device = device |         self.device = device | ||||||
|  |         logger.info(f"Using device: {self.device}") | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |             self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) | ||||||
|  | @ -302,7 +310,7 @@ class TransformersModel(Model): | ||||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." |                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." | ||||||
|             ) |             ) | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device) |             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue