Adding default parameter for max_new_tokens in TransformersModel (#604)

This commit is contained in:
Matthias Freiberger 2025-02-13 12:03:44 +01:00 committed by GitHub
parent a427c84c1c
commit f3ee6052db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 0 deletions

View File

@ -599,7 +599,16 @@ class TransformersModel(Model):
model_id = default_model_id model_id = default_model_id
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'") logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
self.model_id = model_id self.model_id = model_id
default_max_tokens = 5000
max_new_tokens = kwargs.get("max_new_tokens") or kwargs.get("max_tokens")
if not max_new_tokens:
kwargs["max_new_tokens"] = default_max_tokens
logger.warning(
f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
)
self.kwargs = kwargs self.kwargs = kwargs
if device_map is None: if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu" device_map = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_map}") logger.info(f"Using device: {device_map}")