add device parameter to TransformersModel
This commit is contained in:
		
							parent
							
								
									81388b14f7
								
							
						
					
					
						commit
						12ee33a878
					
				|  | @ -29,6 +29,7 @@ import litellm | |||
| import logging | ||||
| import os | ||||
| import random | ||||
| import torch | ||||
| 
 | ||||
| from huggingface_hub import InferenceClient | ||||
| 
 | ||||
|  | @ -304,7 +305,7 @@ class TransformersModel(Model): | |||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." | ||||
|             ) | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device) | ||||
| 
 | ||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||
|         class StopOnStrings(StoppingCriteria): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue