feat(settings): Configurable context_window and tokenizer (#1437)
This commit is contained in:
		
							parent
							
								
									6eeb95ec7f
								
							
						
					
					
						commit
						4780540870
					
				|  | @ -1,11 +1,13 @@ | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
|  | from llama_index import set_global_tokenizer | ||||||
| from llama_index.llms import MockLLM | from llama_index.llms import MockLLM | ||||||
| from llama_index.llms.base import LLM | from llama_index.llms.base import LLM | ||||||
|  | from transformers import AutoTokenizer  # type: ignore | ||||||
| 
 | 
 | ||||||
| from private_gpt.components.llm.prompt_helper import get_prompt_style | from private_gpt.components.llm.prompt_helper import get_prompt_style | ||||||
| from private_gpt.paths import models_path | from private_gpt.paths import models_cache_path, models_path | ||||||
| from private_gpt.settings.settings import Settings | from private_gpt.settings.settings import Settings | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -18,6 +20,14 @@ class LLMComponent: | ||||||
|     @inject |     @inject | ||||||
|     def __init__(self, settings: Settings) -> None: |     def __init__(self, settings: Settings) -> None: | ||||||
|         llm_mode = settings.llm.mode |         llm_mode = settings.llm.mode | ||||||
|  |         if settings.llm.tokenizer: | ||||||
|  |             set_global_tokenizer( | ||||||
|  |                 AutoTokenizer.from_pretrained( | ||||||
|  |                     pretrained_model_name_or_path=settings.llm.tokenizer, | ||||||
|  |                     cache_dir=str(models_cache_path), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|         logger.info("Initializing the LLM in mode=%s", llm_mode) |         logger.info("Initializing the LLM in mode=%s", llm_mode) | ||||||
|         match settings.llm.mode: |         match settings.llm.mode: | ||||||
|             case "local": |             case "local": | ||||||
|  | @ -29,9 +39,7 @@ class LLMComponent: | ||||||
|                     model_path=str(models_path / settings.local.llm_hf_model_file), |                     model_path=str(models_path / settings.local.llm_hf_model_file), | ||||||
|                     temperature=0.1, |                     temperature=0.1, | ||||||
|                     max_new_tokens=settings.llm.max_new_tokens, |                     max_new_tokens=settings.llm.max_new_tokens, | ||||||
|                     # llama2 has a context window of 4096 tokens, |                     context_window=settings.llm.context_window, | ||||||
|                     # but we set it lower to allow for some wiggle room |  | ||||||
|                     context_window=3900, |  | ||||||
|                     generate_kwargs={}, |                     generate_kwargs={}, | ||||||
|                     # All to GPU |                     # All to GPU | ||||||
|                     model_kwargs={"n_gpu_layers": -1}, |                     model_kwargs={"n_gpu_layers": -1}, | ||||||
|  | @ -46,6 +54,8 @@ class LLMComponent: | ||||||
| 
 | 
 | ||||||
|                 self.llm = SagemakerLLM( |                 self.llm = SagemakerLLM( | ||||||
|                     endpoint_name=settings.sagemaker.llm_endpoint_name, |                     endpoint_name=settings.sagemaker.llm_endpoint_name, | ||||||
|  |                     max_new_tokens=settings.llm.max_new_tokens, | ||||||
|  |                     context_window=settings.llm.context_window, | ||||||
|                 ) |                 ) | ||||||
|             case "openai": |             case "openai": | ||||||
|                 from llama_index.llms import OpenAI |                 from llama_index.llms import OpenAI | ||||||
|  |  | ||||||
|  | @ -86,6 +86,18 @@ class LLMSettings(BaseModel): | ||||||
|         256, |         256, | ||||||
|         description="The maximum number of token that the LLM is authorized to generate in one completion.", |         description="The maximum number of token that the LLM is authorized to generate in one completion.", | ||||||
|     ) |     ) | ||||||
|  |     context_window: int = Field( | ||||||
|  |         3900, | ||||||
|  |         description="The maximum number of context tokens for the model.", | ||||||
|  |     ) | ||||||
|  |     tokenizer: str = Field( | ||||||
|  |         None, | ||||||
|  |         description="The model id of a predefined tokenizer hosted inside a model repo on " | ||||||
|  |         "huggingface.co. Valid model ids can be located at the root-level, like " | ||||||
|  |         "`bert-base-uncased`, or namespaced under a user or organization name, " | ||||||
|  |         "like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching " | ||||||
|  |         "gpt-3.5-turbo LLM.", | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class VectorstoreSettings(BaseModel): | class VectorstoreSettings(BaseModel): | ||||||
|  |  | ||||||
|  | @ -3,6 +3,7 @@ import os | ||||||
| import argparse | import argparse | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import hf_hub_download, snapshot_download | from huggingface_hub import hf_hub_download, snapshot_download | ||||||
|  | from transformers import AutoTokenizer | ||||||
| 
 | 
 | ||||||
| from private_gpt.paths import models_path, models_cache_path | from private_gpt.paths import models_path, models_cache_path | ||||||
| from private_gpt.settings.settings import settings | from private_gpt.settings.settings import settings | ||||||
|  | @ -15,8 +16,9 @@ if __name__ == '__main__': | ||||||
|     resume_download = args.resume |     resume_download = args.resume | ||||||
| 
 | 
 | ||||||
| os.makedirs(models_path, exist_ok=True) | os.makedirs(models_path, exist_ok=True) | ||||||
| embedding_path = models_path / "embedding" |  | ||||||
| 
 | 
 | ||||||
|  | # Download Embedding model | ||||||
|  | embedding_path = models_path / "embedding" | ||||||
| print(f"Downloading embedding {settings().local.embedding_hf_model_name}") | print(f"Downloading embedding {settings().local.embedding_hf_model_name}") | ||||||
| snapshot_download( | snapshot_download( | ||||||
|     repo_id=settings().local.embedding_hf_model_name, |     repo_id=settings().local.embedding_hf_model_name, | ||||||
|  | @ -24,9 +26,9 @@ snapshot_download( | ||||||
|     local_dir=embedding_path, |     local_dir=embedding_path, | ||||||
| ) | ) | ||||||
| print("Embedding model downloaded!") | print("Embedding model downloaded!") | ||||||
| print("Downloading models for local execution...") |  | ||||||
| 
 | 
 | ||||||
| # Download LLM and create a symlink to the model file | # Download LLM and create a symlink to the model file | ||||||
|  | print(f"Downloading LLM {settings().local.llm_hf_model_file}") | ||||||
| hf_hub_download( | hf_hub_download( | ||||||
|     repo_id=settings().local.llm_hf_repo_id, |     repo_id=settings().local.llm_hf_repo_id, | ||||||
|     filename=settings().local.llm_hf_model_file, |     filename=settings().local.llm_hf_model_file, | ||||||
|  | @ -34,6 +36,14 @@ hf_hub_download( | ||||||
|     local_dir=models_path, |     local_dir=models_path, | ||||||
|     resume_download=resume_download, |     resume_download=resume_download, | ||||||
| ) | ) | ||||||
| 
 |  | ||||||
| print("LLM model downloaded!") | print("LLM model downloaded!") | ||||||
|  | 
 | ||||||
|  | # Download Tokenizer | ||||||
|  | print(f"Downloading tokenizer {settings().llm.tokenizer}") | ||||||
|  | AutoTokenizer.from_pretrained( | ||||||
|  |     pretrained_model_name_or_path=settings().llm.tokenizer, | ||||||
|  |     cache_dir=models_cache_path, | ||||||
|  | ) | ||||||
|  | print("Tokenizer downloaded!") | ||||||
|  | 
 | ||||||
| print("Setup done") | print("Setup done") | ||||||
|  |  | ||||||
|  | @ -34,6 +34,10 @@ ui: | ||||||
| 
 | 
 | ||||||
| llm: | llm: | ||||||
|   mode: local |   mode: local | ||||||
|  |   # Should be matching the selected model | ||||||
|  |   max_new_tokens: 512 | ||||||
|  |   context_window: 32768 | ||||||
|  |   tokenizer: mistralai/Mistral-7B-Instruct-v0.2 | ||||||
| 
 | 
 | ||||||
| embedding: | embedding: | ||||||
|   # Should be matching the value above in most cases |   # Should be matching the value above in most cases | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue