diff --git a/src/smolagents/models.py b/src/smolagents/models.py index a57550a..70ef5d1 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -287,16 +287,14 @@ class TransformersModel(Model): logger.info(f"Using device: {self.device}") try: self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) except Exception as e: logger.warning( f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." ) self.model_id = default_model_id self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to( - self.device - ) + self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria):