Adding default parameter for max_new_tokens in TransformersModel (#604)
This commit is contained in:
parent
a427c84c1c
commit
f3ee6052db
|
@ -599,7 +599,16 @@ class TransformersModel(Model):
|
||||||
model_id = default_model_id
|
model_id = default_model_id
|
||||||
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
|
default_max_tokens = 5000
|
||||||
|
max_new_tokens = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")
|
||||||
|
if not max_new_tokens:
|
||||||
|
kwargs["max_new_tokens"] = default_max_tokens
|
||||||
|
logger.warning(
|
||||||
|
f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
|
||||||
|
)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"Using device: {device_map}")
|
logger.info(f"Using device: {device_map}")
|
||||||
|
|
Loading…
Reference in New Issue