Ingestion Speedup Multiple strategy (#1309)
This commit is contained in:
parent
546ba33e6f
commit
bafdd3baf1
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index import MockEmbedding
|
from llama_index import MockEmbedding
|
||||||
from llama_index.embeddings.base import BaseEmbedding
|
from llama_index.embeddings.base import BaseEmbedding
|
||||||
|
@ -5,6 +7,8 @@ from llama_index.embeddings.base import BaseEmbedding
|
||||||
from private_gpt.paths import models_cache_path
|
from private_gpt.paths import models_cache_path
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class EmbeddingComponent:
|
class EmbeddingComponent:
|
||||||
|
@ -12,7 +16,9 @@ class EmbeddingComponent:
|
||||||
|
|
||||||
@inject
|
@inject
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(self, settings: Settings) -> None:
|
||||||
match settings.llm.mode:
|
embedding_mode = settings.embedding.mode
|
||||||
|
logger.info("Initializing the embedding model in mode=%s", embedding_mode)
|
||||||
|
match embedding_mode:
|
||||||
case "local":
|
case "local":
|
||||||
from llama_index.embeddings import HuggingFaceEmbedding
|
from llama_index.embeddings import HuggingFaceEmbedding
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,297 @@
|
||||||
|
import abc
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import multiprocessing.pool
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index import (
|
||||||
|
Document,
|
||||||
|
ServiceContext,
|
||||||
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
|
load_index_from_storage,
|
||||||
|
)
|
||||||
|
from llama_index.data_structs import IndexDict
|
||||||
|
from llama_index.indices.base import BaseIndex
|
||||||
|
from llama_index.ingestion import run_transformations
|
||||||
|
|
||||||
|
from private_gpt.components.ingest.ingest_helper import IngestionHelper
|
||||||
|
from private_gpt.paths import local_data_path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIngestComponent(abc.ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_context: StorageContext,
|
||||||
|
service_context: ServiceContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing base ingest component type=%s", type(self).__name__)
|
||||||
|
self.storage_context = storage_context
|
||||||
|
self.service_context = service_context
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def delete(self, doc_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_context: StorageContext,
|
||||||
|
service_context: ServiceContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(storage_context, service_context, *args, **kwargs)
|
||||||
|
|
||||||
|
self.show_progress = True
|
||||||
|
self._index_thread_lock = (
|
||||||
|
threading.RLock()
|
||||||
|
) # Thread lock! Not Multiprocessing lock
|
||||||
|
self._index = self._initialize_index()
|
||||||
|
|
||||||
|
def _initialize_index(self) -> BaseIndex[IndexDict]:
|
||||||
|
"""Initialize the index from the storage context."""
|
||||||
|
try:
|
||||||
|
# Load the index with store_nodes_override=True to be able to delete them
|
||||||
|
index = load_index_from_storage(
|
||||||
|
storage_context=self.storage_context,
|
||||||
|
service_context=self.service_context,
|
||||||
|
store_nodes_override=True, # Force store nodes in index and document stores
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
# There are no index in the storage context, creating a new one
|
||||||
|
logger.info("Creating a new vector store index")
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
[],
|
||||||
|
storage_context=self.storage_context,
|
||||||
|
service_context=self.service_context,
|
||||||
|
store_nodes_override=True, # Force store nodes in index and document stores
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
)
|
||||||
|
index.storage_context.persist(persist_dir=local_data_path)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def _save_index(self) -> None:
|
||||||
|
self._index.storage_context.persist(persist_dir=local_data_path)
|
||||||
|
|
||||||
|
def delete(self, doc_id: str) -> None:
|
||||||
|
with self._index_thread_lock:
|
||||||
|
# Delete the document from the index
|
||||||
|
self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
||||||
|
|
||||||
|
# Save the index
|
||||||
|
self._save_index()
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleIngestComponent(BaseIngestComponentWithIndex):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_context: StorageContext,
|
||||||
|
service_context: ServiceContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(storage_context, service_context, *args, **kwargs)
|
||||||
|
|
||||||
|
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||||
|
logger.info("Ingesting file_name=%s", file_name)
|
||||||
|
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
||||||
|
logger.info(
|
||||||
|
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||||
|
)
|
||||||
|
logger.debug("Saving the documents in the index and doc store")
|
||||||
|
return self._save_docs(documents)
|
||||||
|
|
||||||
|
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||||
|
saved_documents = []
|
||||||
|
for file_name, file_data in files:
|
||||||
|
documents = IngestionHelper.transform_file_into_documents(
|
||||||
|
file_name, file_data
|
||||||
|
)
|
||||||
|
saved_documents.extend(self._save_docs(documents))
|
||||||
|
return saved_documents
|
||||||
|
|
||||||
|
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||||
|
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||||
|
with self._index_thread_lock:
|
||||||
|
for document in documents:
|
||||||
|
self._index.insert(document, show_progress=True)
|
||||||
|
logger.debug("Persisting the index and nodes")
|
||||||
|
# persist the index and nodes
|
||||||
|
self._save_index()
|
||||||
|
logger.debug("Persisted the index and nodes")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
class MultiWorkerIngestComponent(BaseIngestComponentWithIndex):
|
||||||
|
"""Parallelize the file reading and parsing on multiple CPU core.
|
||||||
|
|
||||||
|
This also makes the embeddings to be computed in batches (on GPU or CPU).
|
||||||
|
"""
|
||||||
|
|
||||||
|
BULK_INGEST_WORKER_NUM = max((os.cpu_count() or 1) - 1, 1)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_context: StorageContext,
|
||||||
|
service_context: ServiceContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(storage_context, service_context, *args, **kwargs)
|
||||||
|
# Make an efficient use of the CPU and GPU, the embedding
|
||||||
|
# must be in the transformations
|
||||||
|
assert (
|
||||||
|
len(self.service_context.transformations) >= 2
|
||||||
|
), "Embeddings must be in the transformations"
|
||||||
|
|
||||||
|
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||||
|
logger.info("Ingesting file_name=%s", file_name)
|
||||||
|
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
||||||
|
logger.info(
|
||||||
|
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||||
|
)
|
||||||
|
logger.debug("Saving the documents in the index and doc store")
|
||||||
|
return self._save_docs(documents)
|
||||||
|
|
||||||
|
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||||
|
with multiprocessing.Pool(processes=self.BULK_INGEST_WORKER_NUM) as pool:
|
||||||
|
documents = list(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
|
pool.starmap(IngestionHelper.transform_file_into_documents, files)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Transformed count=%s files into count=%s documents",
|
||||||
|
len(files),
|
||||||
|
len(documents),
|
||||||
|
)
|
||||||
|
return self._save_docs(documents)
|
||||||
|
|
||||||
|
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||||
|
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||||
|
nodes = run_transformations(
|
||||||
|
documents, # type: ignore[arg-type]
|
||||||
|
self.service_context.transformations,
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
)
|
||||||
|
# Locking the index to avoid concurrent writes
|
||||||
|
with self._index_thread_lock:
|
||||||
|
logger.debug("Inserting count=%s nodes in the index", len(nodes))
|
||||||
|
self._index.insert_nodes(nodes, show_progress=True)
|
||||||
|
for document in documents:
|
||||||
|
self._index.docstore.set_document_hash(
|
||||||
|
document.get_doc_id(), document.hash
|
||||||
|
)
|
||||||
|
logger.debug("Persisting the index and nodes")
|
||||||
|
# persist the index and nodes
|
||||||
|
self._save_index()
|
||||||
|
logger.debug("Persisted the index and nodes")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
|
||||||
|
"""Parallelize the file ingestion (file reading, embeddings, and index insertion).
|
||||||
|
|
||||||
|
This use the CPU and GPU in parallel (both running at the same time), and
|
||||||
|
reduce the memory pressure by not loading all the files in memory at the same time.
|
||||||
|
|
||||||
|
FIXME: this is not working as well as planned because of the usage of
|
||||||
|
the multiprocessing worker pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BULK_INGEST_WORKER_NUM = max((os.cpu_count() or 1) - 1, 1)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_context: StorageContext,
|
||||||
|
service_context: ServiceContext,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(storage_context, service_context, *args, **kwargs)
|
||||||
|
# Make an efficient use of the CPU and GPU, the embedding
|
||||||
|
# must be in the transformations
|
||||||
|
assert (
|
||||||
|
len(self.service_context.transformations) >= 2
|
||||||
|
), "Embeddings must be in the transformations"
|
||||||
|
# We are doing our own multiprocessing
|
||||||
|
# To do not collide with the multiprocessing of huggingface, we disable it
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
||||||
|
logger.info("Ingesting file_name=%s", file_name)
|
||||||
|
# FIXME there are some cases where the process is not finished
|
||||||
|
# causing deadlocks. More information using trace:
|
||||||
|
# time PGPT_PROFILES=ingest-local python -m trace --trace \
|
||||||
|
# ./scripts/ingest_folder.py ... &> ingestion.traces
|
||||||
|
with multiprocessing.Pool(processes=1) as pool:
|
||||||
|
# Running in a single (1) process to release the current
|
||||||
|
# thread, and take a dedicated CPU core for computation
|
||||||
|
a_documents = pool.apply_async(
|
||||||
|
IngestionHelper.transform_file_into_documents, (file_name, file_data)
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
# FIXME ugly hack to highlight the deadlock in traces
|
||||||
|
try:
|
||||||
|
documents = list(a_documents.get(timeout=2))
|
||||||
|
except multiprocessing.TimeoutError:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
pool.close()
|
||||||
|
pool.terminate()
|
||||||
|
logger.info(
|
||||||
|
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
||||||
|
)
|
||||||
|
logger.debug("Saving the documents in the index and doc store")
|
||||||
|
return self._save_docs(documents)
|
||||||
|
|
||||||
|
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
||||||
|
# Lightweight threads, used for parallelize the
|
||||||
|
# underlying IO calls made in the ingestion
|
||||||
|
with multiprocessing.pool.ThreadPool(
|
||||||
|
processes=self.BULK_INGEST_WORKER_NUM
|
||||||
|
) as pool:
|
||||||
|
documents = list(
|
||||||
|
itertools.chain.from_iterable(pool.starmap(self.ingest, files))
|
||||||
|
)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
||||||
|
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
||||||
|
nodes = run_transformations(
|
||||||
|
documents, # type: ignore[arg-type]
|
||||||
|
self.service_context.transformations,
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
)
|
||||||
|
# Locking the index to avoid concurrent writes
|
||||||
|
with self._index_thread_lock:
|
||||||
|
logger.debug("Inserting count=%s nodes in the index", len(nodes))
|
||||||
|
self._index.insert_nodes(nodes, show_progress=True)
|
||||||
|
for document in documents:
|
||||||
|
self._index.docstore.set_document_hash(
|
||||||
|
document.get_doc_id(), document.hash
|
||||||
|
)
|
||||||
|
logger.debug("Persisting the index and nodes")
|
||||||
|
# persist the index and nodes
|
||||||
|
self._save_index()
|
||||||
|
logger.debug("Persisted the index and nodes")
|
||||||
|
return documents
|
|
@ -0,0 +1,61 @@
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_index import Document
|
||||||
|
from llama_index.readers import JSONReader, StringIterableReader
|
||||||
|
from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Patching the default file reader to support other file types
|
||||||
|
FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
|
||||||
|
FILE_READER_CLS.update(
|
||||||
|
{
|
||||||
|
".json": JSONReader,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IngestionHelper:
|
||||||
|
"""Helper class to transform a file into a list of documents.
|
||||||
|
|
||||||
|
This class should be used to transform a file into a list of documents.
|
||||||
|
These methods are thread-safe (and multiprocessing-safe).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transform_file_into_documents(
|
||||||
|
file_name: str, file_data: Path
|
||||||
|
) -> list[Document]:
|
||||||
|
documents = IngestionHelper._load_file_to_documents(file_name, file_data)
|
||||||
|
for document in documents:
|
||||||
|
document.metadata["file_name"] = file_name
|
||||||
|
IngestionHelper._exclude_metadata(documents)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
|
||||||
|
logger.debug("Transforming file_name=%s into documents", file_name)
|
||||||
|
extension = Path(file_name).suffix
|
||||||
|
reader_cls = FILE_READER_CLS.get(extension)
|
||||||
|
if reader_cls is None:
|
||||||
|
logger.debug(
|
||||||
|
"No reader found for extension=%s, using default string reader",
|
||||||
|
extension,
|
||||||
|
)
|
||||||
|
# Read as a plain text
|
||||||
|
string_reader = StringIterableReader()
|
||||||
|
return string_reader.load_data([file_data.read_text()])
|
||||||
|
|
||||||
|
logger.debug("Specific reader found for extension=%s", extension)
|
||||||
|
return reader_cls().load_data(file_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _exclude_metadata(documents: list[Document]) -> None:
|
||||||
|
logger.debug("Excluding metadata from count=%s documents", len(documents))
|
||||||
|
for document in documents:
|
||||||
|
document.metadata["doc_id"] = document.doc_id
|
||||||
|
# We don't want the Embeddings search to receive this metadata
|
||||||
|
document.excluded_embed_metadata_keys = ["doc_id"]
|
||||||
|
# We don't want the LLM to receive these metadata in the context
|
||||||
|
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
|
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index.llms import MockLLM
|
from llama_index.llms import MockLLM
|
||||||
from llama_index.llms.base import LLM
|
from llama_index.llms.base import LLM
|
||||||
|
@ -6,6 +8,8 @@ from private_gpt.components.llm.prompt_helper import get_prompt_style
|
||||||
from private_gpt.paths import models_path
|
from private_gpt.paths import models_path
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class LLMComponent:
|
class LLMComponent:
|
||||||
|
@ -13,6 +17,8 @@ class LLMComponent:
|
||||||
|
|
||||||
@inject
|
@inject
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
llm_mode = settings.llm.mode
|
||||||
|
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||||
match settings.llm.mode:
|
match settings.llm.mode:
|
||||||
case "local":
|
case "local":
|
||||||
from llama_index.llms import LlamaCPP
|
from llama_index.llms import LlamaCPP
|
||||||
|
|
|
@ -12,7 +12,7 @@ from private_gpt.components.vector_store.vector_store_component import (
|
||||||
VectorStoreComponent,
|
VectorStoreComponent,
|
||||||
)
|
)
|
||||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||||
from private_gpt.server.ingest.ingest_service import IngestedDoc
|
from private_gpt.server.ingest.model import IngestedDoc
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llama_index.schema import RelatedNodeInfo
|
from llama_index.schema import RelatedNodeInfo
|
||||||
|
|
|
@ -3,7 +3,8 @@ from typing import Literal
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
|
from private_gpt.server.ingest.ingest_service import IngestService
|
||||||
|
from private_gpt.server.ingest.model import IngestedDoc
|
||||||
from private_gpt.server.utils.auth import authenticated
|
from private_gpt.server.utils.auth import authenticated
|
||||||
|
|
||||||
ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||||
|
@ -35,7 +36,7 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse:
|
||||||
service = request.state.injector.get(IngestService)
|
service = request.state.injector.get(IngestService)
|
||||||
if file.filename is None:
|
if file.filename is None:
|
||||||
raise HTTPException(400, "No file name provided")
|
raise HTTPException(400, "No file name provided")
|
||||||
ingested_documents = service.ingest(file.filename, file.file.read())
|
ingested_documents = service.ingest_bin_data(file.filename, file.file)
|
||||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,64 +1,27 @@
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, AnyStr, Literal
|
from typing import BinaryIO
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index import (
|
from llama_index import (
|
||||||
Document,
|
|
||||||
ServiceContext,
|
ServiceContext,
|
||||||
StorageContext,
|
StorageContext,
|
||||||
VectorStoreIndex,
|
|
||||||
load_index_from_storage,
|
|
||||||
)
|
)
|
||||||
from llama_index.node_parser import SentenceWindowNodeParser
|
from llama_index.node_parser import SentenceWindowNodeParser
|
||||||
from llama_index.readers import JSONReader, StringIterableReader
|
|
||||||
from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||||
|
from private_gpt.components.ingest.ingest_component import SimpleIngestComponent
|
||||||
from private_gpt.components.llm.llm_component import LLMComponent
|
from private_gpt.components.llm.llm_component import LLMComponent
|
||||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||||
from private_gpt.components.vector_store.vector_store_component import (
|
from private_gpt.components.vector_store.vector_store_component import (
|
||||||
VectorStoreComponent,
|
VectorStoreComponent,
|
||||||
)
|
)
|
||||||
from private_gpt.paths import local_data_path
|
from private_gpt.server.ingest.model import IngestedDoc
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from llama_index.readers.base import BaseReader
|
|
||||||
|
|
||||||
# Patching the default file reader to support other file types
|
|
||||||
FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
|
|
||||||
FILE_READER_CLS.update(
|
|
||||||
{
|
|
||||||
".json": JSONReader,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IngestedDoc(BaseModel):
|
|
||||||
object: Literal["ingest.document"]
|
|
||||||
doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
|
|
||||||
doc_metadata: dict[str, Any] | None = Field(
|
|
||||||
examples=[
|
|
||||||
{
|
|
||||||
"page_label": "2",
|
|
||||||
"file_name": "Sales Report Q3 2023.pdf",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Remove unwanted metadata keys."""
|
|
||||||
metadata.pop("doc_id", None)
|
|
||||||
metadata.pop("window", None)
|
|
||||||
metadata.pop("original_text", None)
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class IngestService:
|
class IngestService:
|
||||||
@inject
|
@inject
|
||||||
|
@ -75,99 +38,50 @@ class IngestService:
|
||||||
docstore=node_store_component.doc_store,
|
docstore=node_store_component.doc_store,
|
||||||
index_store=node_store_component.index_store,
|
index_store=node_store_component.index_store,
|
||||||
)
|
)
|
||||||
|
node_parser = SentenceWindowNodeParser.from_defaults()
|
||||||
self.ingest_service_context = ServiceContext.from_defaults(
|
self.ingest_service_context = ServiceContext.from_defaults(
|
||||||
llm=self.llm_service.llm,
|
llm=self.llm_service.llm,
|
||||||
embed_model=embedding_component.embedding_model,
|
embed_model=embedding_component.embedding_model,
|
||||||
node_parser=SentenceWindowNodeParser.from_defaults(),
|
node_parser=node_parser,
|
||||||
|
# Embeddings done early in the pipeline of node transformations, right
|
||||||
|
# after the node parsing
|
||||||
|
transformations=[node_parser, embedding_component.embedding_model],
|
||||||
)
|
)
|
||||||
|
|
||||||
def ingest(self, file_name: str, file_data: AnyStr | Path) -> list[IngestedDoc]:
|
self.ingest_component = SimpleIngestComponent(
|
||||||
|
self.storage_context, self.ingest_service_context
|
||||||
|
)
|
||||||
|
|
||||||
|
def ingest(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
|
||||||
logger.info("Ingesting file_name=%s", file_name)
|
logger.info("Ingesting file_name=%s", file_name)
|
||||||
extension = Path(file_name).suffix
|
documents = self.ingest_component.ingest(file_name, file_data)
|
||||||
reader_cls = FILE_READER_CLS.get(extension)
|
return [IngestedDoc.from_document(document) for document in documents]
|
||||||
documents: list[Document]
|
|
||||||
if reader_cls is None:
|
|
||||||
logger.debug(
|
|
||||||
"No reader found for extension=%s, using default string reader",
|
|
||||||
extension,
|
|
||||||
)
|
|
||||||
# Read as a plain text
|
|
||||||
string_reader = StringIterableReader()
|
|
||||||
if isinstance(file_data, Path):
|
|
||||||
text = file_data.read_text()
|
|
||||||
documents = string_reader.load_data([text])
|
|
||||||
elif isinstance(file_data, bytes):
|
|
||||||
documents = string_reader.load_data([file_data.decode("utf-8")])
|
|
||||||
elif isinstance(file_data, str):
|
|
||||||
documents = string_reader.load_data([file_data])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported data type {type(file_data)}")
|
|
||||||
else:
|
|
||||||
logger.debug("Specific reader found for extension=%s", extension)
|
|
||||||
reader: BaseReader = reader_cls()
|
|
||||||
if isinstance(file_data, Path):
|
|
||||||
# Already a path, nothing to do
|
|
||||||
documents = reader.load_data(file_data)
|
|
||||||
else:
|
|
||||||
# llama-index mainly supports reading from files, so
|
|
||||||
# we have to create a tmp file to read for it to work
|
|
||||||
# delete=False to avoid a Windows 11 permission error.
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
|
||||||
try:
|
|
||||||
path_to_tmp = Path(tmp.name)
|
|
||||||
if isinstance(file_data, bytes):
|
|
||||||
path_to_tmp.write_bytes(file_data)
|
|
||||||
else:
|
|
||||||
path_to_tmp.write_text(str(file_data))
|
|
||||||
documents = reader.load_data(path_to_tmp)
|
|
||||||
finally:
|
|
||||||
tmp.close()
|
|
||||||
path_to_tmp.unlink()
|
|
||||||
logger.info(
|
|
||||||
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
|
||||||
)
|
|
||||||
for document in documents:
|
|
||||||
document.metadata["file_name"] = file_name
|
|
||||||
return self._save_docs(documents)
|
|
||||||
|
|
||||||
def _save_docs(self, documents: list[Document]) -> list[IngestedDoc]:
|
def ingest_bin_data(
|
||||||
for document in documents:
|
self, file_name: str, raw_file_data: BinaryIO
|
||||||
document.metadata["doc_id"] = document.doc_id
|
) -> list[IngestedDoc]:
|
||||||
# We don't want the Embeddings search to receive this metadata
|
logger.debug("Ingesting binary data with file_name=%s", file_name)
|
||||||
document.excluded_embed_metadata_keys = ["doc_id"]
|
file_data = raw_file_data.read()
|
||||||
# We don't want the LLM to receive these metadata in the context
|
logger.debug("Got file data of size=%s to ingest", len(file_data))
|
||||||
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
|
# llama-index mainly supports reading from files, so
|
||||||
|
# we have to create a tmp file to read for it to work
|
||||||
|
# delete=False to avoid a Windows 11 permission error.
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||||
|
try:
|
||||||
|
path_to_tmp = Path(tmp.name)
|
||||||
|
if isinstance(file_data, bytes):
|
||||||
|
path_to_tmp.write_bytes(file_data)
|
||||||
|
else:
|
||||||
|
path_to_tmp.write_text(str(file_data))
|
||||||
|
return self.ingest(file_name, path_to_tmp)
|
||||||
|
finally:
|
||||||
|
tmp.close()
|
||||||
|
path_to_tmp.unlink()
|
||||||
|
|
||||||
try:
|
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
|
||||||
# Load the index from storage and insert new documents,
|
logger.info("Ingesting file_names=%s", [f[0] for f in files])
|
||||||
index = load_index_from_storage(
|
documents = self.ingest_component.bulk_ingest(files)
|
||||||
storage_context=self.storage_context,
|
return [IngestedDoc.from_document(document) for document in documents]
|
||||||
service_context=self.ingest_service_context,
|
|
||||||
store_nodes_override=True, # Force store nodes in index and document stores
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
for doc in documents:
|
|
||||||
index.insert(doc)
|
|
||||||
except ValueError:
|
|
||||||
# Or create a new one if there is none
|
|
||||||
VectorStoreIndex.from_documents(
|
|
||||||
documents,
|
|
||||||
storage_context=self.storage_context,
|
|
||||||
service_context=self.ingest_service_context,
|
|
||||||
store_nodes_override=True, # Force store nodes in index and document stores
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# persist the index and nodes
|
|
||||||
self.storage_context.persist(persist_dir=local_data_path)
|
|
||||||
return [
|
|
||||||
IngestedDoc(
|
|
||||||
object="ingest.document",
|
|
||||||
doc_id=document.doc_id,
|
|
||||||
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
|
|
||||||
)
|
|
||||||
for document in documents
|
|
||||||
]
|
|
||||||
|
|
||||||
def list_ingested(self) -> list[IngestedDoc]:
|
def list_ingested(self) -> list[IngestedDoc]:
|
||||||
ingested_docs = []
|
ingested_docs = []
|
||||||
|
@ -205,17 +119,4 @@ class IngestService:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Deleting the ingested document=%s in the doc and index store", doc_id
|
"Deleting the ingested document=%s in the doc and index store", doc_id
|
||||||
)
|
)
|
||||||
|
self.ingest_component.delete(doc_id)
|
||||||
# Load the index with store_nodes_override=True to be able to delete them
|
|
||||||
index = load_index_from_storage(
|
|
||||||
storage_context=self.storage_context,
|
|
||||||
service_context=self.ingest_service_context,
|
|
||||||
store_nodes_override=True, # Force store nodes in index and document stores
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete the document from the index
|
|
||||||
index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
|
||||||
|
|
||||||
# Save the index
|
|
||||||
self.storage_context.persist(persist_dir=local_data_path)
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from llama_index import Document
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class IngestedDoc(BaseModel):
|
||||||
|
object: Literal["ingest.document"]
|
||||||
|
doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
|
||||||
|
doc_metadata: dict[str, Any] | None = Field(
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"page_label": "2",
|
||||||
|
"file_name": "Sales Report Q3 2023.pdf",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Remove unwanted metadata keys."""
|
||||||
|
for key in ["doc_id", "window", "original_text"]:
|
||||||
|
metadata.pop(key, None)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_document(document: Document) -> "IngestedDoc":
|
||||||
|
return IngestedDoc(
|
||||||
|
object="ingest.document",
|
||||||
|
doc_id=document.doc_id,
|
||||||
|
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
|
||||||
|
)
|
|
@ -115,6 +115,10 @@ class LocalSettings(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSettings(BaseModel):
|
||||||
|
mode: Literal["local", "openai", "sagemaker", "mock"]
|
||||||
|
|
||||||
|
|
||||||
class SagemakerSettings(BaseModel):
|
class SagemakerSettings(BaseModel):
|
||||||
llm_endpoint_name: str
|
llm_endpoint_name: str
|
||||||
embedding_endpoint_name: str
|
embedding_endpoint_name: str
|
||||||
|
@ -188,6 +192,7 @@ class Settings(BaseModel):
|
||||||
data: DataSettings
|
data: DataSettings
|
||||||
ui: UISettings
|
ui: UISettings
|
||||||
llm: LLMSettings
|
llm: LLMSettings
|
||||||
|
embedding: EmbeddingSettings
|
||||||
local: LocalSettings
|
local: LocalSettings
|
||||||
sagemaker: SagemakerSettings
|
sagemaker: SagemakerSettings
|
||||||
openai: OpenAISettings
|
openai: OpenAISettings
|
||||||
|
|
|
@ -157,10 +157,8 @@ class PrivateGptUi:
|
||||||
|
|
||||||
def _upload_file(self, files: list[str]) -> None:
|
def _upload_file(self, files: list[str]) -> None:
|
||||||
logger.debug("Loading count=%s files", len(files))
|
logger.debug("Loading count=%s files", len(files))
|
||||||
for file in files:
|
paths = [Path(file) for file in files]
|
||||||
logger.info("Loading file=%s", file)
|
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
|
||||||
path = Path(file)
|
|
||||||
self._ingest_service.ingest(file_name=path.name, file_data=path)
|
|
||||||
|
|
||||||
def _build_ui_blocks(self) -> gr.Blocks:
|
def _build_ui_blocks(self) -> gr.Blocks:
|
||||||
logger.debug("Creating the UI blocks")
|
logger.debug("Creating the UI blocks")
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -8,7 +10,51 @@ from private_gpt.server.ingest.ingest_watcher import IngestWatcher
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ingest_service = global_injector.get(IngestService)
|
|
||||||
|
class LocalIngestWorker:
|
||||||
|
def __init__(self, ingest_service: IngestService) -> None:
|
||||||
|
self.ingest_service = ingest_service
|
||||||
|
|
||||||
|
self.total_documents = 0
|
||||||
|
self.current_document_count = 0
|
||||||
|
|
||||||
|
self._files_under_root_folder: list[Path] = list()
|
||||||
|
|
||||||
|
def _find_all_files_in_folder(self, root_path: Path) -> None:
|
||||||
|
"""Search all files under the root folder recursively.
|
||||||
|
Count them at the same time
|
||||||
|
"""
|
||||||
|
for file_path in root_path.iterdir():
|
||||||
|
if file_path.is_file():
|
||||||
|
self.total_documents += 1
|
||||||
|
self._files_under_root_folder.append(file_path)
|
||||||
|
elif file_path.is_dir():
|
||||||
|
self._find_all_files_in_folder(file_path)
|
||||||
|
|
||||||
|
def ingest_folder(self, folder_path: Path) -> None:
|
||||||
|
# Count total documents before ingestion
|
||||||
|
self._find_all_files_in_folder(folder_path)
|
||||||
|
self._ingest_all(self._files_under_root_folder)
|
||||||
|
|
||||||
|
def _ingest_all(self, files_to_ingest: list[Path]) -> None:
|
||||||
|
logger.info("Ingesting files=%s", [f.name for f in files_to_ingest])
|
||||||
|
self.ingest_service.bulk_ingest([(str(p.name), p) for p in files_to_ingest])
|
||||||
|
|
||||||
|
def ingest_on_watch(self, changed_path: Path) -> None:
|
||||||
|
logger.info("Detected change in at path=%s, ingesting", changed_path)
|
||||||
|
self._do_ingest_one(changed_path)
|
||||||
|
|
||||||
|
def _do_ingest_one(self, changed_path: Path) -> None:
|
||||||
|
try:
|
||||||
|
if changed_path.exists():
|
||||||
|
logger.info(f"Started ingesting file={changed_path}")
|
||||||
|
self.ingest_service.ingest(changed_path.name, changed_path)
|
||||||
|
logger.info(f"Completed ingesting file={changed_path}")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to ingest document: {changed_path}, find the exception attached"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog="ingest_folder.py")
|
parser = argparse.ArgumentParser(prog="ingest_folder.py")
|
||||||
parser.add_argument("folder", help="Folder to ingest")
|
parser.add_argument("folder", help="Folder to ingest")
|
||||||
|
@ -37,53 +83,17 @@ if args.log_file:
|
||||||
)
|
)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
total_documents = 0
|
root_path = Path(args.folder)
|
||||||
current_document_count = 0
|
if not root_path.exists():
|
||||||
|
raise ValueError(f"Path {args.folder} does not exist")
|
||||||
|
|
||||||
|
ingest_service = global_injector.get(IngestService)
|
||||||
|
worker = LocalIngestWorker(ingest_service)
|
||||||
|
worker.ingest_folder(root_path)
|
||||||
|
|
||||||
def count_documents(folder_path: Path) -> None:
|
if args.watch:
|
||||||
global total_documents
|
logger.info(f"Watching {args.folder} for changes, press Ctrl+C to stop...")
|
||||||
for file_path in folder_path.iterdir():
|
watcher = IngestWatcher(args.folder, worker.ingest_on_watch)
|
||||||
if file_path.is_file():
|
watcher.start()
|
||||||
total_documents += 1
|
|
||||||
elif file_path.is_dir():
|
|
||||||
count_documents(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _recursive_ingest_folder(folder_path: Path) -> None:
|
|
||||||
global current_document_count, total_documents
|
|
||||||
for file_path in folder_path.iterdir():
|
|
||||||
if file_path.is_file():
|
|
||||||
current_document_count += 1
|
|
||||||
progress_msg = f"Document {current_document_count} of {total_documents} ({(current_document_count / total_documents) * 100:.2f}%)"
|
|
||||||
logger.info(progress_msg)
|
|
||||||
_do_ingest(file_path)
|
|
||||||
elif file_path.is_dir():
|
|
||||||
_recursive_ingest_folder(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _do_ingest(changed_path: Path) -> None:
|
|
||||||
try:
|
|
||||||
if changed_path.exists():
|
|
||||||
logger.info(f"Started ingesting {changed_path}")
|
|
||||||
ingest_service.ingest(changed_path.name, changed_path)
|
|
||||||
logger.info(f"Completed ingesting {changed_path}")
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"Failed to ingest document: {changed_path}, find the exception attached"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
path = Path(args.folder)
|
|
||||||
if not path.exists():
|
|
||||||
raise ValueError(f"Path {args.folder} does not exist")
|
|
||||||
|
|
||||||
# Count total documents before ingestion
|
|
||||||
count_documents(path)
|
|
||||||
|
|
||||||
_recursive_ingest_folder(path)
|
|
||||||
if args.watch:
|
|
||||||
logger.info(f"Watching {args.folder} for changes, press Ctrl+C to stop...")
|
|
||||||
watcher = IngestWatcher(args.folder, _do_ingest)
|
|
||||||
watcher.start()
|
|
||||||
|
|
|
@ -22,6 +22,9 @@ ui:
|
||||||
|
|
||||||
llm:
|
llm:
|
||||||
mode: local
|
mode: local
|
||||||
|
embedding:
|
||||||
|
# Should be matching the value above in most cases
|
||||||
|
mode: local
|
||||||
|
|
||||||
vectorstore:
|
vectorstore:
|
||||||
database: qdrant
|
database: qdrant
|
||||||
|
|
Loading…
Reference in New Issue