Feature/sagemaker embedding (#1161)

* Sagemaker deployed embedding model support

---------

Co-authored-by: Pablo Orgaz <pabloogc@gmail.com>
This commit is contained in:
Iván Martínez 2023-11-05 16:16:49 +01:00 committed by GitHub
parent f29df84301
commit ad512e3c42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 114 additions and 7 deletions

View File

@ -0,0 +1,82 @@
# mypy: ignore-errors
import json
from typing import Any
import boto3
from llama_index.embeddings.base import BaseEmbedding
from pydantic import Field, PrivateAttr
class SagemakerEmbedding(BaseEmbedding):
"""Sagemaker Embedding Endpoint.
To use, you must supply the endpoint name from your deployed
Sagemaker embedding model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
endpoint_name: str = Field(description="")
_boto_client: Any = boto3.client(
"sagemaker-runtime",
) # TODO make it an optional field
_async_not_implemented_warned: bool = PrivateAttr(default=False)
@classmethod
def class_name(cls) -> str:
return "SagemakerEmbedding"
def _async_not_implemented_warn_once(self) -> None:
if not self._async_not_implemented_warned:
print("Async embedding not available, falling back to sync method.")
self._async_not_implemented_warned = True
def _embed(self, sentences: list[str]) -> list[list[float]]:
request_params = {
"inputs": sentences,
}
resp = self._boto_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
response_body = resp["Body"]
response_str = response_body.read().decode("utf-8")
response_json = json.loads(response_str)
return response_json["vectors"]
def _get_query_embedding(self, query: str) -> list[float]:
"""Get query embedding."""
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_text_embedding(text)
def _get_text_embedding(self, text: str) -> list[float]:
"""Get text embedding."""
return self._embed([text])[0]
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
"""Get text embeddings."""
return self._embed(texts)

View File

@ -13,13 +13,22 @@ class EmbeddingComponent:
@inject
def __init__(self) -> None:
match settings.llm.mode:
case "local" | "sagemaker":
case "local":
from llama_index.embeddings import HuggingFaceEmbedding
self.embedding_model = HuggingFaceEmbedding(
model_name=settings.local.embedding_hf_model_name,
cache_folder=str(models_cache_path),
)
case "sagemaker":
from private_gpt.components.embedding.custom.sagemaker import (
SagemakerEmbedding,
)
self.embedding_model = SagemakerEmbedding(
endpoint_name=settings.sagemaker.embedding_endpoint_name,
)
case "openai":
from llama_index import OpenAIEmbedding

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import json
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import boto3 # type: ignore
from llama_index.bridge.pydantic import Field
@ -30,7 +30,6 @@ from llama_index.llms.llama_utils import (
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
from llama_index.callbacks import CallbackManager
from llama_index.llms import (

View File

@ -36,7 +36,7 @@ class LLMComponent:
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
self.llm = SagemakerLLM(
endpoint_name=settings.sagemaker.endpoint_name,
endpoint_name=settings.sagemaker.llm_endpoint_name,
)
case "openai":
from llama_index.llms import OpenAI

View File

@ -28,7 +28,8 @@ class LocalSettings(BaseModel):
class SagemakerSettings(BaseModel):
endpoint_name: str
llm_endpoint_name: str
embedding_endpoint_name: str
class OpenAISettings(BaseModel):

View File

@ -11,7 +11,8 @@ local:
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
sagemaker:
endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:}
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
ui:
enabled: true

14
settings-sagemaker.yaml Normal file
View File

@ -0,0 +1,14 @@
server:
env_name: ${APP_ENV:prod}
port: ${PORT:8001}
ui:
enabled: true
path: /
llm:
mode: sagemaker
sagemaker:
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479

View File

@ -18,7 +18,8 @@ local:
embedding_hf_model_name: BAAI/bge-small-en-v1.5
sagemaker:
endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
openai:
api_key: ${OPENAI_API_KEY:}