Feature/sagemaker embedding (#1161)
* Sagemaker deployed embedding model support --------- Co-authored-by: Pablo Orgaz <pabloogc@gmail.com>
This commit is contained in:
		
							parent
							
								
									f29df84301
								
							
						
					
					
						commit
						ad512e3c42
					
				|  | @ -0,0 +1,82 @@ | |||
| # mypy: ignore-errors | ||||
| import json | ||||
| from typing import Any | ||||
| 
 | ||||
| import boto3 | ||||
| from llama_index.embeddings.base import BaseEmbedding | ||||
| from pydantic import Field, PrivateAttr | ||||
| 
 | ||||
| 
 | ||||
| class SagemakerEmbedding(BaseEmbedding): | ||||
|     """Sagemaker Embedding Endpoint. | ||||
| 
 | ||||
|     To use, you must supply the endpoint name from your deployed | ||||
|     Sagemaker embedding model & the region where it is deployed. | ||||
| 
 | ||||
|     To authenticate, the AWS client uses the following methods to | ||||
|     automatically load credentials: | ||||
|     https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||||
| 
 | ||||
|     If a specific credential profile should be used, you must pass | ||||
|     the name of the profile from the ~/.aws/credentials file that is to be used. | ||||
| 
 | ||||
|     Make sure the credentials / roles used have the required policies to | ||||
|     access the Sagemaker endpoint. | ||||
|     See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html | ||||
|     """ | ||||
| 
 | ||||
|     endpoint_name: str = Field(description="") | ||||
| 
 | ||||
|     _boto_client: Any = boto3.client( | ||||
|         "sagemaker-runtime", | ||||
|     )  # TODO make it an optional field | ||||
| 
 | ||||
|     _async_not_implemented_warned: bool = PrivateAttr(default=False) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def class_name(cls) -> str: | ||||
|         return "SagemakerEmbedding" | ||||
| 
 | ||||
|     def _async_not_implemented_warn_once(self) -> None: | ||||
|         if not self._async_not_implemented_warned: | ||||
|             print("Async embedding not available, falling back to sync method.") | ||||
|             self._async_not_implemented_warned = True | ||||
| 
 | ||||
|     def _embed(self, sentences: list[str]) -> list[list[float]]: | ||||
|         request_params = { | ||||
|             "inputs": sentences, | ||||
|         } | ||||
| 
 | ||||
|         resp = self._boto_client.invoke_endpoint( | ||||
|             EndpointName=self.endpoint_name, | ||||
|             Body=json.dumps(request_params), | ||||
|             ContentType="application/json", | ||||
|         ) | ||||
| 
 | ||||
|         response_body = resp["Body"] | ||||
|         response_str = response_body.read().decode("utf-8") | ||||
|         response_json = json.loads(response_str) | ||||
| 
 | ||||
|         return response_json["vectors"] | ||||
| 
 | ||||
|     def _get_query_embedding(self, query: str) -> list[float]: | ||||
|         """Get query embedding.""" | ||||
|         return self._embed([query])[0] | ||||
| 
 | ||||
|     async def _aget_query_embedding(self, query: str) -> list[float]: | ||||
|         # Warn the user that sync is being used | ||||
|         self._async_not_implemented_warn_once() | ||||
|         return self._get_query_embedding(query) | ||||
| 
 | ||||
|     async def _aget_text_embedding(self, text: str) -> list[float]: | ||||
|         # Warn the user that sync is being used | ||||
|         self._async_not_implemented_warn_once() | ||||
|         return self._get_text_embedding(text) | ||||
| 
 | ||||
|     def _get_text_embedding(self, text: str) -> list[float]: | ||||
|         """Get text embedding.""" | ||||
|         return self._embed([text])[0] | ||||
| 
 | ||||
|     def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: | ||||
|         """Get text embeddings.""" | ||||
|         return self._embed(texts) | ||||
|  | @ -13,13 +13,22 @@ class EmbeddingComponent: | |||
|     @inject | ||||
|     def __init__(self) -> None: | ||||
|         match settings.llm.mode: | ||||
|             case "local" | "sagemaker": | ||||
|             case "local": | ||||
|                 from llama_index.embeddings import HuggingFaceEmbedding | ||||
| 
 | ||||
|                 self.embedding_model = HuggingFaceEmbedding( | ||||
|                     model_name=settings.local.embedding_hf_model_name, | ||||
|                     cache_folder=str(models_cache_path), | ||||
|                 ) | ||||
|             case "sagemaker": | ||||
| 
 | ||||
|                 from private_gpt.components.embedding.custom.sagemaker import ( | ||||
|                     SagemakerEmbedding, | ||||
|                 ) | ||||
| 
 | ||||
|                 self.embedding_model = SagemakerEmbedding( | ||||
|                     endpoint_name=settings.sagemaker.embedding_endpoint_name, | ||||
|                 ) | ||||
|             case "openai": | ||||
|                 from llama_index import OpenAIEmbedding | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ from __future__ import annotations | |||
| import io | ||||
| import json | ||||
| import logging | ||||
| from typing import TYPE_CHECKING | ||||
| from typing import TYPE_CHECKING, Any | ||||
| 
 | ||||
| import boto3  # type: ignore | ||||
| from llama_index.bridge.pydantic import Field | ||||
|  | @ -30,7 +30,6 @@ from llama_index.llms.llama_utils import ( | |||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Sequence | ||||
|     from typing import Any | ||||
| 
 | ||||
|     from llama_index.callbacks import CallbackManager | ||||
|     from llama_index.llms import ( | ||||
|  |  | |||
|  | @ -36,7 +36,7 @@ class LLMComponent: | |||
|                 from private_gpt.components.llm.custom.sagemaker import SagemakerLLM | ||||
| 
 | ||||
|                 self.llm = SagemakerLLM( | ||||
|                     endpoint_name=settings.sagemaker.endpoint_name, | ||||
|                     endpoint_name=settings.sagemaker.llm_endpoint_name, | ||||
|                 ) | ||||
|             case "openai": | ||||
|                 from llama_index.llms import OpenAI | ||||
|  |  | |||
|  | @ -28,7 +28,8 @@ class LocalSettings(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| class SagemakerSettings(BaseModel): | ||||
|     endpoint_name: str | ||||
|     llm_endpoint_name: str | ||||
|     embedding_endpoint_name: str | ||||
| 
 | ||||
| 
 | ||||
| class OpenAISettings(BaseModel): | ||||
|  |  | |||
|  | @ -11,7 +11,8 @@ local: | |||
|   embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5} | ||||
| 
 | ||||
| sagemaker: | ||||
|   endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:} | ||||
|   llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:} | ||||
|   embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:} | ||||
| 
 | ||||
| ui: | ||||
|   enabled: true | ||||
|  |  | |||
|  | @ -0,0 +1,14 @@ | |||
| server: | ||||
|   env_name: ${APP_ENV:prod} | ||||
|   port: ${PORT:8001} | ||||
| 
 | ||||
| ui: | ||||
|   enabled: true | ||||
|   path: / | ||||
| 
 | ||||
| llm: | ||||
|   mode: sagemaker | ||||
| 
 | ||||
| sagemaker: | ||||
|   llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 | ||||
|   embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 | ||||
|  | @ -18,7 +18,8 @@ local: | |||
|   embedding_hf_model_name: BAAI/bge-small-en-v1.5 | ||||
| 
 | ||||
| sagemaker: | ||||
|   endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 | ||||
|   llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 | ||||
|   embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 | ||||
| 
 | ||||
| openai: | ||||
|   api_key: ${OPENAI_API_KEY:} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue