Update dependencies. Remove custom gpt4all_j wrapper.
This commit is contained in:
		
							parent
							
								
									92244a90b4
								
							
						
					
					
						commit
						bdd8c8748b
					
				
							
								
								
									
										160
									
								
								gpt4all_j.py
								
								
								
								
							
							
						
						
									
										160
									
								
								gpt4all_j.py
								
								
								
								
							|  | @ -1,160 +0,0 @@ | |||
| """Wrapper for the GPT4All-J model.""" | ||||
| from functools import partial | ||||
| from typing import Any, Dict, List, Mapping, Optional, Set | ||||
| 
 | ||||
| from pydantic import Extra, Field, root_validator | ||||
| 
 | ||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | ||||
| from langchain.llms.base import LLM | ||||
| from langchain.llms.utils import enforce_stop_tokens | ||||
| 
 | ||||
| 
 | ||||
| class GPT4All_J(LLM): | ||||
|     r"""Wrapper around GPT4All-J language models. | ||||
| 
 | ||||
|     To use, you should have the ``pygpt4all`` python package installed, the | ||||
|     pre-trained model file, and the model's config information. | ||||
| 
 | ||||
|     Example: | ||||
|         .. code-block:: python | ||||
| 
 | ||||
|             from langchain.llms import GPT4All_J | ||||
|             model = GPT4All_J(model="./models/gpt4all-model.bin") | ||||
| 
 | ||||
|             # Simplest invocation | ||||
|             response = model("Once upon a time, ") | ||||
|     """ | ||||
| 
 | ||||
|     model: str | ||||
|     """Path to the pre-trained GPT4All model file.""" | ||||
| 
 | ||||
|     n_threads: Optional[int] = Field(4, alias="n_threads") | ||||
|     """Number of threads to use.""" | ||||
| 
 | ||||
|     n_predict: Optional[int] = 256 | ||||
|     """The maximum number of tokens to generate.""" | ||||
| 
 | ||||
|     temp: Optional[float] = 0.8 | ||||
|     """The temperature to use for sampling.""" | ||||
| 
 | ||||
|     top_p: Optional[float] = 0.95 | ||||
|     """The top-p value to use for sampling.""" | ||||
| 
 | ||||
|     top_k: Optional[int] = 40 | ||||
|     """The top-k value to use for sampling.""" | ||||
| 
 | ||||
|     echo: Optional[bool] = False | ||||
|     """Whether to echo the prompt.""" | ||||
| 
 | ||||
|     stop: Optional[List[str]] = [] | ||||
|     """A list of strings to stop generation when encountered.""" | ||||
| 
 | ||||
|     repeat_last_n: Optional[int] = 64 | ||||
|     "Last n tokens to penalize" | ||||
| 
 | ||||
|     repeat_penalty: Optional[float] = 1.3 | ||||
|     """The penalty to apply to repeated tokens.""" | ||||
| 
 | ||||
|     n_batch: int = Field(1, alias="n_batch") | ||||
|     """Batch size for prompt processing.""" | ||||
| 
 | ||||
|     streaming: bool = False | ||||
|     """Whether to stream the results or not.""" | ||||
| 
 | ||||
|     client: Any = None  #: :meta private: | ||||
| 
 | ||||
|     class Config: | ||||
|         """Configuration for this pydantic object.""" | ||||
| 
 | ||||
|         extra = Extra.forbid | ||||
| 
 | ||||
|     @property | ||||
|     def _default_params(self) -> Dict[str, Any]: | ||||
|         """Get the identifying parameters.""" | ||||
|         return { | ||||
|             "seed": self.seed, | ||||
|             "n_predict": self.n_predict, | ||||
|             "n_threads": self.n_threads, | ||||
|             "n_batch": self.n_batch, | ||||
|             "repeat_last_n": self.repeat_last_n, | ||||
|             "repeat_penalty": self.repeat_penalty, | ||||
|             "top_k": self.top_k, | ||||
|             "top_p": self.top_p, | ||||
|             "temp": self.temp, | ||||
|         } | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _llama_param_names() -> Set[str]: | ||||
|         """Get the identifying parameters.""" | ||||
|         return {} | ||||
| 
 | ||||
|     @root_validator() | ||||
|     def validate_environment(cls, values: Dict) -> Dict: | ||||
|         """Validate that the python package exists in the environment.""" | ||||
|         try: | ||||
|             from pygpt4all.models.gpt4all_j import GPT4All_J as GPT4AllModel | ||||
|              | ||||
|             llama_keys = cls._llama_param_names() | ||||
|             model_kwargs = {k: v for k, v in values.items() if k in llama_keys} | ||||
|             values["client"] = GPT4AllModel( | ||||
|                 model_path=values["model"], | ||||
|                 **model_kwargs, | ||||
|             ) | ||||
| 
 | ||||
