Update dependencies. Remove custom gpt4all_j wrapper.
This commit is contained in:
		
							parent
							
								
									92244a90b4
								
							
						
					
					
						commit
						bdd8c8748b
					
				
							
								
								
									
										160
									
								
								gpt4all_j.py
								
								
								
								
							
							
						
						
									
										160
									
								
								gpt4all_j.py
								
								
								
								
							|  | @ -1,160 +0,0 @@ | ||||||
| """Wrapper for the GPT4All-J model.""" |  | ||||||
| from functools import partial |  | ||||||
| from typing import Any, Dict, List, Mapping, Optional, Set |  | ||||||
| 
 |  | ||||||
| from pydantic import Extra, Field, root_validator |  | ||||||
| 
 |  | ||||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun |  | ||||||
| from langchain.llms.base import LLM |  | ||||||
| from langchain.llms.utils import enforce_stop_tokens |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class GPT4All_J(LLM): |  | ||||||
|     r"""Wrapper around GPT4All-J language models. |  | ||||||
| 
 |  | ||||||
|     To use, you should have the ``pygpt4all`` python package installed, the |  | ||||||
|     pre-trained model file, and the model's config information. |  | ||||||
| 
 |  | ||||||
|     Example: |  | ||||||
|         .. code-block:: python |  | ||||||
| 
 |  | ||||||
|             from langchain.llms import GPT4All_J |  | ||||||
|             model = GPT4All_J(model="./models/gpt4all-model.bin") |  | ||||||
| 
 |  | ||||||
|             # Simplest invocation |  | ||||||
|             response = model("Once upon a time, ") |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     model: str |  | ||||||
|     """Path to the pre-trained GPT4All model file.""" |  | ||||||
| 
 |  | ||||||
|     n_threads: Optional[int] = Field(4, alias="n_threads") |  | ||||||
|     """Number of threads to use.""" |  | ||||||
| 
 |  | ||||||
|     n_predict: Optional[int] = 256 |  | ||||||
|     """The maximum number of tokens to generate.""" |  | ||||||
| 
 |  | ||||||
|     temp: Optional[float] = 0.8 |  | ||||||
|     """The temperature to use for sampling.""" |  | ||||||
| 
 |  | ||||||
|     top_p: Optional[float] = 0.95 |  | ||||||
|     """The top-p value to use for sampling.""" |  | ||||||
| 
 |  | ||||||
|     top_k: Optional[int] = 40 |  | ||||||
|     """The top-k value to use for sampling.""" |  | ||||||
| 
 |  | ||||||
|     echo: Optional[bool] = False |  | ||||||
|     """Whether to echo the prompt.""" |  | ||||||
| 
 |  | ||||||
|     stop: Optional[List[str]] = [] |  | ||||||
|     """A list of strings to stop generation when encountered.""" |  | ||||||
| 
 |  | ||||||
|     repeat_last_n: Optional[int] = 64 |  | ||||||
|     "Last n tokens to penalize" |  | ||||||
| 
 |  | ||||||
|     repeat_penalty: Optional[float] = 1.3 |  | ||||||
|     """The penalty to apply to repeated tokens.""" |  | ||||||
| 
 |  | ||||||
|     n_batch: int = Field(1, alias="n_batch") |  | ||||||
|     """Batch size for prompt processing.""" |  | ||||||
| 
 |  | ||||||
|     streaming: bool = False |  | ||||||
|     """Whether to stream the results or not.""" |  | ||||||
| 
 |  | ||||||
|     client: Any = None  #: :meta private: |  | ||||||
| 
 |  | ||||||
|     class Config: |  | ||||||
|         """Configuration for this pydantic object.""" |  | ||||||
| 
 |  | ||||||
|         extra = Extra.forbid |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def _default_params(self) -> Dict[str, Any]: |  | ||||||
|         """Get the identifying parameters.""" |  | ||||||
|         return { |  | ||||||
|             "seed": self.seed, |  | ||||||
|             "n_predict": self.n_predict, |  | ||||||
|             "n_threads": self.n_threads, |  | ||||||
|             "n_batch": self.n_batch, |  | ||||||
|             "repeat_last_n": self.repeat_last_n, |  | ||||||
|             "repeat_penalty": self.repeat_penalty, |  | ||||||
|             "top_k": self.top_k, |  | ||||||
|             "top_p": self.top_p, |  | ||||||
|             "temp": self.temp, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def _llama_param_names() -> Set[str]: |  | ||||||
|         """Get the identifying parameters.""" |  | ||||||
|         return {} |  | ||||||
| 
 |  | ||||||
|     @root_validator() |  | ||||||
|     def validate_environment(cls, values: Dict) -> Dict: |  | ||||||
|         """Validate that the python package exists in the environment.""" |  | ||||||
|         try: |  | ||||||
|             from pygpt4all.models.gpt4all_j import GPT4All_J as GPT4AllModel |  | ||||||
|              |  | ||||||
|             llama_keys = cls._llama_param_names() |  | ||||||
|             model_kwargs = {k: v for k, v in values.items() if k in llama_keys} |  | ||||||
|             values["client"] = GPT4AllModel( |  | ||||||
|                 model_path=values["model"], |  | ||||||
|                 **model_kwargs, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         except ImportError: |  | ||||||
|             raise ValueError( |  | ||||||
|                 "Could not import pygpt4all python package. " |  | ||||||
|                 "Please install it with `pip install pygpt4all`." |  | ||||||
|             ) |  | ||||||
|         return values |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def _identifying_params(self) -> Mapping[str, Any]: |  | ||||||
|         """Get the identifying parameters.""" |  | ||||||
|         return { |  | ||||||
|             "model": self.model, |  | ||||||
|             **self._default_params, |  | ||||||
|             **{ |  | ||||||
|                 k: v |  | ||||||
|                 for k, v in self.__dict__.items() |  | ||||||
|                 if k in GPT4All_J._llama_param_names() |  | ||||||
|             }, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def _llm_type(self) -> str: |  | ||||||
|         """Return the type of llm.""" |  | ||||||
|         return "gpt4all" |  | ||||||
| 
 |  | ||||||
|     def _call( |  | ||||||
|         self, |  | ||||||
|         prompt: str, |  | ||||||
|         stop: Optional[List[str]] = None, |  | ||||||
|         run_manager: Optional[CallbackManagerForLLMRun] = None, |  | ||||||
|     ) -> str: |  | ||||||
|         r"""Call out to GPT4All's generate method. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             prompt: The prompt to pass into the model. |  | ||||||
|             stop: A list of strings to stop generation when encountered. |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             The string generated by the model. |  | ||||||
| 
 |  | ||||||
