feat: Add multi-GPU support for TransformersModel (#139)
Add multi-GPU support for TransformersModel
This commit is contained in:
parent
c04e8de825
commit
12a2e6f4b4
|
@ -287,16 +287,14 @@ class TransformersModel(Model):
|
|||
logger.info(f"Using device: {self.device}")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
|
||||
)
|
||||
self.model_id = default_model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
|
||||
self.device
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
|
||||
|
||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||
class StopOnStrings(StoppingCriteria):
|
||||
|
|
Loading…
Reference in New Issue