|         except ImportError: | ||||
|             raise ValueError( | ||||
|                 "Could not import pygpt4all python package. " | ||||
|                 "Please install it with `pip install pygpt4all`." | ||||
|             ) | ||||
|         return values | ||||
| 
 | ||||
|     @property | ||||
|     def _identifying_params(self) -> Mapping[str, Any]: | ||||
|         """Get the identifying parameters.""" | ||||
|         return { | ||||
|             "model": self.model, | ||||
|             **self._default_params, | ||||
|             **{ | ||||
|                 k: v | ||||
|                 for k, v in self.__dict__.items() | ||||
|                 if k in GPT4All_J._llama_param_names() | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     @property | ||||
|     def _llm_type(self) -> str: | ||||
|         """Return the type of llm.""" | ||||
|         return "gpt4all" | ||||
| 
 | ||||
|     def _call( | ||||
|         self, | ||||
|         prompt: str, | ||||
|         stop: Optional[List[str]] = None, | ||||
|         run_manager: Optional[CallbackManagerForLLMRun] = None, | ||||
|     ) -> str: | ||||
|         r"""Call out to GPT4All's generate method. | ||||
| 
 | ||||
|         Args: | ||||
|             prompt: The prompt to pass into the model. | ||||
|             stop: A list of strings to stop generation when encountered. | ||||
| 
 | ||||
|         Returns: | ||||
|             The string generated by the model. | ||||
| 
 | ||||
|         Example: | ||||
|             .. code-block:: python | ||||
| 
 | ||||
|                 prompt = "Once upon a time, " | ||||
|                 response = model(prompt, n_predict=55) | ||||
|         """ | ||||
|         if run_manager: | ||||
|             text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) | ||||
|             text = self.client.generate( | ||||
|                 prompt, | ||||
|                 new_text_callback=text_callback | ||||
|             ) | ||||
|         else: | ||||
|             text = self.client.generate(prompt) | ||||
|         if stop is not None: | ||||
|             text = enforce_stop_tokens(text, stop) | ||||
|         return text | ||||
|  | @ -1,8 +1,8 @@ | |||
| from gpt4all_j import GPT4All_J | ||||
| from langchain.chains import RetrievalQA | ||||
| from langchain.embeddings import LlamaCppEmbeddings | ||||
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | ||||
| from langchain.vectorstores import Chroma | ||||
| from langchain.llms import GPT4All | ||||
| 
 | ||||
| def main():         | ||||
|     # Load stored vectorstore | ||||
|  | @ -12,14 +12,28 @@ def main(): | |||
|     retriever = db.as_retriever() | ||||
|     # Prepare the LLM | ||||
|     callbacks = [StreamingStdOutCallbackHandler()] | ||||
|     llm = GPT4All_J(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', callbacks=callbacks, verbose=False) | ||||
|     qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) | ||||
|     llm = GPT4All(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', backend='gptj', callbacks=callbacks, verbose=False) | ||||
|     qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | ||||
|     # Interactive questions and answers | ||||
|     while True: | ||||
|         query = input("Enter a query: ") | ||||
|         query = input("\nEnter a query: ") | ||||
|         if query == "exit": | ||||
|             break | ||||
|         qa.run(query)     | ||||
|          | ||||
|         # Get the answer from the chain | ||||
|         res = qa(query)     | ||||
|         answer, docs = res['result'], res['source_documents'] | ||||
| 
 | ||||
|         # Print the result | ||||
|         print("\n\n> Question:") | ||||
|         print(query) | ||||
|         print("\n> Answer:") | ||||
|         print(answer) | ||||
|          | ||||
|         # Print the relevant sources used for the answer | ||||
|         for document in docs: | ||||
|             print("\n> " + document.metadata["source"] + ":") | ||||
|             print(document.page_content) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | @ -1,5 +1,4 @@ | |||
| langchain==0.0.154 | ||||
| pygptj==1.0.10 | ||||
| pygpt4all==1.0.1 | ||||
| chromadb==0.3.21 | ||||
| llama-cpp-python==0.1.41 | ||||
| langchain==0.0.162 | ||||
| pygpt4all==1.1.0 | ||||
| chromadb==0.3.22 | ||||
| llama-cpp-python==0.1.47 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue