fix: chromadb max batch size (#1087)
This commit is contained in:
		
							parent
							
								
									b46c1087e2
								
							
						
					
					
						commit
						f5a9bf4e37
					
				|  | @ -261,24 +261,6 @@ files = [ | ||||||
| tests = ["pytest (>=3.2.1,!=3.3.0)"] | tests = ["pytest (>=3.2.1,!=3.3.0)"] | ||||||
| typecheck = ["mypy"] | typecheck = ["mypy"] | ||||||
| 
 | 
 | ||||||
| [[package]] |  | ||||||
| name = "beautifulsoup4" |  | ||||||
| version = "4.12.2" |  | ||||||
| description = "Screen-scraping library" |  | ||||||
| optional = false |  | ||||||
| python-versions = ">=3.6.0" |  | ||||||
| files = [ |  | ||||||
|     {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, |  | ||||||
|     {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, |  | ||||||
| ] |  | ||||||
| 
 |  | ||||||
| [package.dependencies] |  | ||||||
| soupsieve = ">1.2" |  | ||||||
| 
 |  | ||||||
| [package.extras] |  | ||||||
| html5lib = ["html5lib"] |  | ||||||
| lxml = ["lxml"] |  | ||||||
| 
 |  | ||||||
| [[package]] | [[package]] | ||||||
| name = "black" | name = "black" | ||||||
| version = "22.12.0" | version = "22.12.0" | ||||||
|  | @ -643,10 +625,7 @@ files = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [package.dependencies] | [package.dependencies] | ||||||
| numpy = [ | numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} | ||||||
|     {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}, |  | ||||||
|     {version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""}, |  | ||||||
| ] |  | ||||||
| 
 | 
 | ||||||
| [package.extras] | [package.extras] | ||||||
| bokeh = ["bokeh", "selenium"] | bokeh = ["bokeh", "selenium"] | ||||||
|  | @ -736,13 +715,13 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "dataclasses-json" | name = "dataclasses-json" | ||||||
| version = "0.6.1" | version = "0.5.14" | ||||||
| description = "Easily serialize dataclasses to and from JSON." | description = "Easily serialize dataclasses to and from JSON." | ||||||
| optional = false | optional = false | ||||||
| python-versions = ">=3.7,<4.0" | python-versions = ">=3.7,<3.13" | ||||||
| files = [ | files = [ | ||||||
|     {file = "dataclasses_json-0.6.1-py3-none-any.whl", hash = "sha256:1bd8418a61fe3d588bb0079214d7fb71d44937da40742b787256fd53b26b6c80"}, |     {file = "dataclasses_json-0.5.14-py3-none-any.whl", hash = "sha256:5ec6fed642adb1dbdb4182badb01e0861badfd8fda82e3b67f44b2d1e9d10d21"}, | ||||||
|     {file = "dataclasses_json-0.6.1.tar.gz", hash = "sha256:a53c220c35134ce08211a1057fd0e5bf76dc5331627c6b241cacbc570a89faae"}, |     {file = "dataclasses_json-0.5.14.tar.gz", hash = "sha256:d82896a94c992ffaf689cd1fafc180164e2abdd415b8f94a7f78586af5886236"}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [package.dependencies] | [package.dependencies] | ||||||
|  | @ -1130,7 +1109,7 @@ files = [ | ||||||
|     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, |     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, | ||||||
|     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, |     {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, | ||||||
|     {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, |     {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, | ||||||
|     {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, |     {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, |     {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, |     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, |     {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, | ||||||
|  | @ -1140,7 +1119,6 @@ files = [ | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, |     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, |     {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, |     {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, | ||||||
|     {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, |  | ||||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, |     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, | ||||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, |     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, | ||||||
|     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, |     {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, | ||||||
|  | @ -1760,28 +1738,27 @@ test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"] | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "llama-index" | name = "llama-index" | ||||||
| version = "0.8.35" | version = "0.8.47" | ||||||
| description = "Interface between LLMs and your data" | description = "Interface between LLMs and your data" | ||||||
| optional = false | optional = false | ||||||
| python-versions = "*" | python-versions = ">=3.8.1,<3.12" | ||||||
| files = [ | files = [ | ||||||
|     {file = "llama_index-0.8.35-py3-none-any.whl", hash = "sha256:f2f1670320e75a9643b6dc96662038f777866ed543994d18f71ab54329e295ae"}, |     {file = "llama_index-0.8.47-py3-none-any.whl", hash = "sha256:7a0e5154637524fb59b30bd3a349fba2ec6092cf2972276da9dfa38bbe82d721"}, | ||||||
|     {file = "llama_index-0.8.35.tar.gz", hash = "sha256:a8767be9d36ebd538a37e18b0c7f46bb19d9d7ec490ef7582640f75d5d9b5259"}, |     {file = "llama_index-0.8.47.tar.gz", hash = "sha256:f824e7bcf9b6cf3fb98de59d722695a8db327c83b6b7d30071d931b56c14904f"}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [package.dependencies] | [package.dependencies] | ||||||
| beautifulsoup4 = "*" | dataclasses-json = ">=0.5.7,<0.6.0" | ||||||
| dataclasses-json = "*" |  | ||||||
| fsspec = ">=2023.5.0" | fsspec = ">=2023.5.0" | ||||||
| langchain = ">=0.0.293" | langchain = ">=0.0.303" | ||||||
| nest-asyncio = "*" | nest-asyncio = ">=1.5.8,<2.0.0" | ||||||
| nltk = "*" | nltk = ">=3.8.1,<4.0.0" | ||||||
| numpy = "*" | numpy = "*" | ||||||
| openai = ">=0.26.4" | openai = ">=0.26.4" | ||||||
| pandas = "*" | pandas = "*" | ||||||
| sqlalchemy = ">=2.0.15" | SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} | ||||||
| tenacity = ">=8.2.0,<9.0.0" | tenacity = ">=8.2.0,<9.0.0" | ||||||
| tiktoken = "*" | tiktoken = ">=0.3.3" | ||||||
| typing-extensions = ">=4.5.0" | typing-extensions = ">=4.5.0" | ||||||
| typing-inspect = ">=0.8.0" | typing-inspect = ">=0.8.0" | ||||||
| urllib3 = "<2" | urllib3 = "<2" | ||||||
|  | @ -2397,10 +2374,7 @@ files = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [package.dependencies] | [package.dependencies] | ||||||
| numpy = [ | numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""} | ||||||
|     {version = ">=1.23.2", markers = "python_version == \"3.11\""}, |  | ||||||
|     {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, |  | ||||||
| ] |  | ||||||
| python-dateutil = ">=2.8.2" | python-dateutil = ">=2.8.2" | ||||||
| pytz = ">=2020.1" | pytz = ">=2020.1" | ||||||
| tzdata = ">=2022.1" | tzdata = ">=2022.1" | ||||||
|  | @ -3469,11 +3443,6 @@ files = [ | ||||||
|     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, |     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, | ||||||
|     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, |     {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, | ||||||
|     {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, |     {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, | ||||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, |  | ||||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, |  | ||||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, |  | ||||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, |  | ||||||
|     {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, |  | ||||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, |     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, | ||||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, |     {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, | ||||||
|     {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, |     {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, | ||||||
|  | @ -3690,17 +3659,6 @@ files = [ | ||||||
|     {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, |     {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [[package]] |  | ||||||
| name = "soupsieve" |  | ||||||
| version = "2.5" |  | ||||||
| description = "A modern CSS selector implementation for Beautiful Soup." |  | ||||||
| optional = false |  | ||||||
| python-versions = ">=3.8" |  | ||||||
| files = [ |  | ||||||
|     {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, |  | ||||||
|     {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, |  | ||||||
| ] |  | ||||||
| 
 |  | ||||||
| [[package]] | [[package]] | ||||||
| name = "sqlalchemy" | name = "sqlalchemy" | ||||||
| version = "2.0.22" | version = "2.0.22" | ||||||
|  | @ -3760,7 +3718,7 @@ files = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| [package.dependencies] | [package.dependencies] | ||||||
| greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} | greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} | ||||||
| typing-extensions = ">=4.2.0" | typing-extensions = ">=4.2.0" | ||||||
| 
 | 
 | ||||||
| [package.extras] | [package.extras] | ||||||
|  | @ -4738,5 +4696,5 @@ multidict = ">=4.0" | ||||||
| 
 | 
 | ||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = ">=3.11,<3.13" | python-versions = ">=3.11,<3.12" | ||||||
| content-hash = "c1fa5accdcd9cd81430839398e16d2596b43b1c314b5d2a8a76aa05bbb83a39c" | content-hash = "56b78ce6a8a6dfbe42b490bcf4ffbf2820f10d5ce70a28c61ee4e357172dab33" | ||||||
|  |  | ||||||
|  | @ -0,0 +1,87 @@ | ||||||
|  | from typing import Any | ||||||
|  | 
 | ||||||
|  | from llama_index.schema import BaseNode, MetadataMode | ||||||
|  | from llama_index.vector_stores import ChromaVectorStore | ||||||
|  | from llama_index.vector_stores.chroma import chunk_list | ||||||
|  | from llama_index.vector_stores.utils import node_to_metadata_dict | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BatchedChromaVectorStore(ChromaVectorStore): | ||||||
|  |     """Chroma vector store, batching additions to avoid reaching the max batch limit. | ||||||
|  | 
 | ||||||
|  |     In this vector store, embeddings are stored within a ChromaDB collection. | ||||||
|  | 
 | ||||||
|  |     During query time, the index uses ChromaDB to query for the top | ||||||
|  |     k most similar nodes. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         chroma_client (from chromadb.api.API): | ||||||
|  |             API instance | ||||||
|  |         chroma_collection (chromadb.api.models.Collection.Collection): | ||||||
|  |             ChromaDB collection instance | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     chroma_client: Any | None | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         chroma_client: Any, | ||||||
|  |         chroma_collection: Any, | ||||||
|  |         host: str | None = None, | ||||||
|  |         port: str | None = None, | ||||||
|  |         ssl: bool = False, | ||||||
|  |         headers: dict[str, str] | None = None, | ||||||
|  |         collection_kwargs: dict[Any, Any] | None = None, | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__( | ||||||
|  |             chroma_collection=chroma_collection, | ||||||
|  |             host=host, | ||||||
|  |             port=port, | ||||||
|  |             ssl=ssl, | ||||||
|  |             headers=headers, | ||||||
|  |             collection_kwargs=collection_kwargs or {}, | ||||||
|  |         ) | ||||||
|  |         self.chroma_client = chroma_client | ||||||
|  | 
 | ||||||
|  |     def add(self, nodes: list[BaseNode]) -> list[str]: | ||||||
|  |         """Add nodes to index, batching the insertion to avoid issues. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             nodes: List[BaseNode]: list of nodes with embeddings | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         if not self.chroma_client: | ||||||
|  |             raise ValueError("Client not initialized") | ||||||
|  | 
 | ||||||
|  |         if not self._collection: | ||||||
|  |             raise ValueError("Collection not initialized") | ||||||
|  | 
 | ||||||
|  |         max_chunk_size = self.chroma_client.max_batch_size | ||||||
|  |         node_chunks = chunk_list(nodes, max_chunk_size) | ||||||
|  | 
 | ||||||
|  |         all_ids = [] | ||||||
|  |         for node_chunk in node_chunks: | ||||||
|  |             embeddings = [] | ||||||
|  |             metadatas = [] | ||||||
|  |             ids = [] | ||||||
|  |             documents = [] | ||||||
|  |             for node in node_chunk: | ||||||
|  |                 embeddings.append(node.get_embedding()) | ||||||
|  |                 metadatas.append( | ||||||
|  |                     node_to_metadata_dict( | ||||||
|  |                         node, remove_text=True, flat_metadata=self.flat_metadata | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |                 ids.append(node.node_id) | ||||||
|  |                 documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) | ||||||
|  | 
 | ||||||
|  |             self._collection.add( | ||||||
|  |                 embeddings=embeddings, | ||||||
|  |                 ids=ids, | ||||||
|  |                 metadatas=metadatas, | ||||||
|  |                 documents=documents, | ||||||
|  |             ) | ||||||
|  |             all_ids.extend(ids) | ||||||
|  | 
 | ||||||
|  |         return all_ids | ||||||
|  | @ -4,9 +4,9 @@ import chromadb | ||||||
| from injector import inject, singleton | from injector import inject, singleton | ||||||
| from llama_index import VectorStoreIndex | from llama_index import VectorStoreIndex | ||||||
| from llama_index.indices.vector_store import VectorIndexRetriever | from llama_index.indices.vector_store import VectorIndexRetriever | ||||||
| from llama_index.vector_stores import ChromaVectorStore |  | ||||||
| from llama_index.vector_stores.types import VectorStore | from llama_index.vector_stores.types import VectorStore | ||||||
| 
 | 
 | ||||||
|  | from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore | ||||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||||
| from private_gpt.paths import local_data_path | from private_gpt.paths import local_data_path | ||||||
| 
 | 
 | ||||||
|  | @ -36,14 +36,16 @@ class VectorStoreComponent: | ||||||
| 
 | 
 | ||||||
|     @inject |     @inject | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|         db = chromadb.PersistentClient( |         chroma_client = chromadb.PersistentClient( | ||||||
|             path=str((local_data_path / "chroma_db").absolute()) |             path=str((local_data_path / "chroma_db").absolute()) | ||||||
|         ) |         ) | ||||||
|         chroma_collection = db.get_or_create_collection( |         chroma_collection = chroma_client.get_or_create_collection( | ||||||
|             "make_this_parameterizable_per_api_call" |             "make_this_parameterizable_per_api_call" | ||||||
|         )  # TODO |         )  # TODO | ||||||
| 
 | 
 | ||||||
|         self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |         self.vector_store = BatchedChromaVectorStore( | ||||||
|  |             chroma_client=chroma_client, chroma_collection=chroma_collection | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def get_retriever( |     def get_retriever( | ||||||
|  |  | ||||||
|  | @ -5,7 +5,7 @@ description = "Private GPT" | ||||||
| authors = ["Zylon <hi@zylon.ai>"] | authors = ["Zylon <hi@zylon.ai>"] | ||||||
| 
 | 
 | ||||||
| [tool.poetry.dependencies] | [tool.poetry.dependencies] | ||||||
| python = ">=3.11,<3.13" | python = ">=3.11,<3.12" | ||||||
| fastapi = { extras = ["all"], version = "^0.103.1" } | fastapi = { extras = ["all"], version = "^0.103.1" } | ||||||
| loguru = "^0.7.2" | loguru = "^0.7.2" | ||||||
| boto3 = "^1.28.56" | boto3 = "^1.28.56" | ||||||
|  | @ -13,7 +13,7 @@ injector = "^0.21.0" | ||||||
| pyyaml = "^6.0.1" | pyyaml = "^6.0.1" | ||||||
| python-multipart = "^0.0.6" | python-multipart = "^0.0.6" | ||||||
| pypdf = "^3.16.2" | pypdf = "^3.16.2" | ||||||
| llama-index = "v0.8.35" | llama-index = "0.8.47" | ||||||
| chromadb = "^0.4.13" | chromadb = "^0.4.13" | ||||||
| watchdog = "^3.0.0" | watchdog = "^3.0.0" | ||||||
| transformers = "^4.34.0" | transformers = "^4.34.0" | ||||||
|  |  | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | from unittest.mock import PropertyMock, patch | ||||||
|  | 
 | ||||||
|  | from llama_index import Document | ||||||
|  | 
 | ||||||
|  | from private_gpt.server.ingest.ingest_service import IngestService | ||||||
|  | from tests.fixtures.mock_injector import MockInjector | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_save_many_nodes(injector: MockInjector) -> None: | ||||||
|  |     """This is a specific test for a local Chromadb Vector Database setup. | ||||||
|  | 
 | ||||||
|  |     Extend it when we add support for other vector databases in VectorStoreComponent. | ||||||
|  |     """ | ||||||
|  |     with patch( | ||||||
|  |         "chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock | ||||||
|  |     ) as max_batch_size: | ||||||
|  |         # Make max batch size of Chromadb very small | ||||||
|  |         max_batch_size.return_value = 10 | ||||||
|  | 
 | ||||||
|  |         ingest_service = injector.get(IngestService) | ||||||
|  | 
 | ||||||
|  |         documents = [] | ||||||
|  |         for _i in range(100): | ||||||
|  |             documents.append(Document(text="This is a sentence.")) | ||||||
|  | 
 | ||||||
|  |         ingested_docs = ingest_service._save_docs(documents) | ||||||
|  |         assert len(ingested_docs) == len(documents) | ||||||
		Loading…
	
		Reference in New Issue