fix: chromadb max batch size (#1087)
This commit is contained in:
		
							parent
							
								
									b46c1087e2
								
							
						
					
					
						commit
						f5a9bf4e37
					
				|  | @ -261,24 +261,6 @@ files = [ | |||
| tests = ["pytest (>=3.2.1,!=3.3.0)"] | ||||
| typecheck = ["mypy"] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "beautifulsoup4" | ||||
| version = "4.12.2" | ||||
| description = "Screen-scraping library" | ||||
| optional = false | ||||
| python-versions = ">=3.6.0" | ||||
| files = [ | ||||
|     {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, | ||||
|     {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, | ||||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
| soupsieve = ">1.2" | ||||
| 
 | ||||
| [package.extras] | ||||
| html5lib = ["html5lib"] | ||||
| lxml = ["lxml"] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "black" | ||||
| version = "22.12.0" | ||||
|  | @ -643,10 +625,7 @@ files = [ | |||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
| numpy = [ | ||||
|     {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}, | ||||
|     {version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""}, | ||||
| ] | ||||
| numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} | ||||
| 
 | ||||
| [package.extras] | ||||
| bokeh = ["bokeh", "selenium"] | ||||
|  | @ -736,13 +715,13 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] | |||
| 
 | ||||
| [[package]] | ||||
| name = "dataclasses-json" | ||||
| version = "0.6.1" | ||||
| version = "0.5.14" | ||||
| description = "Easily serialize dataclasses to and from JSON." | ||||
| optional = false | ||||
| python-versions = ">=3.7,<4.0" | ||||
| python-versions = ">=3.7,<3.13" | ||||
| files = [ | ||||
|     {file = "dataclasses_json-0.6.1-py3-none-any.whl", hash = "sha256:1bd8418a61fe3d588bb0079214d7fb71d44937da40742b787256fd53b26b6c80"}, | ||||
|     {file = "dataclasses_json-0.6.1.tar.gz", hash = "sha256:a53c220c35134ce08211a1057fd0e5bf76dc5331627c6b241cacbc570a89faae"}, | ||||
|     {file = "dataclasses_json-0.5.14-py3-none-any.whl", hash = "sha256:5ec6fed642adb1dbdb4182badb01e0861badfd8fda82e3b67f44b2d1e9d10d21"}, | ||||
|     {file = "dataclasses_json-0.5.14.tar.gz", hash = "sha256:d82896a94c992ffaf689cd1fafc180164e2abdd415b8f94a7f78586af5886236"}, | ||||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
|  | @ -1130,7 +1109,7 @@ files = [ | |||
|     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, | ||||
|     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, | ||||
|     {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, | ||||
|     {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, | ||||
|  | @ -1140,7 +1119,6 @@ files = [ | |||
|     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, | ||||
|     {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, | ||||
|     {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, | ||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, | ||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, | ||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, | ||||
|  | @ -1760,28 +1738,27 @@ test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"] | |||
| 
 | ||||
| [[package]] | ||||
| name = "llama-index" | ||||
| version = "0.8.35" | ||||
| version = "0.8.47" | ||||
| description = "Interface between LLMs and your data" | ||||
| optional = false | ||||
| python-versions = "*" | ||||
| python-versions = ">=3.8.1,<3.12" | ||||
| files = [ | ||||
|     {file = "llama_index-0.8.35-py3-none-any.whl", hash = "sha256:f2f1670320e75a9643b6dc96662038f777866ed543994d18f71ab54329e295ae"}, | ||||
|     {file = "llama_index-0.8.35.tar.gz", hash = "sha256:a8767be9d36ebd538a37e18b0c7f46bb19d9d7ec490ef7582640f75d5d9b5259"}, | ||||
|     {file = "llama_index-0.8.47-py3-none-any.whl", hash = "sha256:7a0e5154637524fb59b30bd3a349fba2ec6092cf2972276da9dfa38bbe82d721"}, | ||||
|     {file = "llama_index-0.8.47.tar.gz", hash = "sha256:f824e7bcf9b6cf3fb98de59d722695a8db327c83b6b7d30071d931b56c14904f"}, | ||||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
| beautifulsoup4 = "*" | ||||
| dataclasses-json = "*" | ||||
| dataclasses-json = ">=0.5.7,<0.6.0" | ||||
| fsspec = ">=2023.5.0" | ||||
| langchain = ">=0.0.293" | ||||
| nest-asyncio = "*" | ||||
| nltk = "*" | ||||
| langchain = ">=0.0.303" | ||||
| nest-asyncio = ">=1.5.8,<2.0.0" | ||||
| nltk = ">=3.8.1,<4.0.0" | ||||
| numpy = "*" | ||||
| openai = ">=0.26.4" | ||||
| pandas = "*" | ||||
| sqlalchemy = ">=2.0.15" | ||||
| SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} | ||||
| tenacity = ">=8.2.0,<9.0.0" | ||||
| tiktoken = "*" | ||||
| tiktoken = ">=0.3.3" | ||||
| typing-extensions = ">=4.5.0" | ||||
| typing-inspect = ">=0.8.0" | ||||
| urllib3 = "<2" | ||||
|  | @ -2397,10 +2374,7 @@ files = [ | |||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
| numpy = [ | ||||
|     {version = ">=1.23.2", markers = "python_version == \"3.11\""}, | ||||
|     {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, | ||||
| ] | ||||
| numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""} | ||||
| python-dateutil = ">=2.8.2" | ||||
| pytz = ">=2020.1" | ||||
| tzdata = ">=2022.1" | ||||
|  | @ -3469,11 +3443,6 @@ files = [ | |||
|     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, | ||||
|     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, | ||||
|     {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, | ||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, | ||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, | ||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, | ||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, | ||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, | ||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, | ||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, | ||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, | ||||
|  | @ -3690,17 +3659,6 @@ files = [ | |||
|     {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "soupsieve" | ||||
| version = "2.5" | ||||
| description = "A modern CSS selector implementation for Beautiful Soup." | ||||
| optional = false | ||||
| python-versions = ">=3.8" | ||||
| files = [ | ||||
|     {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, | ||||
|     {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "sqlalchemy" | ||||
| version = "2.0.22" | ||||
|  | @ -3760,7 +3718,7 @@ files = [ | |||
| ] | ||||
| 
 | ||||
| [package.dependencies] | ||||
| greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} | ||||
| greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} | ||||
| typing-extensions = ">=4.2.0" | ||||
| 
 | ||||
| [package.extras] | ||||
|  | @ -4738,5 +4696,5 @@ multidict = ">=4.0" | |||
| 
 | ||||
| [metadata] | ||||
| lock-version = "2.0" | ||||
| python-versions = ">=3.11,<3.13" | ||||
| content-hash = "c1fa5accdcd9cd81430839398e16d2596b43b1c314b5d2a8a76aa05bbb83a39c" | ||||
| python-versions = ">=3.11,<3.12" | ||||
| content-hash = "56b78ce6a8a6dfbe42b490bcf4ffbf2820f10d5ce70a28c61ee4e357172dab33" | ||||
|  |  | |||
|  | @ -0,0 +1,87 @@ | |||
| from typing import Any | ||||
| 
 | ||||
| from llama_index.schema import BaseNode, MetadataMode | ||||
| from llama_index.vector_stores import ChromaVectorStore | ||||
| from llama_index.vector_stores.chroma import chunk_list | ||||
| from llama_index.vector_stores.utils import node_to_metadata_dict | ||||
| 
 | ||||
| 
 | ||||
| class BatchedChromaVectorStore(ChromaVectorStore): | ||||
|     """Chroma vector store, batching additions to avoid reaching the max batch limit. | ||||
| 
 | ||||
|     In this vector store, embeddings are stored within a ChromaDB collection. | ||||
| 
 | ||||
|     During query time, the index uses ChromaDB to query for the top | ||||
|     k most similar nodes. | ||||
| 
 | ||||
|     Args: | ||||
|         chroma_client (from chromadb.api.API): | ||||
|             API instance | ||||
|         chroma_collection (chromadb.api.models.Collection.Collection): | ||||
|             ChromaDB collection instance | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     chroma_client: Any | None | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         chroma_client: Any, | ||||
|         chroma_collection: Any, | ||||
|         host: str | None = None, | ||||
|         port: str | None = None, | ||||
|         ssl: bool = False, | ||||
|         headers: dict[str, str] | None = None, | ||||
|         collection_kwargs: dict[Any, Any] | None = None, | ||||
|     ) -> None: | ||||
|         super().__init__( | ||||
|             chroma_collection=chroma_collection, | ||||
|             host=host, | ||||
|             port=port, | ||||
|             ssl=ssl, | ||||
|             headers=headers, | ||||
|             collection_kwargs=collection_kwargs or {}, | ||||
|         ) | ||||
|         self.chroma_client = chroma_client | ||||
| 
 | ||||
|     def add(self, nodes: list[BaseNode]) -> list[str]: | ||||
|         """Add nodes to index, batching the insertion to avoid issues. | ||||
| 
 | ||||
|         Args: | ||||
|             nodes: List[BaseNode]: list of nodes with embeddings | ||||
| 
 | ||||
|         """ | ||||
|         if not self.chroma_client: | ||||
|             raise ValueError("Client not initialized") | ||||
| 
 | ||||
|         if not self._collection: | ||||
|             raise ValueError("Collection not initialized") | ||||
| 
 | ||||
|         max_chunk_size = self.chroma_client.max_batch_size | ||||
|         node_chunks = chunk_list(nodes, max_chunk_size) | ||||
| 
 | ||||
|         all_ids = [] | ||||
|         for node_chunk in node_chunks: | ||||
|             embeddings = [] | ||||
|             metadatas = [] | ||||
|             ids = [] | ||||
|             documents = [] | ||||
|             for node in node_chunk: | ||||
|                 embeddings.append(node.get_embedding()) | ||||
|                 metadatas.append( | ||||
|                     node_to_metadata_dict( | ||||
|                         node, remove_text=True, flat_metadata=self.flat_metadata | ||||
|                     ) | ||||
|                 ) | ||||
|                 ids.append(node.node_id) | ||||
|                 documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) | ||||
| 
 | ||||
|             self._collection.add( | ||||
|                 embeddings=embeddings, | ||||
|                 ids=ids, | ||||
|                 metadatas=metadatas, | ||||
|                 documents=documents, | ||||
|             ) | ||||
|             all_ids.extend(ids) | ||||
| 
 | ||||
|         return all_ids | ||||
|  | @ -4,9 +4,9 @@ import chromadb | |||
| from injector import inject, singleton | ||||
| from llama_index import VectorStoreIndex | ||||
| from llama_index.indices.vector_store import VectorIndexRetriever | ||||
| from llama_index.vector_stores import ChromaVectorStore | ||||
| from llama_index.vector_stores.types import VectorStore | ||||
| 
 | ||||
| from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore | ||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||
| from private_gpt.paths import local_data_path | ||||
| 
 | ||||
|  | @ -36,14 +36,16 @@ class VectorStoreComponent: | |||
| 
 | ||||
|     @inject | ||||
|     def __init__(self) -> None: | ||||
|         db = chromadb.PersistentClient( | ||||
|         chroma_client = chromadb.PersistentClient( | ||||
|             path=str((local_data_path / "chroma_db").absolute()) | ||||
|         ) | ||||
|         chroma_collection = db.get_or_create_collection( | ||||
|         chroma_collection = chroma_client.get_or_create_collection( | ||||
|             "make_this_parameterizable_per_api_call" | ||||
|         )  # TODO | ||||
| 
 | ||||
|         self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | ||||
|         self.vector_store = BatchedChromaVectorStore( | ||||
|             chroma_client=chroma_client, chroma_collection=chroma_collection | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def get_retriever( | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ description = "Private GPT" | |||
| authors = ["Zylon <hi@zylon.ai>"] | ||||
| 
 | ||||
| [tool.poetry.dependencies] | ||||
| python = ">=3.11,<3.13" | ||||
| python = ">=3.11,<3.12" | ||||
| fastapi = { extras = ["all"], version = "^0.103.1" } | ||||
| loguru = "^0.7.2" | ||||
| boto3 = "^1.28.56" | ||||
|  | @ -13,7 +13,7 @@ injector = "^0.21.0" | |||
| pyyaml = "^6.0.1" | ||||
| python-multipart = "^0.0.6" | ||||
| pypdf = "^3.16.2" | ||||
| llama-index = "v0.8.35" | ||||
| llama-index = "0.8.47" | ||||
| chromadb = "^0.4.13" | ||||
| watchdog = "^3.0.0" | ||||
| transformers = "^4.34.0" | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| from unittest.mock import PropertyMock, patch | ||||
| 
 | ||||
| from llama_index import Document | ||||
| 
 | ||||
| from private_gpt.server.ingest.ingest_service import IngestService | ||||
| from tests.fixtures.mock_injector import MockInjector | ||||
| 
 | ||||
| 
 | ||||
| def test_save_many_nodes(injector: MockInjector) -> None: | ||||
|     """This is a specific test for a local Chromadb Vector Database setup. | ||||
| 
 | ||||
|     Extend it when we add support for other vector databases in VectorStoreComponent. | ||||
|     """ | ||||
|     with patch( | ||||
|         "chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock | ||||
|     ) as max_batch_size: | ||||
|         # Make max batch size of Chromadb very small | ||||
|         max_batch_size.return_value = 10 | ||||
| 
 | ||||
|         ingest_service = injector.get(IngestService) | ||||
| 
 | ||||
|         documents = [] | ||||
|         for _i in range(100): | ||||
|             documents.append(Document(text="This is a sentence.")) | ||||
| 
 | ||||
|         ingested_docs = ingest_service._save_docs(documents) | ||||
|         assert len(ingested_docs) == len(documents) | ||||
		Loading…
	
		Reference in New Issue