feat: Add multi-GPU support for TransformersModel (#139)

Add multi-GPU support for TransformersModel
This commit is contained in:
Deng Tongwei 2025-01-14 17:00:08 +08:00 committed by GitHub
parent c04e8de825
commit 12a2e6f4b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 4 deletions

View File

@ -287,16 +287,14 @@ class TransformersModel(Model):
logger.info(f"Using device: {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
)
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
self.device
)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):