From 12ee33a8788305131471778a3cc8b1981b9bf887 Mon Sep 17 00:00:00 2001 From: Izaak Curry Date: Thu, 2 Jan 2025 20:54:32 -0800 Subject: [PATCH] add device parameter to TransformersModel --- src/smolagents/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 6ad0ce9..9c9a749 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -29,6 +29,7 @@ import litellm import logging import os import random +import torch from huggingface_hub import InferenceClient @@ -304,7 +305,7 @@ class TransformersModel(Model): f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}." ) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) - self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device) + self.model = AutoModelForCausalLM.from_pretrained(default_model_id, device_map=device) def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: class StopOnStrings(StoppingCriteria):