From bdd8c8748b5f045b43ba2005867aa9bbcbd0f862 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Mart=C3=ADnez?= Date: Mon, 8 May 2023 23:41:57 +0200 Subject: [PATCH] Update dependencies. Remove custom gpt4all_j wrapper. --- gpt4all_j.py | 160 ----------------------------------------------- privateGPT.py | 24 +++++-- requirements.txt | 9 ++- 3 files changed, 23 insertions(+), 170 deletions(-) delete mode 100644 gpt4all_j.py diff --git a/gpt4all_j.py b/gpt4all_j.py deleted file mode 100644 index 33aee40..0000000 --- a/gpt4all_j.py +++ /dev/null @@ -1,160 +0,0 @@ -"""Wrapper for the GPT4All-J model.""" -from functools import partial -from typing import Any, Dict, List, Mapping, Optional, Set - -from pydantic import Extra, Field, root_validator - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens - - -class GPT4All_J(LLM): - r"""Wrapper around GPT4All-J language models. - - To use, you should have the ``pygpt4all`` python package installed, the - pre-trained model file, and the model's config information. - - Example: - .. code-block:: python - - from langchain.llms import GPT4All_J - model = GPT4All_J(model="./models/gpt4all-model.bin") - - # Simplest invocation - response = model("Once upon a time, ") - """ - - model: str - """Path to the pre-trained GPT4All model file.""" - - n_threads: Optional[int] = Field(4, alias="n_threads") - """Number of threads to use.""" - - n_predict: Optional[int] = 256 - """The maximum number of tokens to generate.""" - - temp: Optional[float] = 0.8 - """The temperature to use for sampling.""" - - top_p: Optional[float] = 0.95 - """The top-p value to use for sampling.""" - - top_k: Optional[int] = 40 - """The top-k value to use for sampling.""" - - echo: Optional[bool] = False - """Whether to echo the prompt.""" - - stop: Optional[List[str]] = [] - """A list of strings to stop generation when encountered.""" - - repeat_last_n: Optional[int] = 64 - "Last n tokens to penalize" - - repeat_penalty: Optional[float] = 1.3 - """The penalty to apply to repeated tokens.""" - - n_batch: int = Field(1, alias="n_batch") - """Batch size for prompt processing.""" - - streaming: bool = False - """Whether to stream the results or not.""" - - client: Any = None #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return { - "seed": self.seed, - "n_predict": self.n_predict, - "n_threads": self.n_threads, - "n_batch": self.n_batch, - "repeat_last_n": self.repeat_last_n, - "repeat_penalty": self.repeat_penalty, - "top_k": self.top_k, - "top_p": self.top_p, - "temp": self.temp, - } - - @staticmethod - def _llama_param_names() -> Set[str]: - """Get the identifying parameters.""" - return {} - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the python package exists in the environment.""" - try: - from pygpt4all.models.gpt4all_j import GPT4All_J as GPT4AllModel - - llama_keys = cls._llama_param_names() - model_kwargs = {k: v for k, v in values.items() if k in llama_keys} - values["client"] = GPT4AllModel( - model_path=values["model"], - **model_kwargs, - ) - - except ImportError: - raise ValueError( - "Could not import pygpt4all python package. " - "Please install it with `pip install pygpt4all`." - ) - return values - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return { - "model": self.model, - **self._default_params, - **{ - k: v - for k, v in self.__dict__.items() - if k in GPT4All_J._llama_param_names() - }, - } - - @property - def _llm_type(self) -> str: - """Return the type of llm.""" - return "gpt4all" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - ) -> str: - r"""Call out to GPT4All's generate method. - - Args: - prompt: The prompt to pass into the model. - stop: A list of strings to stop generation when encountered. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - prompt = "Once upon a time, " - response = model(prompt, n_predict=55) - """ - if run_manager: - text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) - text = self.client.generate( - prompt, - new_text_callback=text_callback - ) - else: - text = self.client.generate(prompt) - if stop is not None: - text = enforce_stop_tokens(text, stop) - return text \ No newline at end of file diff --git a/privateGPT.py b/privateGPT.py index d45d592..817a5e3 100644 --- a/privateGPT.py +++ b/privateGPT.py @@ -1,8 +1,8 @@ -from gpt4all_j import GPT4All_J from langchain.chains import RetrievalQA from langchain.embeddings import LlamaCppEmbeddings from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma +from langchain.llms import GPT4All def main(): # Load stored vectorstore @@ -12,14 +12,28 @@ def main(): retriever = db.as_retriever() # Prepare the LLM callbacks = [StreamingStdOutCallbackHandler()] - llm = GPT4All_J(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', callbacks=callbacks, verbose=False) - qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) + llm = GPT4All(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', backend='gptj', callbacks=callbacks, verbose=False) + qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) # Interactive questions and answers while True: - query = input("Enter a query: ") + query = input("\nEnter a query: ") if query == "exit": break - qa.run(query) + + # Get the answer from the chain + res = qa(query) + answer, docs = res['result'], res['source_documents'] + + # Print the result + print("\n\n> Question:") + print(query) + print("\n> Answer:") + print(answer) + + # Print the relevant sources used for the answer + for document in docs: + print("\n> " + document.metadata["source"] + ":") + print(document.page_content) if __name__ == "__main__": main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 95a2610..c7049b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -langchain==0.0.154 -pygptj==1.0.10 -pygpt4all==1.0.1 -chromadb==0.3.21 -llama-cpp-python==0.1.41 +langchain==0.0.162 +pygpt4all==1.1.0 +chromadb==0.3.22 +llama-cpp-python==0.1.47