Allow passing HF access token to download tokenizer. Fallback to default tokenizer.

This commit is contained in:
imartinez 2024-04-19 15:38:25 +02:00
parent 347be643f7
commit 49ef729abc
3 changed files with 22 additions and 6 deletions

View File

@ -22,13 +22,24 @@ class LLMComponent:
@inject @inject
def __init__(self, settings: Settings) -> None: def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode llm_mode = settings.llm.mode
if settings.llm.tokenizer: if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer( set_global_tokenizer(
AutoTokenizer.from_pretrained( AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer, pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path), cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
) )
) )
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
)
logger.info("Initializing the LLM in mode=%s", llm_mode) logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode: match settings.llm.mode:

View File

@ -151,6 +151,10 @@ class HuggingFaceSettings(BaseModel):
embedding_hf_model_name: str = Field( embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings" description="Name of the HuggingFace model to use for embeddings"
) )
access_token: str = Field(
None,
description="Huggingface access token, required to download some models",
)
class EmbeddingSettings(BaseModel): class EmbeddingSettings(BaseModel):

View File

@ -69,6 +69,7 @@ embedding:
huggingface: huggingface:
embedding_hf_model_name: BAAI/bge-small-en-v1.5 embedding_hf_model_name: BAAI/bge-small-en-v1.5
access_token: ${HUGGINGFACE_TOKEN:}
vectorstore: vectorstore:
database: qdrant database: qdrant