From 49ef729abc818f983770f0f3d16c18dfa661a5ef Mon Sep 17 00:00:00 2001 From: imartinez Date: Fri, 19 Apr 2024 15:38:25 +0200 Subject: [PATCH] Allow passing HF access token to download tokenizer. Fallback to default tokenizer. --- private_gpt/components/llm/llm_component.py | 23 +++++++++++++++------ private_gpt/settings/settings.py | 4 ++++ settings.yaml | 1 + 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index dae997c..baffa4e 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -22,13 +22,24 @@ class LLMComponent: @inject def __init__(self, settings: Settings) -> None: llm_mode = settings.llm.mode - if settings.llm.tokenizer: - set_global_tokenizer( - AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=settings.llm.tokenizer, - cache_dir=str(models_cache_path), + if settings.llm.tokenizer and settings.llm.mode != "mock": + # Try to download the tokenizer. If it fails, the LLM will still work + # using the default one, which is less accurate. + try: + set_global_tokenizer( + AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=settings.llm.tokenizer, + cache_dir=str(models_cache_path), + token=settings.huggingface.access_token, + ) + ) + except Exception as e: + logger.warning( + "Failed to download tokenizer %s. Falling back to " + "default tokenizer.", + settings.llm.tokenizer, + e, ) - ) logger.info("Initializing the LLM in mode=%s", llm_mode) match settings.llm.mode: diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 5df6811..051cfca 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -151,6 +151,10 @@ class HuggingFaceSettings(BaseModel): embedding_hf_model_name: str = Field( description="Name of the HuggingFace model to use for embeddings" ) + access_token: str = Field( + None, + description="Huggingface access token, required to download some models", + ) class EmbeddingSettings(BaseModel): diff --git a/settings.yaml b/settings.yaml index dfd719b..e881a55 100644 --- a/settings.yaml +++ b/settings.yaml @@ -69,6 +69,7 @@ embedding: huggingface: embedding_hf_model_name: BAAI/bge-small-en-v1.5 + access_token: ${HUGGINGFACE_TOKEN:} vectorstore: database: qdrant