Merge pull request #49 from ScientistIzaak/add-device-parameter
Add device parameter for TransformerModel in models.py
This commit is contained in:
		
						commit
						19143af576
					
				|  | @ -29,6 +29,7 @@ import litellm | |||
| import logging | ||||
| import os | ||||
| import random | ||||
| import torch | ||||
| 
 | ||||
| from huggingface_hub import InferenceClient | ||||
| 
 | ||||
|  | @ -279,9 +280,16 @@ class HfApiModel(Model): | |||
| 
 | ||||
| 
 | ||||
| class TransformersModel(Model): | ||||
|     """This engine initializes a model and tokenizer from the given `model_id`.""" | ||||
|     """This engine initializes a model and tokenizer from the given `model_id`. | ||||
|      | ||||
|     def __init__(self, model_id: Optional[str] = None): | ||||
|         Parameters: | ||||
|             model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`): | ||||
|                 The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. | ||||
|             device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):  | ||||
|                 The device to load the model on (`"cpu"` or `"cuda"`).  | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): | ||||
|         super().__init__() | ||||
|         default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|         if model_id is None: | ||||
|  | @ -290,15 +298,19 @@ class TransformersModel(Model): | |||
|                 f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'" | ||||
|             ) | ||||
|         self.model_id = model_id | ||||
|         if device is None: | ||||
|             device = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         self.device = device | ||||
|         logger.info(f"Using device: {self.device}") | ||||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." | ||||
|             ) | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) | ||||
| 
 | ||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||
|         class StopOnStrings(StoppingCriteria): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue