Support n_batch to improve inference performance
This commit is contained in:
parent
52eb020256
commit
ad661933cb
|
@ -26,6 +26,7 @@ MODEL_TYPE: supports LlamaCpp or GPT4All
|
||||||
PERSIST_DIRECTORY: is the folder you want your vectorstore in
|
PERSIST_DIRECTORY: is the folder you want your vectorstore in
|
||||||
MODEL_PATH: Path to your GPT4All or LlamaCpp supported LLM
|
MODEL_PATH: Path to your GPT4All or LlamaCpp supported LLM
|
||||||
MODEL_N_CTX: Maximum token limit for the LLM model
|
MODEL_N_CTX: Maximum token limit for the LLM model
|
||||||
|
MODEL_N_BATCH: Number of tokens in the prompt that are fed into the model at a time. Optimal value differs a lot depending on the model (8 works well for GPT4All, and 1024 is better for LlamaCpp)
|
||||||
EMBEDDINGS_MODEL_NAME: SentenceTransformers embeddings model name (see https://www.sbert.net/docs/pretrained_models.html)
|
EMBEDDINGS_MODEL_NAME: SentenceTransformers embeddings model name (see https://www.sbert.net/docs/pretrained_models.html)
|
||||||
TARGET_SOURCE_CHUNKS: The amount of chunks (sources) that will be used to answer a question
|
TARGET_SOURCE_CHUNKS: The amount of chunks (sources) that will be used to answer a question
|
||||||
```
|
```
|
||||||
|
|
|
@ -3,4 +3,5 @@ MODEL_TYPE=GPT4All
|
||||||
MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin
|
MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin
|
||||||
EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2
|
EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2
|
||||||
MODEL_N_CTX=1000
|
MODEL_N_CTX=1000
|
||||||
|
MODEL_N_BATCH=8
|
||||||
TARGET_SOURCE_CHUNKS=4
|
TARGET_SOURCE_CHUNKS=4
|
||||||
|
|
|
@ -17,6 +17,7 @@ persist_directory = os.environ.get('PERSIST_DIRECTORY')
|
||||||
model_type = os.environ.get('MODEL_TYPE')
|
model_type = os.environ.get('MODEL_TYPE')
|
||||||
model_path = os.environ.get('MODEL_PATH')
|
model_path = os.environ.get('MODEL_PATH')
|
||||||
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
||||||
|
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
|
||||||
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
||||||
|
|
||||||
from constants import CHROMA_SETTINGS
|
from constants import CHROMA_SETTINGS
|
||||||
|
@ -32,9 +33,9 @@ def main():
|
||||||
# Prepare the LLM
|
# Prepare the LLM
|
||||||
match model_type:
|
match model_type:
|
||||||
case "LlamaCpp":
|
case "LlamaCpp":
|
||||||
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
|
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False)
|
||||||
case "GPT4All":
|
case "GPT4All":
|
||||||
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
|
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
|
||||||
case _default:
|
case _default:
|
||||||
print(f"Model {model_type} not supported!")
|
print(f"Model {model_type} not supported!")
|
||||||
exit;
|
exit;
|
||||||
|
|
Loading…
Reference in New Issue