Feature/sagemaker embedding (#1161)
* Sagemaker deployed embedding model support --------- Co-authored-by: Pablo Orgaz <pabloogc@gmail.com>
This commit is contained in:
		
							parent
							
								
									f29df84301
								
							
						
					
					
						commit
						ad512e3c42
					
				|  | @ -0,0 +1,82 @@ | ||||||
|  | # mypy: ignore-errors | ||||||
|  | import json | ||||||
|  | from typing import Any | ||||||
|  | 
 | ||||||
|  | import boto3 | ||||||
|  | from llama_index.embeddings.base import BaseEmbedding | ||||||
|  | from pydantic import Field, PrivateAttr | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SagemakerEmbedding(BaseEmbedding): | ||||||
|  |     """Sagemaker Embedding Endpoint. | ||||||
|  | 
 | ||||||
|  |     To use, you must supply the endpoint name from your deployed | ||||||
|  |     Sagemaker embedding model & the region where it is deployed. | ||||||
|  | 
 | ||||||
|  |     To authenticate, the AWS client uses the following methods to | ||||||
|  |     automatically load credentials: | ||||||
|  |     https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||||||
|  | 
 | ||||||
|  |     If a specific credential profile should be used, you must pass | ||||||
|  |     the name of the profile from the ~/.aws/credentials file that is to be used. | ||||||
|  | 
 | ||||||
|  |     Make sure the credentials / roles used have the required policies to | ||||||
|  |     access the Sagemaker endpoint. | ||||||
|  |     See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     endpoint_name: str = Field(description="") | ||||||
|  | 
 | ||||||
|  |     _boto_client: Any = boto3.client( | ||||||
|  |         "sagemaker-runtime", | ||||||
|  |     )  # TODO make it an optional field | ||||||
|  | 
 | ||||||
|  |     _async_not_implemented_warned: bool = PrivateAttr(default=False) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def class_name(cls) -> str: | ||||||
|  |         return "SagemakerEmbedding" | ||||||
|  | 
 | ||||||
|  |     def _async_not_implemented_warn_once(self) -> None: | ||||||
|  |         if not self._async_not_implemented_warned: | ||||||
|  |             print("Async embedding not available, falling back to sync method.") | ||||||
|  |             self._async_not_implemented_warned = True | ||||||
|  | 
 | ||||||
|  |     def _embed(self, sentences: list[str]) -> list[list[float]]: | ||||||
|  |         request_params = { | ||||||
|  |             "inputs": sentences, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         resp = self._boto_client.invoke_endpoint( | ||||||
|  |             EndpointName=self.endpoint_name, | ||||||
|  |             Body=json.dumps(request_params), | ||||||
|  |             ContentType="application/json", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response_body = resp["Body"] | ||||||
|  |         response_str = response_body.read().decode("utf-8") | ||||||
|  |         response_json = json.loads(response_str) | ||||||
|  | 
 | ||||||
|  |         return response_json["vectors"] | ||||||
|  | 
 | ||||||
|  |     def _get_query_embedding(self, query: str) -> list[float]: | ||||||
|  |         """Get query embedding.""" | ||||||
|  |         return self._embed([query])[0] | ||||||
|  | 
 | ||||||
|  |     async def _aget_query_embedding(self, query: str) -> list[float]: | ||||||
|  |         # Warn the user that sync is being used | ||||||
|  |         self._async_not_implemented_warn_once() | ||||||
|  |         return self._get_query_embedding(query) | ||||||
|  | 
 | ||||||
|  |     async def _aget_text_embedding(self, text: str) -> list[float]: | ||||||
|  |         # Warn the user that sync is being used | ||||||
|  |         self._async_not_implemented_warn_once() | ||||||
|  |         return self._get_text_embedding(text) | ||||||
|  | 
 | ||||||
|  |     def _get_text_embedding(self, text: str) -> list[float]: | ||||||
|  |         """Get text embedding.""" | ||||||
|  |         return self._embed([text])[0] | ||||||
|  | 
 | ||||||
|  |     def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: | ||||||
|  |         """Get text embeddings.""" | ||||||
|  |         return self._embed(texts) | ||||||
|  | @ -13,13 +13,22 @@ class EmbeddingComponent: | ||||||
|     @inject |     @inject | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|         match settings.llm.mode: |         match settings.llm.mode: | ||||||
|             case "local" | "sagemaker": |             case "local": | ||||||
|                 from llama_index.embeddings import HuggingFaceEmbedding |                 from llama_index.embeddings import HuggingFaceEmbedding | ||||||
| 
 | 
 | ||||||
|                 self.embedding_model = HuggingFaceEmbedding( |                 self.embedding_model = HuggingFaceEmbedding( | ||||||
|                     model_name=settings.local.embedding_hf_model_name, |                     model_name=settings.local.embedding_hf_model_name, | ||||||
|                     cache_folder=str(models_cache_path), |                     cache_folder=str(models_cache_path), | ||||||
|                 ) |                 ) | ||||||
|  |             case "sagemaker": | ||||||
|  | 
 | ||||||
|  |                 from private_gpt.components.embedding.custom.sagemaker import ( | ||||||
|  |                     SagemakerEmbedding, | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 self.embedding_model = SagemakerEmbedding( | ||||||
|  |                     endpoint_name=settings.sagemaker.embedding_endpoint_name, | ||||||
|  |                 ) | ||||||
|             case "openai": |             case "openai": | ||||||
|                 from llama_index import OpenAIEmbedding |                 from llama_index import OpenAIEmbedding | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ from __future__ import annotations | ||||||
| import io | import io | ||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING, Any | ||||||
| 
 | 
 | ||||||
| import boto3  # type: ignore | import boto3  # type: ignore | ||||||
| from llama_index.bridge.pydantic import Field | from llama_index.bridge.pydantic import Field | ||||||
|  | @ -30,7 +30,6 @@ from llama_index.llms.llama_utils import ( | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from collections.abc import Sequence |     from collections.abc import Sequence | ||||||
|     from typing import Any |  | ||||||
| 
 | 
 | ||||||
|     from llama_index.callbacks import CallbackManager |     from llama_index.callbacks import CallbackManager | ||||||
|     from llama_index.llms import ( |     from llama_index.llms import ( | ||||||
|  |  | ||||||
|  | @ -36,7 +36,7 @@ class LLMComponent: | ||||||
|                 from private_gpt.components.llm.custom.sagemaker import SagemakerLLM |                 from private_gpt.components.llm.custom.sagemaker import SagemakerLLM | ||||||
| 
 | 
 | ||||||
|                 self.llm = SagemakerLLM( |                 self.llm = SagemakerLLM( | ||||||
|                     endpoint_name=settings.sagemaker.endpoint_name, |                     endpoint_name=settings.sagemaker.llm_endpoint_name, | ||||||
|                 ) |                 ) | ||||||
|             case "openai": |             case "openai": | ||||||
|                 from llama_index.llms import OpenAI |                 from llama_index.llms import OpenAI | ||||||
|  |  | ||||||
|  | @ -28,7 +28,8 @@ class LocalSettings(BaseModel): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SagemakerSettings(BaseModel): | class SagemakerSettings(BaseModel): | ||||||
|     endpoint_name: str |     llm_endpoint_name: str | ||||||
|  |     embedding_endpoint_name: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class OpenAISettings(BaseModel): | class OpenAISettings(BaseModel): | ||||||
|  |  | ||||||
|  | @ -11,7 +11,8 @@ local: | ||||||
|   embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5} |   embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5} | ||||||
| 
 | 
 | ||||||
| sagemaker: | sagemaker: | ||||||
|   endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:} |   llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:} | ||||||
|  |   embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:} | ||||||
| 
 | 
 | ||||||
| ui: | ui: | ||||||
|   enabled: true |   enabled: true | ||||||
|  |  | ||||||
|  | @ -0,0 +1,14 @@ | ||||||
|  | server: | ||||||
|  |   env_name: ${APP_ENV:prod} | ||||||
|  |   port: ${PORT:8001} | ||||||
|  | 
 | ||||||
|  | ui: | ||||||
|  |   enabled: true | ||||||
|  |   path: / | ||||||
|  | 
 | ||||||
|  | llm: | ||||||
|  |   mode: sagemaker | ||||||
|  | 
 | ||||||
|  | sagemaker: | ||||||
|  |   llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 | ||||||
|  |   embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 | ||||||
|  | @ -18,7 +18,8 @@ local: | ||||||
|   embedding_hf_model_name: BAAI/bge-small-en-v1.5 |   embedding_hf_model_name: BAAI/bge-small-en-v1.5 | ||||||
| 
 | 
 | ||||||
| sagemaker: | sagemaker: | ||||||
|   endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 |   llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 | ||||||
|  |   embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 | ||||||
| 
 | 
 | ||||||
| openai: | openai: | ||||||
|   api_key: ${OPENAI_API_KEY:} |   api_key: ${OPENAI_API_KEY:} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue