Feature/sagemaker embedding (#1161)
* Sagemaker deployed embedding model support --------- Co-authored-by: Pablo Orgaz <pabloogc@gmail.com>
This commit is contained in:
parent
f29df84301
commit
ad512e3c42
|
@ -0,0 +1,82 @@
|
|||
# mypy: ignore-errors
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
|
||||
class SagemakerEmbedding(BaseEmbedding):
|
||||
"""Sagemaker Embedding Endpoint.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker embedding model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
endpoint_name: str = Field(description="")
|
||||
|
||||
_boto_client: Any = boto3.client(
|
||||
"sagemaker-runtime",
|
||||
) # TODO make it an optional field
|
||||
|
||||
_async_not_implemented_warned: bool = PrivateAttr(default=False)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "SagemakerEmbedding"
|
||||
|
||||
def _async_not_implemented_warn_once(self) -> None:
|
||||
if not self._async_not_implemented_warned:
|
||||
print("Async embedding not available, falling back to sync method.")
|
||||
self._async_not_implemented_warned = True
|
||||
|
||||
def _embed(self, sentences: list[str]) -> list[list[float]]:
|
||||
request_params = {
|
||||
"inputs": sentences,
|
||||
}
|
||||
|
||||
resp = self._boto_client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=json.dumps(request_params),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
response_body = resp["Body"]
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
response_json = json.loads(response_str)
|
||||
|
||||
return response_json["vectors"]
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
"""Get query embedding."""
|
||||
return self._embed([query])[0]
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
# Warn the user that sync is being used
|
||||
self._async_not_implemented_warn_once()
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> list[float]:
|
||||
# Warn the user that sync is being used
|
||||
self._async_not_implemented_warn_once()
|
||||
return self._get_text_embedding(text)
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
"""Get text embedding."""
|
||||
return self._embed([text])[0]
|
||||
|
||||
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Get text embeddings."""
|
||||
return self._embed(texts)
|
|
@ -13,13 +13,22 @@ class EmbeddingComponent:
|
|||
@inject
|
||||
def __init__(self) -> None:
|
||||
match settings.llm.mode:
|
||||
case "local" | "sagemaker":
|
||||
case "local":
|
||||
from llama_index.embeddings import HuggingFaceEmbedding
|
||||
|
||||
self.embedding_model = HuggingFaceEmbedding(
|
||||
model_name=settings.local.embedding_hf_model_name,
|
||||
cache_folder=str(models_cache_path),
|
||||
)
|
||||
case "sagemaker":
|
||||
|
||||
from private_gpt.components.embedding.custom.sagemaker import (
|
||||
SagemakerEmbedding,
|
||||
)
|
||||
|
||||
self.embedding_model = SagemakerEmbedding(
|
||||
endpoint_name=settings.sagemaker.embedding_endpoint_name,
|
||||
)
|
||||
case "openai":
|
||||
from llama_index import OpenAIEmbedding
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import boto3 # type: ignore
|
||||
from llama_index.bridge.pydantic import Field
|
||||
|
@ -30,7 +30,6 @@ from llama_index.llms.llama_utils import (
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.llms import (
|
||||
|
|
|
@ -36,7 +36,7 @@ class LLMComponent:
|
|||
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
|
||||
|
||||
self.llm = SagemakerLLM(
|
||||
endpoint_name=settings.sagemaker.endpoint_name,
|
||||
endpoint_name=settings.sagemaker.llm_endpoint_name,
|
||||
)
|
||||
case "openai":
|
||||
from llama_index.llms import OpenAI
|
||||
|
|
|
@ -28,7 +28,8 @@ class LocalSettings(BaseModel):
|
|||
|
||||
|
||||
class SagemakerSettings(BaseModel):
|
||||
endpoint_name: str
|
||||
llm_endpoint_name: str
|
||||
embedding_endpoint_name: str
|
||||
|
||||
|
||||
class OpenAISettings(BaseModel):
|
||||
|
|
|
@ -11,7 +11,8 @@ local:
|
|||
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
|
||||
|
||||
sagemaker:
|
||||
endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:}
|
||||
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
|
||||
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
|
||||
|
||||
ui:
|
||||
enabled: true
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
server:
|
||||
env_name: ${APP_ENV:prod}
|
||||
port: ${PORT:8001}
|
||||
|
||||
ui:
|
||||
enabled: true
|
||||
path: /
|
||||
|
||||
llm:
|
||||
mode: sagemaker
|
||||
|
||||
sagemaker:
|
||||
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
|
|
@ -18,7 +18,8 @@ local:
|
|||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
|
||||
sagemaker:
|
||||
endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
|
||||
|
||||
openai:
|
||||
api_key: ${OPENAI_API_KEY:}
|
||||
|
|
Loading…
Reference in New Issue