feat(rag): expose similarity_top_k and similarity_score to settings (#1771)
* Added RAG settings to settings.py, vector_store and chat_service to add similarity_top_k and similarity_score * Updated settings in vector and chat service per Ivans request * Updated code for mypy
This commit is contained in:
		
							parent
							
								
									774e256052
								
							
						
					
					
						commit
						087cb0b7b7
					
				|  | @ -8,6 +8,9 @@ from llama_index.core.chat_engine.types import ( | ||||||
| from llama_index.core.indices import VectorStoreIndex | from llama_index.core.indices import VectorStoreIndex | ||||||
| from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | ||||||
| from llama_index.core.llms import ChatMessage, MessageRole | from llama_index.core.llms import ChatMessage, MessageRole | ||||||
|  | from llama_index.core.postprocessor import ( | ||||||
|  |     SimilarityPostprocessor, | ||||||
|  | ) | ||||||
| from llama_index.core.storage import StorageContext | from llama_index.core.storage import StorageContext | ||||||
| from llama_index.core.types import TokenGen | from llama_index.core.types import TokenGen | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
|  | @ -20,6 +23,7 @@ from private_gpt.components.vector_store.vector_store_component import ( | ||||||
| ) | ) | ||||||
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | from private_gpt.open_ai.extensions.context_filter import ContextFilter | ||||||
| from private_gpt.server.chunks.chunks_service import Chunk | from private_gpt.server.chunks.chunks_service import Chunk | ||||||
|  | from private_gpt.settings.settings import Settings | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Completion(BaseModel): | class Completion(BaseModel): | ||||||
|  | @ -68,14 +72,18 @@ class ChatEngineInput: | ||||||
| 
 | 
 | ||||||
| @singleton | @singleton | ||||||
| class ChatService: | class ChatService: | ||||||
|  |     settings: Settings | ||||||
|  | 
 | ||||||
|     @inject |     @inject | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|  |         settings: Settings, | ||||||
|         llm_component: LLMComponent, |         llm_component: LLMComponent, | ||||||
|         vector_store_component: VectorStoreComponent, |         vector_store_component: VectorStoreComponent, | ||||||
|         embedding_component: EmbeddingComponent, |         embedding_component: EmbeddingComponent, | ||||||
|         node_store_component: NodeStoreComponent, |         node_store_component: NodeStoreComponent, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|  |         self.settings = settings | ||||||
|         self.llm_component = llm_component |         self.llm_component = llm_component | ||||||
|         self.embedding_component = embedding_component |         self.embedding_component = embedding_component | ||||||
|         self.vector_store_component = vector_store_component |         self.vector_store_component = vector_store_component | ||||||
|  | @ -98,9 +106,12 @@ class ChatService: | ||||||
|         use_context: bool = False, |         use_context: bool = False, | ||||||
|         context_filter: ContextFilter | None = None, |         context_filter: ContextFilter | None = None, | ||||||
|     ) -> BaseChatEngine: |     ) -> BaseChatEngine: | ||||||
|  |         settings = self.settings | ||||||
|         if use_context: |         if use_context: | ||||||
|             vector_index_retriever = self.vector_store_component.get_retriever( |             vector_index_retriever = self.vector_store_component.get_retriever( | ||||||
|                 index=self.index, context_filter=context_filter |                 index=self.index, | ||||||
|  |                 context_filter=context_filter, | ||||||
|  |                 similarity_top_k=self.settings.rag.similarity_top_k, | ||||||
|             ) |             ) | ||||||
|             return ContextChatEngine.from_defaults( |             return ContextChatEngine.from_defaults( | ||||||
|                 system_prompt=system_prompt, |                 system_prompt=system_prompt, | ||||||
|  | @ -108,6 +119,9 @@ class ChatService: | ||||||
|                 llm=self.llm_component.llm,  # Takes no effect at the moment |                 llm=self.llm_component.llm,  # Takes no effect at the moment | ||||||
|                 node_postprocessors=[ |                 node_postprocessors=[ | ||||||
|                     MetadataReplacementPostProcessor(target_metadata_key="window"), |                     MetadataReplacementPostProcessor(target_metadata_key="window"), | ||||||
|  |                     SimilarityPostprocessor( | ||||||
|  |                         similarity_cutoff=settings.rag.similarity_value | ||||||
|  |                     ), | ||||||
|                 ], |                 ], | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|  |  | ||||||
|  | @ -284,6 +284,17 @@ class UISettings(BaseModel): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class RagSettings(BaseModel): | ||||||
|  |     similarity_top_k: int = Field( | ||||||
|  |         2, | ||||||
|  |         description="This value controls the number of documents returned by the RAG pipeline", | ||||||
|  |     ) | ||||||
|  |     similarity_value: float = Field( | ||||||
|  |         None, | ||||||
|  |         description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.", | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class PostgresSettings(BaseModel): | class PostgresSettings(BaseModel): | ||||||
|     host: str = Field( |     host: str = Field( | ||||||
|         "localhost", |         "localhost", | ||||||
|  | @ -379,6 +390,7 @@ class Settings(BaseModel): | ||||||
|     azopenai: AzureOpenAISettings |     azopenai: AzureOpenAISettings | ||||||
|     vectorstore: VectorstoreSettings |     vectorstore: VectorstoreSettings | ||||||
|     nodestore: NodeStoreSettings |     nodestore: NodeStoreSettings | ||||||
|  |     rag: RagSettings | ||||||
|     qdrant: QdrantSettings | None = None |     qdrant: QdrantSettings | None = None | ||||||
|     postgres: PostgresSettings | None = None |     postgres: PostgresSettings | None = None | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -42,6 +42,12 @@ llm: | ||||||
|   tokenizer: mistralai/Mistral-7B-Instruct-v0.2 |   tokenizer: mistralai/Mistral-7B-Instruct-v0.2 | ||||||
|   temperature: 0.1      # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1) |   temperature: 0.1      # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1) | ||||||
| 
 | 
 | ||||||
|  | rag: | ||||||
|  |   similarity_top_k: 2 | ||||||
|  |   #This value controls how many "top" documents the RAG returns to use in the context. | ||||||
|  |   #similarity_value: 0.45 | ||||||
|  |   #This value is disabled by default.  If you enable this settings, the RAG will only use articles that meet a certain percentage score. | ||||||
|  | 
 | ||||||
| llamacpp: | llamacpp: | ||||||
|   prompt_style: "mistral" |   prompt_style: "mistral" | ||||||
|   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF |   llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue