Update code to use sentence-transformers through huggingfaceembeddings
This commit is contained in:
		
							parent
							
								
									8a5b2f453b
								
							
						
					
					
						commit
						23d24c88e9
					
				|  | @ -1,5 +1,5 @@ | ||||||
| PERSIST_DIRECTORY=db | PERSIST_DIRECTORY=db | ||||||
| LLAMA_EMBEDDINGS_MODEL=models/ggml-model-q4_0.bin |  | ||||||
| MODEL_TYPE=GPT4All | MODEL_TYPE=GPT4All | ||||||
| MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin | MODEL_PATH=models/ggml-gpt4all-j-v1.3-groovy.bin | ||||||
|  | EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2 | ||||||
| MODEL_N_CTX=1000 | MODEL_N_CTX=1000 | ||||||
							
								
								
									
										15
									
								
								ingest.py
								
								
								
								
							
							
						
						
									
										15
									
								
								ingest.py
								
								
								
								
							|  | @ -6,7 +6,7 @@ from dotenv import load_dotenv | ||||||
| from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader | from langchain.document_loaders import TextLoader, PDFMinerLoader, CSVLoader | ||||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||||
| from langchain.vectorstores import Chroma | from langchain.vectorstores import Chroma | ||||||
| from langchain.embeddings import LlamaCppEmbeddings | from langchain.embeddings import HuggingFaceEmbeddings | ||||||
| from langchain.docstore.document import Document | from langchain.docstore.document import Document | ||||||
| from constants import CHROMA_SETTINGS | from constants import CHROMA_SETTINGS | ||||||
| 
 | 
 | ||||||
|  | @ -38,22 +38,23 @@ def main(): | ||||||
|     # Load environment variables |     # Load environment variables | ||||||
|     persist_directory = os.environ.get('PERSIST_DIRECTORY') |     persist_directory = os.environ.get('PERSIST_DIRECTORY') | ||||||
|     source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') |     source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') | ||||||
|     llama_embeddings_model = os.environ.get('LLAMA_EMBEDDINGS_MODEL') |     embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME') | ||||||
|     model_n_ctx = os.environ.get('MODEL_N_CTX') |  | ||||||
| 
 | 
 | ||||||
|     # Load documents and split in chunks |     # Load documents and split in chunks | ||||||
|     print(f"Loading documents from {source_directory}") |     print(f"Loading documents from {source_directory}") | ||||||
|  |     chunk_size = 500 | ||||||
|  |     chunk_overlap = 50 | ||||||
|     documents = load_documents(source_directory) |     documents = load_documents(source_directory) | ||||||
|     text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |     text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||||||
|     texts = text_splitter.split_documents(documents) |     texts = text_splitter.split_documents(documents) | ||||||
|     print(f"Loaded {len(documents)} documents from {source_directory}") |     print(f"Loaded {len(documents)} documents from {source_directory}") | ||||||
|     print(f"Split into {len(texts)} chunks of text (max. 500 tokens each)") |     print(f"Split into {len(texts)} chunks of text (max. {chunk_size} characters each)") | ||||||
| 
 | 
 | ||||||
|     # Create embeddings |     # Create embeddings | ||||||
|     llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) |     embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | ||||||
|      |      | ||||||
|     # Create and store locally vectorstore |     # Create and store locally vectorstore | ||||||
|     db = Chroma.from_documents(texts, llama, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) |     db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) | ||||||
|     db.persist() |     db.persist() | ||||||
|     db = None |     db = None | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,6 +1,6 @@ | ||||||
| from dotenv import load_dotenv | from dotenv import load_dotenv | ||||||
| from langchain.chains import RetrievalQA | from langchain.chains import RetrievalQA | ||||||
| from langchain.embeddings import LlamaCppEmbeddings | from langchain.embeddings import HuggingFaceEmbeddings | ||||||
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | ||||||
| from langchain.vectorstores import Chroma | from langchain.vectorstores import Chroma | ||||||
| from langchain.llms import GPT4All, LlamaCpp | from langchain.llms import GPT4All, LlamaCpp | ||||||
|  | @ -8,7 +8,7 @@ import os | ||||||
| 
 | 
 | ||||||
| load_dotenv() | load_dotenv() | ||||||
| 
 | 
 | ||||||
| llama_embeddings_model = os.environ.get("LLAMA_EMBEDDINGS_MODEL") | embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | ||||||
| persist_directory = os.environ.get('PERSIST_DIRECTORY') | persist_directory = os.environ.get('PERSIST_DIRECTORY') | ||||||
| 
 | 
 | ||||||
| model_type = os.environ.get('MODEL_TYPE') | model_type = os.environ.get('MODEL_TYPE') | ||||||
|  | @ -18,8 +18,8 @@ model_n_ctx = os.environ.get('MODEL_N_CTX') | ||||||
| from constants import CHROMA_SETTINGS | from constants import CHROMA_SETTINGS | ||||||
| 
 | 
 | ||||||
| def main(): | def main(): | ||||||
|     llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) |     embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | ||||||
|     db = Chroma(persist_directory=persist_directory, embedding_function=llama, client_settings=CHROMA_SETTINGS) |     db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | ||||||
|     retriever = db.as_retriever() |     retriever = db.as_retriever() | ||||||
|     # Prepare the LLM |     # Prepare the LLM | ||||||
|     callbacks = [StreamingStdOutCallbackHandler()] |     callbacks = [StreamingStdOutCallbackHandler()] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue