Ingestion Speedup Multiple strategy (#1309)
This commit is contained in:
		
							parent
							
								
									546ba33e6f
								
							
						
					
					
						commit
						bafdd3baf1
					
				|  | @ -1,3 +1,5 @@ | ||||||
|  | import logging | ||||||
|  | 
 | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index import MockEmbedding | from llama_index import MockEmbedding | ||||||
| from llama_index.embeddings.base import BaseEmbedding | from llama_index.embeddings.base import BaseEmbedding | ||||||
|  | @ -5,6 +7,8 @@ from llama_index.embeddings.base import BaseEmbedding | ||||||
| from private_gpt.paths import models_cache_path | from private_gpt.paths import models_cache_path | ||||||
| from private_gpt.settings.settings import Settings | from private_gpt.settings.settings import Settings | ||||||
| 
 | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @singleton | @singleton | ||||||
| class EmbeddingComponent: | class EmbeddingComponent: | ||||||
|  | @ -12,7 +16,9 @@ class EmbeddingComponent: | ||||||
| 
 | 
 | ||||||
|     @inject |     @inject | ||||||
|     def __init__(self, settings: Settings) -> None: |     def __init__(self, settings: Settings) -> None: | ||||||
|         match settings.llm.mode: |         embedding_mode = settings.embedding.mode | ||||||
|  |         logger.info("Initializing the embedding model in mode=%s", embedding_mode) | ||||||
|  |         match embedding_mode: | ||||||
|             case "local": |             case "local": | ||||||
|                 from llama_index.embeddings import HuggingFaceEmbedding |                 from llama_index.embeddings import HuggingFaceEmbedding | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,297 @@ | ||||||
|  | import abc | ||||||
|  | import itertools | ||||||
|  | import logging | ||||||
|  | import multiprocessing | ||||||
|  | import multiprocessing.pool | ||||||
|  | import os | ||||||
|  | import threading | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Any | ||||||
|  | 
 | ||||||
|  | from llama_index import ( | ||||||
|  |     Document, | ||||||
|  |     ServiceContext, | ||||||
|  |     StorageContext, | ||||||
|  |     VectorStoreIndex, | ||||||
|  |     load_index_from_storage, | ||||||
|  | ) | ||||||
|  | from llama_index.data_structs import IndexDict | ||||||
|  | from llama_index.indices.base import BaseIndex | ||||||
|  | from llama_index.ingestion import run_transformations | ||||||
|  | 
 | ||||||
|  | from private_gpt.components.ingest.ingest_helper import IngestionHelper | ||||||
|  | from private_gpt.paths import local_data_path | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BaseIngestComponent(abc.ABC): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         storage_context: StorageContext, | ||||||
|  |         service_context: ServiceContext, | ||||||
|  |         *args: Any, | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         logger.debug("Initializing base ingest component type=%s", type(self).__name__) | ||||||
|  |         self.storage_context = storage_context | ||||||
|  |         self.service_context = service_context | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def ingest(self, file_name: str, file_data: Path) -> list[Document]: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def delete(self, doc_id: str) -> None: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         storage_context: StorageContext, | ||||||
|  |         service_context: ServiceContext, | ||||||
|  |         *args: Any, | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__(storage_context, service_context, *args, **kwargs) | ||||||
|  | 
 | ||||||
|  |         self.show_progress = True | ||||||
|  |         self._index_thread_lock = ( | ||||||
|  |             threading.RLock() | ||||||
|  |         )  # Thread lock! Not Multiprocessing lock | ||||||
|  |         self._index = self._initialize_index() | ||||||
|  | 
 | ||||||
|  |     def _initialize_index(self) -> BaseIndex[IndexDict]: | ||||||
|  |         """Initialize the index from the storage context.""" | ||||||
|  |         try: | ||||||
|  |             # Load the index with store_nodes_override=True to be able to delete them | ||||||
|  |             index = load_index_from_storage( | ||||||
|  |                 storage_context=self.storage_context, | ||||||
|  |                 service_context=self.service_context, | ||||||
|  |                 store_nodes_override=True,  # Force store nodes in index and document stores | ||||||
|  |                 show_progress=self.show_progress, | ||||||
|  |             ) | ||||||
|  |         except ValueError: | ||||||
|  |             # There are no index in the storage context, creating a new one | ||||||
|  |             logger.info("Creating a new vector store index") | ||||||
|  |             index = VectorStoreIndex.from_documents( | ||||||
|  |                 [], | ||||||
|  |                 storage_context=self.storage_context, | ||||||
|  |                 service_context=self.service_context, | ||||||
|  |                 store_nodes_override=True,  # Force store nodes in index and document stores | ||||||
|  |                 show_progress=self.show_progress, | ||||||
|  |             ) | ||||||
|  |             index.storage_context.persist(persist_dir=local_data_path) | ||||||
|  |         return index | ||||||
|  | 
 | ||||||
|  |     def _save_index(self) -> None: | ||||||
|  |         self._index.storage_context.persist(persist_dir=local_data_path) | ||||||
|  | 
 | ||||||
|  |     def delete(self, doc_id: str) -> None: | ||||||
|  |         with self._index_thread_lock: | ||||||
|  |             # Delete the document from the index | ||||||
|  |             self._index.delete_ref_doc(doc_id, delete_from_docstore=True) | ||||||
|  | 
 | ||||||
|  |             # Save the index | ||||||
|  |             self._save_index() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SimpleIngestComponent(BaseIngestComponentWithIndex): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         storage_context: StorageContext, | ||||||
|  |         service_context: ServiceContext, | ||||||
|  |         *args: Any, | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__(storage_context, service_context, *args, **kwargs) | ||||||
|  | 
 | ||||||
|  |     def ingest(self, file_name: str, file_data: Path) -> list[Document]: | ||||||
|  |         logger.info("Ingesting file_name=%s", file_name) | ||||||
|  |         documents = IngestionHelper.transform_file_into_documents(file_name, file_data) | ||||||
|  |         logger.info( | ||||||
|  |             "Transformed file=%s into count=%s documents", file_name, len(documents) | ||||||
|  |         ) | ||||||
|  |         logger.debug("Saving the documents in the index and doc store") | ||||||
|  |         return self._save_docs(documents) | ||||||
|  | 
 | ||||||
|  |     def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: | ||||||
|  |         saved_documents = [] | ||||||
|  |         for file_name, file_data in files: | ||||||
|  |             documents = IngestionHelper.transform_file_into_documents( | ||||||
|  |                 file_name, file_data | ||||||
|  |             ) | ||||||
|  |             saved_documents.extend(self._save_docs(documents)) | ||||||
|  |         return saved_documents | ||||||
|  | 
 | ||||||
|  |     def _save_docs(self, documents: list[Document]) -> list[Document]: | ||||||
|  |         logger.debug("Transforming count=%s documents into nodes", len(documents)) | ||||||
|  |         with self._index_thread_lock: | ||||||
|  |             for document in documents: | ||||||
|  |                 self._index.insert(document, show_progress=True) | ||||||
|  |             logger.debug("Persisting the index and nodes") | ||||||
|  |             # persist the index and nodes | ||||||
|  |             self._save_index() | ||||||
|  |             logger.debug("Persisted the index and nodes") | ||||||
|  |         return documents | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MultiWorkerIngestComponent(BaseIngestComponentWithIndex): | ||||||
|  |     """Parallelize the file reading and parsing on multiple CPU core. | ||||||
|  | 
 | ||||||
|  |     This also makes the embeddings to be computed in batches (on GPU or CPU). | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     BULK_INGEST_WORKER_NUM = max((os.cpu_count() or 1) - 1, 1) | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         storage_context: StorageContext, | ||||||
|  |         service_context: ServiceContext, | ||||||
|  |         *args: Any, | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__(storage_context, service_context, *args, **kwargs) | ||||||
|  |         # Make an efficient use of the CPU and GPU, the embedding | ||||||
|  |         # must be in the transformations | ||||||
|  |         assert ( | ||||||
|  |             len(self.service_context.transformations) >= 2 | ||||||
|  |         ), "Embeddings must be in the transformations" | ||||||
|  | 
 | ||||||
|  |     def ingest(self, file_name: str, file_data: Path) -> list[Document]: | ||||||
|  |         logger.info("Ingesting file_name=%s", file_name) | ||||||
|  |         documents = IngestionHelper.transform_file_into_documents(file_name, file_data) | ||||||
|  |         logger.info( | ||||||
|  |             "Transformed file=%s into count=%s documents", file_name, len(documents) | ||||||
|  |         ) | ||||||
|  |         logger.debug("Saving the documents in the index and doc store") | ||||||
|  |         return self._save_docs(documents) | ||||||
|  | 
 | ||||||
|  |     def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: | ||||||
|  |         with multiprocessing.Pool(processes=self.BULK_INGEST_WORKER_NUM) as pool: | ||||||
|  |             documents = list( | ||||||
|  |                 itertools.chain.from_iterable( | ||||||
|  |                     pool.starmap(IngestionHelper.transform_file_into_documents, files) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         logger.info( | ||||||
|  |             "Transformed count=%s files into count=%s documents", | ||||||
|  |             len(files), | ||||||
|  |             len(documents), | ||||||
|  |         ) | ||||||
|  |         return self._save_docs(documents) | ||||||
|  | 
 | ||||||
|  |     def _save_docs(self, documents: list[Document]) -> list[Document]: | ||||||
|  |         logger.debug("Transforming count=%s documents into nodes", len(documents)) | ||||||
|  |         nodes = run_transformations( | ||||||
|  |             documents,  # type: ignore[arg-type] | ||||||
|  |             self.service_context.transformations, | ||||||
|  |             show_progress=self.show_progress, | ||||||
|  |         ) | ||||||
|  |         # Locking the index to avoid concurrent writes | ||||||
|  |         with self._index_thread_lock: | ||||||
|  |             logger.debug("Inserting count=%s nodes in the index", len(nodes)) | ||||||
|  |             self._index.insert_nodes(nodes, show_progress=True) | ||||||
|  |             for document in documents: | ||||||
|  |                 self._index.docstore.set_document_hash( | ||||||
|  |                     document.get_doc_id(), document.hash | ||||||
|  |                 ) | ||||||
|  |             logger.debug("Persisting the index and nodes") | ||||||
|  |             # persist the index and nodes | ||||||
|  |             self._save_index() | ||||||
|  |             logger.debug("Persisted the index and nodes") | ||||||
|  |         return documents | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ParallelizedIngestComponent(BaseIngestComponentWithIndex): | ||||||
|  |     """Parallelize the file ingestion (file reading, embeddings, and index insertion). | ||||||
|  | 
 | ||||||
|  |     This use the CPU and GPU in parallel (both running at the same time), and | ||||||
|  |     reduce the memory pressure by not loading all the files in memory at the same time. | ||||||
|  | 
 | ||||||
|  |     FIXME: this is not working as well as planned because of the usage of | ||||||
|  |      the multiprocessing worker pool. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     BULK_INGEST_WORKER_NUM = max((os.cpu_count() or 1) - 1, 1) | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         storage_context: StorageContext, | ||||||
|  |         service_context: ServiceContext, | ||||||
|  |         *args: Any, | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__(storage_context, service_context, *args, **kwargs) | ||||||
|  |         # Make an efficient use of the CPU and GPU, the embedding | ||||||
|  |         # must be in the transformations | ||||||
|  |         assert ( | ||||||
|  |             len(self.service_context.transformations) >= 2 | ||||||
|  |         ), "Embeddings must be in the transformations" | ||||||
|  |         # We are doing our own multiprocessing | ||||||
|  |         # To do not collide with the multiprocessing of huggingface, we disable it | ||||||
|  |         os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||||||
|  | 
 | ||||||
|  |     def ingest(self, file_name: str, file_data: Path) -> list[Document]: | ||||||
|  |         logger.info("Ingesting file_name=%s", file_name) | ||||||
|  |         # FIXME there are some cases where the process is not finished | ||||||
|  |         #  causing deadlocks. More information using trace: | ||||||
|  |         #  time PGPT_PROFILES=ingest-local python -m trace --trace \ | ||||||
|  |         #    ./scripts/ingest_folder.py ... &> ingestion.traces | ||||||
|  |         with multiprocessing.Pool(processes=1) as pool: | ||||||
|  |             # Running in a single (1) process to release the current | ||||||
|  |             # thread, and take a dedicated CPU core for computation | ||||||
|  |             a_documents = pool.apply_async( | ||||||
|  |                 IngestionHelper.transform_file_into_documents, (file_name, file_data) | ||||||
|  |             ) | ||||||
|  |             while True: | ||||||
|  |                 # FIXME ugly hack to highlight the deadlock in traces | ||||||
|  |                 try: | ||||||
|  |                     documents = list(a_documents.get(timeout=2)) | ||||||
|  |                 except multiprocessing.TimeoutError: | ||||||
|  |                     continue | ||||||
|  |                 break | ||||||
|  |             pool.close() | ||||||
|  |             pool.terminate() | ||||||
|  |         logger.info( | ||||||
|  |             "Transformed file=%s into count=%s documents", file_name, len(documents) | ||||||
|  |         ) | ||||||
|  |         logger.debug("Saving the documents in the index and doc store") | ||||||
|  |         return self._save_docs(documents) | ||||||
|  | 
 | ||||||
|  |     def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: | ||||||
|  |         # Lightweight threads, used for parallelize the | ||||||
|  |         # underlying IO calls made in the ingestion | ||||||
|  |         with multiprocessing.pool.ThreadPool( | ||||||
|  |             processes=self.BULK_INGEST_WORKER_NUM | ||||||
|  |         ) as pool: | ||||||
|  |             documents = list( | ||||||
|  |                 itertools.chain.from_iterable(pool.starmap(self.ingest, files)) | ||||||
|  |             ) | ||||||
|  |         return documents | ||||||
|  | 
 | ||||||
|  |     def _save_docs(self, documents: list[Document]) -> list[Document]: | ||||||
|  |         logger.debug("Transforming count=%s documents into nodes", len(documents)) | ||||||
|  |         nodes = run_transformations( | ||||||
|  |             documents,  # type: ignore[arg-type] | ||||||
|  |             self.service_context.transformations, | ||||||
|  |             show_progress=self.show_progress, | ||||||
|  |         ) | ||||||
|  |         # Locking the index to avoid concurrent writes | ||||||
|  |         with self._index_thread_lock: | ||||||
|  |             logger.debug("Inserting count=%s nodes in the index", len(nodes)) | ||||||
|  |             self._index.insert_nodes(nodes, show_progress=True) | ||||||
|  |             for document in documents: | ||||||
|  |                 self._index.docstore.set_document_hash( | ||||||
|  |                     document.get_doc_id(), document.hash | ||||||
|  |                 ) | ||||||
|  |             logger.debug("Persisting the index and nodes") | ||||||
|  |             # persist the index and nodes | ||||||
|  |             self._save_index() | ||||||
|  |             logger.debug("Persisted the index and nodes") | ||||||
|  |         return documents | ||||||
|  | @ -0,0 +1,61 @@ | ||||||
|  | import logging | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | from llama_index import Document | ||||||
|  | from llama_index.readers import JSONReader, StringIterableReader | ||||||
|  | from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | # Patching the default file reader to support other file types | ||||||
|  | FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy() | ||||||
|  | FILE_READER_CLS.update( | ||||||
|  |     { | ||||||
|  |         ".json": JSONReader, | ||||||
|  |     } | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class IngestionHelper: | ||||||
|  |     """Helper class to transform a file into a list of documents. | ||||||
|  | 
 | ||||||
|  |     This class should be used to transform a file into a list of documents. | ||||||
|  |     These methods are thread-safe (and multiprocessing-safe). | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def transform_file_into_documents( | ||||||
|  |         file_name: str, file_data: Path | ||||||
|  |     ) -> list[Document]: | ||||||
|  |         documents = IngestionHelper._load_file_to_documents(file_name, file_data) | ||||||
|  |         for document in documents: | ||||||
|  |             document.metadata["file_name"] = file_name | ||||||
|  |         IngestionHelper._exclude_metadata(documents) | ||||||
|  |         return documents | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]: | ||||||
|  |         logger.debug("Transforming file_name=%s into documents", file_name) | ||||||
|  |         extension = Path(file_name).suffix | ||||||
|  |         reader_cls = FILE_READER_CLS.get(extension) | ||||||
|  |         if reader_cls is None: | ||||||
|  |             logger.debug( | ||||||
|  |                 "No reader found for extension=%s, using default string reader", | ||||||
|  |                 extension, | ||||||
|  |             ) | ||||||
|  |             # Read as a plain text | ||||||
|  |             string_reader = StringIterableReader() | ||||||
|  |             return string_reader.load_data([file_data.read_text()]) | ||||||
|  | 
 | ||||||
|  |         logger.debug("Specific reader found for extension=%s", extension) | ||||||
|  |         return reader_cls().load_data(file_data) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _exclude_metadata(documents: list[Document]) -> None: | ||||||
|  |         logger.debug("Excluding metadata from count=%s documents", len(documents)) | ||||||
|  |         for document in documents: | ||||||
|  |             document.metadata["doc_id"] = document.doc_id | ||||||
|  |             # We don't want the Embeddings search to receive this metadata | ||||||
|  |             document.excluded_embed_metadata_keys = ["doc_id"] | ||||||
|  |             # We don't want the LLM to receive these metadata in the context | ||||||
|  |             document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"] | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | import logging | ||||||
|  | 
 | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index.llms import MockLLM | from llama_index.llms import MockLLM | ||||||
| from llama_index.llms.base import LLM | from llama_index.llms.base import LLM | ||||||
|  | @ -6,6 +8,8 @@ from private_gpt.components.llm.prompt_helper import get_prompt_style | ||||||
| from private_gpt.paths import models_path | from private_gpt.paths import models_path | ||||||
| from private_gpt.settings.settings import Settings | from private_gpt.settings.settings import Settings | ||||||
| 
 | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @singleton | @singleton | ||||||
| class LLMComponent: | class LLMComponent: | ||||||
|  | @ -13,6 +17,8 @@ class LLMComponent: | ||||||
| 
 | 
 | ||||||
|     @inject |     @inject | ||||||
|     def __init__(self, settings: Settings) -> None: |     def __init__(self, settings: Settings) -> None: | ||||||
|  |         llm_mode = settings.llm.mode | ||||||
|  |         logger.info("Initializing the LLM in mode=%s", llm_mode) | ||||||
|         match settings.llm.mode: |         match settings.llm.mode: | ||||||
|             case "local": |             case "local": | ||||||
|                 from llama_index.llms import LlamaCPP |                 from llama_index.llms import LlamaCPP | ||||||
|  |  | ||||||
|  | @ -12,7 +12,7 @@ from private_gpt.components.vector_store.vector_store_component import ( | ||||||
|     VectorStoreComponent, |     VectorStoreComponent, | ||||||
| ) | ) | ||||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||||
| from private_gpt.server.ingest.ingest_service import IngestedDoc | from private_gpt.server.ingest.model import IngestedDoc | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from llama_index.schema import RelatedNodeInfo |     from llama_index.schema import RelatedNodeInfo | ||||||
|  |  | ||||||
|  | @ -3,7 +3,8 @@ from typing import Literal | ||||||
| from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile | from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService | from private_gpt.server.ingest.ingest_service import IngestService | ||||||
|  | from private_gpt.server.ingest.model import IngestedDoc | ||||||
| from private_gpt.server.utils.auth import authenticated | from private_gpt.server.utils.auth import authenticated | ||||||
| 
 | 
 | ||||||
| ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) | ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) | ||||||
|  | @ -35,7 +36,7 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse: | ||||||
|     service = request.state.injector.get(IngestService) |     service = request.state.injector.get(IngestService) | ||||||
|     if file.filename is None: |     if file.filename is None: | ||||||
|         raise HTTPException(400, "No file name provided") |         raise HTTPException(400, "No file name provided") | ||||||
|     ingested_documents = service.ingest(file.filename, file.file.read()) |     ingested_documents = service.ingest_bin_data(file.filename, file.file) | ||||||
|     return IngestResponse(object="list", model="private-gpt", data=ingested_documents) |     return IngestResponse(object="list", model="private-gpt", data=ingested_documents) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,64 +1,27 @@ | ||||||
| import logging | import logging | ||||||
| import tempfile | import tempfile | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import TYPE_CHECKING, Any, AnyStr, Literal | from typing import BinaryIO | ||||||
| 
 | 
 | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index import ( | from llama_index import ( | ||||||
|     Document, |  | ||||||
|     ServiceContext, |     ServiceContext, | ||||||
|     StorageContext, |     StorageContext, | ||||||
|     VectorStoreIndex, |  | ||||||
|     load_index_from_storage, |  | ||||||
| ) | ) | ||||||
| from llama_index.node_parser import SentenceWindowNodeParser | from llama_index.node_parser import SentenceWindowNodeParser | ||||||
| from llama_index.readers import JSONReader, StringIterableReader |  | ||||||
| from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS |  | ||||||
| from pydantic import BaseModel, Field |  | ||||||
| 
 | 
 | ||||||
| from private_gpt.components.embedding.embedding_component import EmbeddingComponent | from private_gpt.components.embedding.embedding_component import EmbeddingComponent | ||||||
|  | from private_gpt.components.ingest.ingest_component import SimpleIngestComponent | ||||||
| from private_gpt.components.llm.llm_component import LLMComponent | from private_gpt.components.llm.llm_component import LLMComponent | ||||||
| from private_gpt.components.node_store.node_store_component import NodeStoreComponent | from private_gpt.components.node_store.node_store_component import NodeStoreComponent | ||||||
| from private_gpt.components.vector_store.vector_store_component import ( | from private_gpt.components.vector_store.vector_store_component import ( | ||||||
|     VectorStoreComponent, |     VectorStoreComponent, | ||||||
| ) | ) | ||||||
| from private_gpt.paths import local_data_path | from private_gpt.server.ingest.model import IngestedDoc | ||||||
| 
 |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from llama_index.readers.base import BaseReader |  | ||||||
| 
 |  | ||||||
| # Patching the default file reader to support other file types |  | ||||||
| FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy() |  | ||||||
| FILE_READER_CLS.update( |  | ||||||
|     { |  | ||||||
|         ".json": JSONReader, |  | ||||||
|     } |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class IngestedDoc(BaseModel): |  | ||||||
|     object: Literal["ingest.document"] |  | ||||||
|     doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]) |  | ||||||
|     doc_metadata: dict[str, Any] | None = Field( |  | ||||||
|         examples=[ |  | ||||||
|             { |  | ||||||
|                 "page_label": "2", |  | ||||||
|                 "file_name": "Sales Report Q3 2023.pdf", |  | ||||||
|             } |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]: |  | ||||||
|         """Remove unwanted metadata keys.""" |  | ||||||
|         metadata.pop("doc_id", None) |  | ||||||
|         metadata.pop("window", None) |  | ||||||
|         metadata.pop("original_text", None) |  | ||||||
|         return metadata |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @singleton | @singleton | ||||||
| class IngestService: | class IngestService: | ||||||
|     @inject |     @inject | ||||||
|  | @ -75,40 +38,31 @@ class IngestService: | ||||||
|             docstore=node_store_component.doc_store, |             docstore=node_store_component.doc_store, | ||||||
|             index_store=node_store_component.index_store, |             index_store=node_store_component.index_store, | ||||||
|         ) |         ) | ||||||
|  |         node_parser = SentenceWindowNodeParser.from_defaults() | ||||||
|         self.ingest_service_context = ServiceContext.from_defaults( |         self.ingest_service_context = ServiceContext.from_defaults( | ||||||
|             llm=self.llm_service.llm, |             llm=self.llm_service.llm, | ||||||
|             embed_model=embedding_component.embedding_model, |             embed_model=embedding_component.embedding_model, | ||||||
|             node_parser=SentenceWindowNodeParser.from_defaults(), |             node_parser=node_parser, | ||||||
|  |             # Embeddings done early in the pipeline of node transformations, right | ||||||
|  |             # after the node parsing | ||||||
|  |             transformations=[node_parser, embedding_component.embedding_model], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def ingest(self, file_name: str, file_data: AnyStr | Path) -> list[IngestedDoc]: |         self.ingest_component = SimpleIngestComponent( | ||||||
|         logger.info("Ingesting file_name=%s", file_name) |             self.storage_context, self.ingest_service_context | ||||||
|         extension = Path(file_name).suffix |  | ||||||
|         reader_cls = FILE_READER_CLS.get(extension) |  | ||||||
|         documents: list[Document] |  | ||||||
|         if reader_cls is None: |  | ||||||
|             logger.debug( |  | ||||||
|                 "No reader found for extension=%s, using default string reader", |  | ||||||
|                 extension, |  | ||||||
|         ) |         ) | ||||||
|             # Read as a plain text | 
 | ||||||
|             string_reader = StringIterableReader() |     def ingest(self, file_name: str, file_data: Path) -> list[IngestedDoc]: | ||||||
|             if isinstance(file_data, Path): |         logger.info("Ingesting file_name=%s", file_name) | ||||||
|                 text = file_data.read_text() |         documents = self.ingest_component.ingest(file_name, file_data) | ||||||
|                 documents = string_reader.load_data([text]) |         return [IngestedDoc.from_document(document) for document in documents] | ||||||
|             elif isinstance(file_data, bytes): | 
 | ||||||
|                 documents = string_reader.load_data([file_data.decode("utf-8")]) |     def ingest_bin_data( | ||||||
|             elif isinstance(file_data, str): |         self, file_name: str, raw_file_data: BinaryIO | ||||||
|                 documents = string_reader.load_data([file_data]) |     ) -> list[IngestedDoc]: | ||||||
|             else: |         logger.debug("Ingesting binary data with file_name=%s", file_name) | ||||||
|                 raise ValueError(f"Unsupported data type {type(file_data)}") |         file_data = raw_file_data.read() | ||||||
|         else: |         logger.debug("Got file data of size=%s to ingest", len(file_data)) | ||||||
|             logger.debug("Specific reader found for extension=%s", extension) |  | ||||||
|             reader: BaseReader = reader_cls() |  | ||||||
|             if isinstance(file_data, Path): |  | ||||||
|                 # Already a path, nothing to do |  | ||||||
|                 documents = reader.load_data(file_data) |  | ||||||
|             else: |  | ||||||
|         # llama-index mainly supports reading from files, so |         # llama-index mainly supports reading from files, so | ||||||
|         # we have to create a tmp file to read for it to work |         # we have to create a tmp file to read for it to work | ||||||
|         # delete=False to avoid a Windows 11 permission error. |         # delete=False to avoid a Windows 11 permission error. | ||||||
|  | @ -119,55 +73,15 @@ class IngestService: | ||||||
|                     path_to_tmp.write_bytes(file_data) |                     path_to_tmp.write_bytes(file_data) | ||||||
|                 else: |                 else: | ||||||
|                     path_to_tmp.write_text(str(file_data)) |                     path_to_tmp.write_text(str(file_data)) | ||||||
|                         documents = reader.load_data(path_to_tmp) |                 return self.ingest(file_name, path_to_tmp) | ||||||
|             finally: |             finally: | ||||||
|                 tmp.close() |                 tmp.close() | ||||||
|                 path_to_tmp.unlink() |                 path_to_tmp.unlink() | ||||||
|         logger.info( |  | ||||||
|             "Transformed file=%s into count=%s documents", file_name, len(documents) |  | ||||||
|         ) |  | ||||||
|         for document in documents: |  | ||||||
|             document.metadata["file_name"] = file_name |  | ||||||
|         return self._save_docs(documents) |  | ||||||
| 
 | 
 | ||||||
|     def _save_docs(self, documents: list[Document]) -> list[IngestedDoc]: |     def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]: | ||||||
|         for document in documents: |         logger.info("Ingesting file_names=%s", [f[0] for f in files]) | ||||||
|             document.metadata["doc_id"] = document.doc_id |         documents = self.ingest_component.bulk_ingest(files) | ||||||
|             # We don't want the Embeddings search to receive this metadata |         return [IngestedDoc.from_document(document) for document in documents] | ||||||
|             document.excluded_embed_metadata_keys = ["doc_id"] |  | ||||||
|             # We don't want the LLM to receive these metadata in the context |  | ||||||
|             document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"] |  | ||||||
| 
 |  | ||||||
|         try: |  | ||||||
|             # Load the index from storage and insert new documents, |  | ||||||
|             index = load_index_from_storage( |  | ||||||
|                 storage_context=self.storage_context, |  | ||||||
|                 service_context=self.ingest_service_context, |  | ||||||
|                 store_nodes_override=True,  # Force store nodes in index and document stores |  | ||||||
|                 show_progress=True, |  | ||||||
|             ) |  | ||||||
|             for doc in documents: |  | ||||||
|                 index.insert(doc) |  | ||||||
|         except ValueError: |  | ||||||
|             # Or create a new one if there is none |  | ||||||
|             VectorStoreIndex.from_documents( |  | ||||||
|                 documents, |  | ||||||
|                 storage_context=self.storage_context, |  | ||||||
|                 service_context=self.ingest_service_context, |  | ||||||
|                 store_nodes_override=True,  # Force store nodes in index and document stores |  | ||||||
|                 show_progress=True, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         # persist the index and nodes |  | ||||||
|         self.storage_context.persist(persist_dir=local_data_path) |  | ||||||
|         return [ |  | ||||||
|             IngestedDoc( |  | ||||||
|                 object="ingest.document", |  | ||||||
|                 doc_id=document.doc_id, |  | ||||||
|                 doc_metadata=IngestedDoc.curate_metadata(document.metadata), |  | ||||||
|             ) |  | ||||||
|             for document in documents |  | ||||||
|         ] |  | ||||||
| 
 | 
 | ||||||
|     def list_ingested(self) -> list[IngestedDoc]: |     def list_ingested(self) -> list[IngestedDoc]: | ||||||
|         ingested_docs = [] |         ingested_docs = [] | ||||||
|  | @ -205,17 +119,4 @@ class IngestService: | ||||||
|         logger.info( |         logger.info( | ||||||
|             "Deleting the ingested document=%s in the doc and index store", doc_id |             "Deleting the ingested document=%s in the doc and index store", doc_id | ||||||
|         ) |         ) | ||||||
| 
 |         self.ingest_component.delete(doc_id) | ||||||
|         # Load the index with store_nodes_override=True to be able to delete them |  | ||||||
|         index = load_index_from_storage( |  | ||||||
|             storage_context=self.storage_context, |  | ||||||
|             service_context=self.ingest_service_context, |  | ||||||
|             store_nodes_override=True,  # Force store nodes in index and document stores |  | ||||||
|             show_progress=True, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # Delete the document from the index |  | ||||||
|         index.delete_ref_doc(doc_id, delete_from_docstore=True) |  | ||||||
| 
 |  | ||||||
|         # Save the index |  | ||||||
|         self.storage_context.persist(persist_dir=local_data_path) |  | ||||||
|  |  | ||||||
|  | @ -0,0 +1,32 @@ | ||||||
|  | from typing import Any, Literal | ||||||
|  | 
 | ||||||
|  | from llama_index import Document | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class IngestedDoc(BaseModel): | ||||||
|  |     object: Literal["ingest.document"] | ||||||
|  |     doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]) | ||||||
|  |     doc_metadata: dict[str, Any] | None = Field( | ||||||
|  |         examples=[ | ||||||
|  |             { | ||||||
|  |                 "page_label": "2", | ||||||
|  |                 "file_name": "Sales Report Q3 2023.pdf", | ||||||
|  |             } | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]: | ||||||
|  |         """Remove unwanted metadata keys.""" | ||||||
|  |         for key in ["doc_id", "window", "original_text"]: | ||||||
|  |             metadata.pop(key, None) | ||||||
|  |         return metadata | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def from_document(document: Document) -> "IngestedDoc": | ||||||
|  |         return IngestedDoc( | ||||||
|  |             object="ingest.document", | ||||||
|  |             doc_id=document.doc_id, | ||||||
|  |             doc_metadata=IngestedDoc.curate_metadata(document.metadata), | ||||||
|  |         ) | ||||||
|  | @ -115,6 +115,10 @@ class LocalSettings(BaseModel): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class EmbeddingSettings(BaseModel): | ||||||
|  |     mode: Literal["local", "openai", "sagemaker", "mock"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class SagemakerSettings(BaseModel): | class SagemakerSettings(BaseModel): | ||||||
|     llm_endpoint_name: str |     llm_endpoint_name: str | ||||||
|     embedding_endpoint_name: str |     embedding_endpoint_name: str | ||||||
|  | @ -188,6 +192,7 @@ class Settings(BaseModel): | ||||||
|     data: DataSettings |     data: DataSettings | ||||||
|     ui: UISettings |     ui: UISettings | ||||||
|     llm: LLMSettings |     llm: LLMSettings | ||||||
|  |     embedding: EmbeddingSettings | ||||||
|     local: LocalSettings |     local: LocalSettings | ||||||
|     sagemaker: SagemakerSettings |     sagemaker: SagemakerSettings | ||||||
|     openai: OpenAISettings |     openai: OpenAISettings | ||||||
|  |  | ||||||
|  | @ -157,10 +157,8 @@ class PrivateGptUi: | ||||||
| 
 | 
 | ||||||
|     def _upload_file(self, files: list[str]) -> None: |     def _upload_file(self, files: list[str]) -> None: | ||||||
|         logger.debug("Loading count=%s files", len(files)) |         logger.debug("Loading count=%s files", len(files)) | ||||||
|         for file in files: |         paths = [Path(file) for file in files] | ||||||
|             logger.info("Loading file=%s", file) |         self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) | ||||||
|             path = Path(file) |  | ||||||
|             self._ingest_service.ingest(file_name=path.name, file_data=path) |  | ||||||
| 
 | 
 | ||||||
|     def _build_ui_blocks(self) -> gr.Blocks: |     def _build_ui_blocks(self) -> gr.Blocks: | ||||||
|         logger.debug("Creating the UI blocks") |         logger.debug("Creating the UI blocks") | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | #!/usr/bin/env python3 | ||||||
|  | 
 | ||||||
| import argparse | import argparse | ||||||
| import logging | import logging | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | @ -8,7 +10,51 @@ from private_gpt.server.ingest.ingest_watcher import IngestWatcher | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| ingest_service = global_injector.get(IngestService) | 
 | ||||||
|  | class LocalIngestWorker: | ||||||
|  |     def __init__(self, ingest_service: IngestService) -> None: | ||||||
|  |         self.ingest_service = ingest_service | ||||||
|  | 
 | ||||||
|  |         self.total_documents = 0 | ||||||
|  |         self.current_document_count = 0 | ||||||
|  | 
 | ||||||
|  |         self._files_under_root_folder: list[Path] = list() | ||||||
|  | 
 | ||||||
|  |     def _find_all_files_in_folder(self, root_path: Path) -> None: | ||||||
|  |         """Search all files under the root folder recursively. | ||||||
|  |         Count them at the same time | ||||||
|  |         """ | ||||||
|  |         for file_path in root_path.iterdir(): | ||||||
|  |             if file_path.is_file(): | ||||||
|  |                 self.total_documents += 1 | ||||||
|  |                 self._files_under_root_folder.append(file_path) | ||||||
|  |             elif file_path.is_dir(): | ||||||
|  |                 self._find_all_files_in_folder(file_path) | ||||||
|  | 
 | ||||||
|  |     def ingest_folder(self, folder_path: Path) -> None: | ||||||
|  |         # Count total documents before ingestion | ||||||
|  |         self._find_all_files_in_folder(folder_path) | ||||||
|  |         self._ingest_all(self._files_under_root_folder) | ||||||
|  | 
 | ||||||
|  |     def _ingest_all(self, files_to_ingest: list[Path]) -> None: | ||||||
|  |         logger.info("Ingesting files=%s", [f.name for f in files_to_ingest]) | ||||||
|  |         self.ingest_service.bulk_ingest([(str(p.name), p) for p in files_to_ingest]) | ||||||
|  | 
 | ||||||
|  |     def ingest_on_watch(self, changed_path: Path) -> None: | ||||||
|  |         logger.info("Detected change in at path=%s, ingesting", changed_path) | ||||||
|  |         self._do_ingest_one(changed_path) | ||||||
|  | 
 | ||||||
|  |     def _do_ingest_one(self, changed_path: Path) -> None: | ||||||
|  |         try: | ||||||
|  |             if changed_path.exists(): | ||||||
|  |                 logger.info(f"Started ingesting file={changed_path}") | ||||||
|  |                 self.ingest_service.ingest(changed_path.name, changed_path) | ||||||
|  |                 logger.info(f"Completed ingesting file={changed_path}") | ||||||
|  |         except Exception: | ||||||
|  |             logger.exception( | ||||||
|  |                 f"Failed to ingest document: {changed_path}, find the exception attached" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| parser = argparse.ArgumentParser(prog="ingest_folder.py") | parser = argparse.ArgumentParser(prog="ingest_folder.py") | ||||||
| parser.add_argument("folder", help="Folder to ingest") | parser.add_argument("folder", help="Folder to ingest") | ||||||
|  | @ -37,53 +83,17 @@ if args.log_file: | ||||||
|     ) |     ) | ||||||
|     logger.addHandler(file_handler) |     logger.addHandler(file_handler) | ||||||
| 
 | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
| 
 | 
 | ||||||
| total_documents = 0 |     root_path = Path(args.folder) | ||||||
| current_document_count = 0 |     if not root_path.exists(): | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def count_documents(folder_path: Path) -> None: |  | ||||||
|     global total_documents |  | ||||||
|     for file_path in folder_path.iterdir(): |  | ||||||
|         if file_path.is_file(): |  | ||||||
|             total_documents += 1 |  | ||||||
|         elif file_path.is_dir(): |  | ||||||
|             count_documents(file_path) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _recursive_ingest_folder(folder_path: Path) -> None: |  | ||||||
|     global current_document_count, total_documents |  | ||||||
|     for file_path in folder_path.iterdir(): |  | ||||||
|         if file_path.is_file(): |  | ||||||
|             current_document_count += 1 |  | ||||||
|             progress_msg = f"Document {current_document_count} of {total_documents} ({(current_document_count / total_documents) * 100:.2f}%)" |  | ||||||
|             logger.info(progress_msg) |  | ||||||
|             _do_ingest(file_path) |  | ||||||
|         elif file_path.is_dir(): |  | ||||||
|             _recursive_ingest_folder(file_path) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _do_ingest(changed_path: Path) -> None: |  | ||||||
|     try: |  | ||||||
|         if changed_path.exists(): |  | ||||||
|             logger.info(f"Started ingesting {changed_path}") |  | ||||||
|             ingest_service.ingest(changed_path.name, changed_path) |  | ||||||
|             logger.info(f"Completed ingesting {changed_path}") |  | ||||||
|     except Exception: |  | ||||||
|         logger.exception( |  | ||||||
|             f"Failed to ingest document: {changed_path}, find the exception attached" |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| path = Path(args.folder) |  | ||||||
| if not path.exists(): |  | ||||||
|         raise ValueError(f"Path {args.folder} does not exist") |         raise ValueError(f"Path {args.folder} does not exist") | ||||||
| 
 | 
 | ||||||
| # Count total documents before ingestion |     ingest_service = global_injector.get(IngestService) | ||||||
| count_documents(path) |     worker = LocalIngestWorker(ingest_service) | ||||||
|  |     worker.ingest_folder(root_path) | ||||||
| 
 | 
 | ||||||
| _recursive_ingest_folder(path) |     if args.watch: | ||||||
| if args.watch: |  | ||||||
|         logger.info(f"Watching {args.folder} for changes, press Ctrl+C to stop...") |         logger.info(f"Watching {args.folder} for changes, press Ctrl+C to stop...") | ||||||
|     watcher = IngestWatcher(args.folder, _do_ingest) |         watcher = IngestWatcher(args.folder, worker.ingest_on_watch) | ||||||
|         watcher.start() |         watcher.start() | ||||||
|  |  | ||||||
|  | @ -22,6 +22,9 @@ ui: | ||||||
| 
 | 
 | ||||||
| llm: | llm: | ||||||
|   mode: local |   mode: local | ||||||
|  | embedding: | ||||||
|  |   # Should be matching the value above in most cases | ||||||
|  |   mode: local | ||||||
| 
 | 
 | ||||||
| vectorstore: | vectorstore: | ||||||
|   database: qdrant |   database: qdrant | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue