Merge pull request #49 from ScientistIzaak/add-device-parameter
Add device parameter for TransformerModel in models.py
This commit is contained in:
commit
19143af576
|
@ -29,6 +29,7 @@ import litellm
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
@ -279,9 +280,16 @@ class HfApiModel(Model):
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(Model):
|
class TransformersModel(Model):
|
||||||
"""This engine initializes a model and tokenizer from the given `model_id`."""
|
"""This engine initializes a model and tokenizer from the given `model_id`.
|
||||||
|
|
||||||
def __init__(self, model_id: Optional[str] = None):
|
Parameters:
|
||||||
|
model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
|
||||||
|
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
||||||
|
device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
|
||||||
|
The device to load the model on (`"cpu"` or `"cuda"`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
|
@ -290,15 +298,19 @@ class TransformersModel(Model):
|
||||||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
||||||
)
|
)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device = device
|
||||||
|
logger.info(f"Using device: {self.device}")
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id)
|
self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {model_id=}."
|
||||||
)
|
)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
|
self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(self.device)
|
||||||
|
|
||||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||||
class StopOnStrings(StoppingCriteria):
|
class StopOnStrings(StoppingCriteria):
|
||||||
|
|
Loading…
Reference in New Issue