feat: Add multi-GPU support for TransformersModel (#139)
Add multi-GPU support for TransformersModel
This commit is contained in:
		
							parent
							
								
									c04e8de825
								
							
						
					
					
						commit
						12a2e6f4b4
					
				|  | @ -287,16 +287,14 @@ class TransformersModel(Model): | ||||||
|         logger.info(f"Using device: {self.device}") |         logger.info(f"Using device: {self.device}") | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." |                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." | ||||||
|             ) |             ) | ||||||
|             self.model_id = default_model_id |             self.model_id = default_model_id | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( |             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) | ||||||
|                 self.device |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue