diff --git a/README.md b/README.md index 5d9bcd3..d67c579 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ MODEL_TYPE: supports LlamaCpp or GPT4All PERSIST_DIRECTORY: is the folder you want your vectorstore in MODEL_PATH: Path to your GPT4All or LlamaCpp supported LLM MODEL_N_CTX: Maximum token limit for the LLM model +MODEL_N_BATCH: Number of tokens in the prompt that are fed into the model at a time. Optimal value differs a lot depending on the model (8 works well for GPT4All, and 1024 is better for LlamaCpp) EMBEDDINGS_MODEL_NAME: SentenceTransformers embeddings model name (see https://www.sbert.net/docs/pretrained_models.html) TARGET_SOURCE_CHUNKS: The amount of chunks (sources) that will be used to answer a question ``` diff --git a/example.env b/example.env index 0afe836..143db76 100644 --- a/example.env +++ b/example.env @@ -3,4 +3,5 @@ MODEL_TYPE=GPT4All MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2 MODEL_N_CTX=1000 +MODEL_N_BATCH=8 TARGET_SOURCE_CHUNKS=4 diff --git a/privateGPT.py b/privateGPT.py index bd03d07..a47a745 100755 --- a/privateGPT.py +++ b/privateGPT.py @@ -17,6 +17,7 @@ persist_directory = os.environ.get('PERSIST_DIRECTORY') model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') model_n_ctx = os.environ.get('MODEL_N_CTX') +model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) from constants import CHROMA_SETTINGS @@ -32,9 +33,9 @@ def main(): # Prepare the LLM match model_type: case "LlamaCpp": - llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) + llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False) case "GPT4All": - llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) + llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) case _default: print(f"Model {model_type} not supported!") exit;