|         Example: |  | ||||||
|             .. code-block:: python |  | ||||||
| 
 |  | ||||||
|                 prompt = "Once upon a time, " |  | ||||||
|                 response = model(prompt, n_predict=55) |  | ||||||
|         """ |  | ||||||
|         if run_manager: |  | ||||||
|             text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) |  | ||||||
|             text = self.client.generate( |  | ||||||
|                 prompt, |  | ||||||
|                 new_text_callback=text_callback |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             text = self.client.generate(prompt) |  | ||||||
|         if stop is not None: |  | ||||||
|             text = enforce_stop_tokens(text, stop) |  | ||||||
|         return text |  | ||||||
|  | @ -1,8 +1,8 @@ | ||||||
| from gpt4all_j import GPT4All_J |  | ||||||
| from langchain.chains import RetrievalQA | from langchain.chains import RetrievalQA | ||||||
| from langchain.embeddings import LlamaCppEmbeddings | from langchain.embeddings import LlamaCppEmbeddings | ||||||
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | ||||||
| from langchain.vectorstores import Chroma | from langchain.vectorstores import Chroma | ||||||
|  | from langchain.llms import GPT4All | ||||||
| 
 | 
 | ||||||
| def main():         | def main():         | ||||||
|     # Load stored vectorstore |     # Load stored vectorstore | ||||||
|  | @ -12,14 +12,28 @@ def main(): | ||||||
|     retriever = db.as_retriever() |     retriever = db.as_retriever() | ||||||
|     # Prepare the LLM |     # Prepare the LLM | ||||||
|     callbacks = [StreamingStdOutCallbackHandler()] |     callbacks = [StreamingStdOutCallbackHandler()] | ||||||
|     llm = GPT4All_J(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', callbacks=callbacks, verbose=False) |     llm = GPT4All(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', backend='gptj', callbacks=callbacks, verbose=False) | ||||||
|     qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) |     qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | ||||||
|     # Interactive questions and answers |     # Interactive questions and answers | ||||||
|     while True: |     while True: | ||||||
|         query = input("Enter a query: ") |         query = input("\nEnter a query: ") | ||||||
|         if query == "exit": |         if query == "exit": | ||||||
|             break |             break | ||||||
|         qa.run(query)     |          | ||||||
|  |         # Get the answer from the chain | ||||||
|  |         res = qa(query)     | ||||||
|  |         answer, docs = res['result'], res['source_documents'] | ||||||
|  | 
 | ||||||
|  |         # Print the result | ||||||
|  |         print("\n\n> Question:") | ||||||
|  |         print(query) | ||||||
|  |         print("\n> Answer:") | ||||||
|  |         print(answer) | ||||||
|  |          | ||||||
|  |         # Print the relevant sources used for the answer | ||||||
|  |         for document in docs: | ||||||
|  |             print("\n> " + document.metadata["source"] + ":") | ||||||
|  |             print(document.page_content) | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     main() |     main() | ||||||
|  | @ -1,5 +1,4 @@ | ||||||
| langchain==0.0.154 | langchain==0.0.162 | ||||||
| pygptj==1.0.10 | pygpt4all==1.1.0 | ||||||
| pygpt4all==1.0.1 | chromadb==0.3.22 | ||||||
| chromadb==0.3.21 | llama-cpp-python==0.1.47 | ||||||
| llama-cpp-python==0.1.41 |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue