diff --git a/private_gpt/components/embedding/custom/__init__.py b/private_gpt/components/embedding/custom/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/private_gpt/components/embedding/custom/sagemaker.py b/private_gpt/components/embedding/custom/sagemaker.py new file mode 100644 index 0000000..a1dddd0 --- /dev/null +++ b/private_gpt/components/embedding/custom/sagemaker.py @@ -0,0 +1,82 @@ +# mypy: ignore-errors +import json +from typing import Any + +import boto3 +from llama_index.embeddings.base import BaseEmbedding +from pydantic import Field, PrivateAttr + + +class SagemakerEmbedding(BaseEmbedding): + """Sagemaker Embedding Endpoint. + + To use, you must supply the endpoint name from your deployed + Sagemaker embedding model & the region where it is deployed. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html + """ + + endpoint_name: str = Field(description="") + + _boto_client: Any = boto3.client( + "sagemaker-runtime", + ) # TODO make it an optional field + + _async_not_implemented_warned: bool = PrivateAttr(default=False) + + @classmethod + def class_name(cls) -> str: + return "SagemakerEmbedding" + + def _async_not_implemented_warn_once(self) -> None: + if not self._async_not_implemented_warned: + print("Async embedding not available, falling back to sync method.") + self._async_not_implemented_warned = True + + def _embed(self, sentences: list[str]) -> list[list[float]]: + request_params = { + "inputs": sentences, + } + + resp = self._boto_client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=json.dumps(request_params), + ContentType="application/json", + ) + + response_body = resp["Body"] + response_str = response_body.read().decode("utf-8") + response_json = json.loads(response_str) + + return response_json["vectors"] + + def _get_query_embedding(self, query: str) -> list[float]: + """Get query embedding.""" + return self._embed([query])[0] + + async def _aget_query_embedding(self, query: str) -> list[float]: + # Warn the user that sync is being used + self._async_not_implemented_warn_once() + return self._get_query_embedding(query) + + async def _aget_text_embedding(self, text: str) -> list[float]: + # Warn the user that sync is being used + self._async_not_implemented_warn_once() + return self._get_text_embedding(text) + + def _get_text_embedding(self, text: str) -> list[float]: + """Get text embedding.""" + return self._embed([text])[0] + + def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: + """Get text embeddings.""" + return self._embed(texts) diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 079e72b..f71be0a 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -13,13 +13,22 @@ class EmbeddingComponent: @inject def __init__(self) -> None: match settings.llm.mode: - case "local" | "sagemaker": + case "local": from llama_index.embeddings import HuggingFaceEmbedding self.embedding_model = HuggingFaceEmbedding( model_name=settings.local.embedding_hf_model_name, cache_folder=str(models_cache_path), ) + case "sagemaker": + + from private_gpt.components.embedding.custom.sagemaker import ( + SagemakerEmbedding, + ) + + self.embedding_model = SagemakerEmbedding( + endpoint_name=settings.sagemaker.embedding_endpoint_name, + ) case "openai": from llama_index import OpenAIEmbedding diff --git a/private_gpt/components/llm/custom/sagemaker.py b/private_gpt/components/llm/custom/sagemaker.py index 284ee2c..2eedb1d 100644 --- a/private_gpt/components/llm/custom/sagemaker.py +++ b/private_gpt/components/llm/custom/sagemaker.py @@ -4,7 +4,7 @@ from __future__ import annotations import io import json import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import boto3 # type: ignore from llama_index.bridge.pydantic import Field @@ -30,7 +30,6 @@ from llama_index.llms.llama_utils import ( if TYPE_CHECKING: from collections.abc import Sequence - from typing import Any from llama_index.callbacks import CallbackManager from llama_index.llms import ( diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 2c32897..cad6ed6 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -36,7 +36,7 @@ class LLMComponent: from private_gpt.components.llm.custom.sagemaker import SagemakerLLM self.llm = SagemakerLLM( - endpoint_name=settings.sagemaker.endpoint_name, + endpoint_name=settings.sagemaker.llm_endpoint_name, ) case "openai": from llama_index.llms import OpenAI diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 73e5c92..16b1d4b 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -28,7 +28,8 @@ class LocalSettings(BaseModel): class SagemakerSettings(BaseModel): - endpoint_name: str + llm_endpoint_name: str + embedding_endpoint_name: str class OpenAISettings(BaseModel): diff --git a/settings-docker.yaml b/settings-docker.yaml index 12fea9f..49b3961 100644 --- a/settings-docker.yaml +++ b/settings-docker.yaml @@ -11,7 +11,8 @@ local: embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5} sagemaker: - endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:} + llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:} + embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:} ui: enabled: true diff --git a/settings-sagemaker.yaml b/settings-sagemaker.yaml new file mode 100644 index 0000000..774b8cb --- /dev/null +++ b/settings-sagemaker.yaml @@ -0,0 +1,14 @@ +server: + env_name: ${APP_ENV:prod} + port: ${PORT:8001} + +ui: + enabled: true + path: / + +llm: + mode: sagemaker + +sagemaker: + llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 + embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 \ No newline at end of file diff --git a/settings.yaml b/settings.yaml index 5cdb556..fba278d 100644 --- a/settings.yaml +++ b/settings.yaml @@ -18,7 +18,8 @@ local: embedding_hf_model_name: BAAI/bge-small-en-v1.5 sagemaker: - endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 + llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140 + embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479 openai: api_key: ${OPENAI_API_KEY:}