diff --git a/chroma_preference.py b/chroma_preference.py new file mode 100644 index 0000000..d5e4383 --- /dev/null +++ b/chroma_preference.py @@ -0,0 +1,11 @@ +from chromadb.config import Settings + +# Define the folder for storing database +PERSIST_DIRECTORY = 'db' + +# Define the Chroma settings +CHROMA_SETTINGS = Settings( + chroma_db_impl='duckdb+parquet', + persist_directory=PERSIST_DIRECTORY, + anonymized_telemetry=False +) \ No newline at end of file diff --git a/ingest.py b/ingest.py index e8b08e6..ee900c1 100644 --- a/ingest.py +++ b/ingest.py @@ -3,6 +3,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.embeddings import LlamaCppEmbeddings from sys import argv +from chroma_preference import PERSIST_DIRECTORY +from chroma_preference import CHROMA_SETTINGS def main(): # Load document and split in chunks @@ -13,8 +15,7 @@ def main(): # Create embeddings llama = LlamaCppEmbeddings(model_path="./models/ggml-model-q4_0.bin") # Create and store locally vectorstore - persist_directory = 'db' - db = Chroma.from_documents(texts, llama, persist_directory=persist_directory) + db = Chroma.from_documents(texts, llama, persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS) db.persist() db = None diff --git a/privateGPT.py b/privateGPT.py index 817a5e3..7ebaec1 100644 --- a/privateGPT.py +++ b/privateGPT.py @@ -3,12 +3,13 @@ from langchain.embeddings import LlamaCppEmbeddings from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from langchain.llms import GPT4All +from chroma_preference import PERSIST_DIRECTORY +from chroma_preference import CHROMA_SETTINGS def main(): # Load stored vectorstore llama = LlamaCppEmbeddings(model_path="./models/ggml-model-q4_0.bin") - persist_directory = 'db' - db = Chroma(persist_directory=persist_directory, embedding_function=llama) + db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=llama, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever() # Prepare the LLM callbacks = [StreamingStdOutCallbackHandler()]