From 12a2e6f4b4eadf94a57034b073672e8b115ac89c Mon Sep 17 00:00:00 2001 From: Deng Tongwei <74892366+6643789wsx@users.noreply.github.com> Date: Tue, 14 Jan 2025 17:00:08 +0800 Subject: [PATCH] feat: Add multi-GPU support for TransformersModel (#139) Add multi-GPU support for TransformersModel --- src/smolagents/models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index a57550a..70ef5d1 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -287,16 +287,14 @@ class TransformersModel(Model): logger.info(f"Using device: {self.device}") try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) except Exception as e: logger.warning( f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." ) self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( - self.device - ) + self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria):