add device parameter to TransformersModel
This commit is contained in:
		
							parent
							
								
									81388b14f7
								
							
						
					
					
						commit
						12ee33a878
					
				|  | @ -29,6 +29,7 @@ import litellm | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import random | import random | ||||||
|  | import torch | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import InferenceClient | from huggingface_hub import InferenceClient | ||||||
| 
 | 
 | ||||||
|  | @ -304,7 +305,7 @@ class TransformersModel(Model): | ||||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." |                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." | ||||||
|             ) |             ) | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) |             self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue