Feature/sagemaker embedding (#1161)
* Sagemaker deployed embedding model support --------- Co-authored-by: Pablo Orgaz <pabloogc@gmail.com>
This commit is contained in:
parent
f29df84301
commit
ad512e3c42
|
@ -0,0 +1,82 @@
|
||||||
|
# mypy: ignore-errors
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from llama_index.embeddings.base import BaseEmbedding
|
||||||
|
from pydantic import Field, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
|
class SagemakerEmbedding(BaseEmbedding):
|
||||||
|
"""Sagemaker Embedding Endpoint.
|
||||||
|
|
||||||
|
To use, you must supply the endpoint name from your deployed
|
||||||
|
Sagemaker embedding model & the region where it is deployed.
|
||||||
|
|
||||||
|
To authenticate, the AWS client uses the following methods to
|
||||||
|
automatically load credentials:
|
||||||
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
|
||||||
|
If a specific credential profile should be used, you must pass
|
||||||
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||||
|
|
||||||
|
Make sure the credentials / roles used have the required policies to
|
||||||
|
access the Sagemaker endpoint.
|
||||||
|
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint_name: str = Field(description="")
|
||||||
|
|
||||||
|
_boto_client: Any = boto3.client(
|
||||||
|
"sagemaker-runtime",
|
||||||
|
) # TODO make it an optional field
|
||||||
|
|
||||||
|
_async_not_implemented_warned: bool = PrivateAttr(default=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "SagemakerEmbedding"
|
||||||
|
|
||||||
|
def _async_not_implemented_warn_once(self) -> None:
|
||||||
|
if not self._async_not_implemented_warned:
|
||||||
|
print("Async embedding not available, falling back to sync method.")
|
||||||
|
self._async_not_implemented_warned = True
|
||||||
|
|
||||||
|
def _embed(self, sentences: list[str]) -> list[list[float]]:
|
||||||
|
request_params = {
|
||||||
|
"inputs": sentences,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self._boto_client.invoke_endpoint(
|
||||||
|
EndpointName=self.endpoint_name,
|
||||||
|
Body=json.dumps(request_params),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_body = resp["Body"]
|
||||||
|
response_str = response_body.read().decode("utf-8")
|
||||||
|
response_json = json.loads(response_str)
|
||||||
|
|
||||||
|
return response_json["vectors"]
|
||||||
|
|
||||||
|
def _get_query_embedding(self, query: str) -> list[float]:
|
||||||
|
"""Get query embedding."""
|
||||||
|
return self._embed([query])[0]
|
||||||
|
|
||||||
|
async def _aget_query_embedding(self, query: str) -> list[float]:
|
||||||
|
# Warn the user that sync is being used
|
||||||
|
self._async_not_implemented_warn_once()
|
||||||
|
return self._get_query_embedding(query)
|
||||||
|
|
||||||
|
async def _aget_text_embedding(self, text: str) -> list[float]:
|
||||||
|
# Warn the user that sync is being used
|
||||||
|
self._async_not_implemented_warn_once()
|
||||||
|
return self._get_text_embedding(text)
|
||||||
|
|
||||||
|
def _get_text_embedding(self, text: str) -> list[float]:
|
||||||
|
"""Get text embedding."""
|
||||||
|
return self._embed([text])[0]
|
||||||
|
|
||||||
|
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Get text embeddings."""
|
||||||
|
return self._embed(texts)
|
|
@ -13,13 +13,22 @@ class EmbeddingComponent:
|
||||||
@inject
|
@inject
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
match settings.llm.mode:
|
match settings.llm.mode:
|
||||||
case "local" | "sagemaker":
|
case "local":
|
||||||
from llama_index.embeddings import HuggingFaceEmbedding
|
from llama_index.embeddings import HuggingFaceEmbedding
|
||||||
|
|
||||||
self.embedding_model = HuggingFaceEmbedding(
|
self.embedding_model = HuggingFaceEmbedding(
|
||||||
model_name=settings.local.embedding_hf_model_name,
|
model_name=settings.local.embedding_hf_model_name,
|
||||||
cache_folder=str(models_cache_path),
|
cache_folder=str(models_cache_path),
|
||||||
)
|
)
|
||||||
|
case "sagemaker":
|
||||||
|
|
||||||
|
from private_gpt.components.embedding.custom.sagemaker import (
|
||||||
|
SagemakerEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedding_model = SagemakerEmbedding(
|
||||||
|
endpoint_name=settings.sagemaker.embedding_endpoint_name,
|
||||||
|
)
|
||||||
case "openai":
|
case "openai":
|
||||||
from llama_index import OpenAIEmbedding
|
from llama_index import OpenAIEmbedding
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import boto3 # type: ignore
|
import boto3 # type: ignore
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
|
@ -30,7 +30,6 @@ from llama_index.llms.llama_utils import (
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_index.callbacks import CallbackManager
|
from llama_index.callbacks import CallbackManager
|
||||||
from llama_index.llms import (
|
from llama_index.llms import (
|
||||||
|
|
|
@ -36,7 +36,7 @@ class LLMComponent:
|
||||||
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
|
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
|
||||||
|
|
||||||
self.llm = SagemakerLLM(
|
self.llm = SagemakerLLM(
|
||||||
endpoint_name=settings.sagemaker.endpoint_name,
|
endpoint_name=settings.sagemaker.llm_endpoint_name,
|
||||||
)
|
)
|
||||||
case "openai":
|
case "openai":
|
||||||
from llama_index.llms import OpenAI
|
from llama_index.llms import OpenAI
|
||||||
|
|
|
@ -28,7 +28,8 @@ class LocalSettings(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SagemakerSettings(BaseModel):
|
class SagemakerSettings(BaseModel):
|
||||||
endpoint_name: str
|
llm_endpoint_name: str
|
||||||
|
embedding_endpoint_name: str
|
||||||
|
|
||||||
|
|
||||||
class OpenAISettings(BaseModel):
|
class OpenAISettings(BaseModel):
|
||||||
|
|
|
@ -11,7 +11,8 @@ local:
|
||||||
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
|
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
|
||||||
|
|
||||||
sagemaker:
|
sagemaker:
|
||||||
endpoint_name: ${PGPT_SAGEMAKER_ENDPOINT_NAME:}
|
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
|
||||||
|
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
|
||||||
|
|
||||||
ui:
|
ui:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
server:
|
||||||
|
env_name: ${APP_ENV:prod}
|
||||||
|
port: ${PORT:8001}
|
||||||
|
|
||||||
|
ui:
|
||||||
|
enabled: true
|
||||||
|
path: /
|
||||||
|
|
||||||
|
llm:
|
||||||
|
mode: sagemaker
|
||||||
|
|
||||||
|
sagemaker:
|
||||||
|
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||||
|
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
|
|
@ -18,7 +18,8 @@ local:
|
||||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||||
|
|
||||||
sagemaker:
|
sagemaker:
|
||||||
endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
|
||||||
|
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
|
||||||
|
|
||||||
openai:
|
openai:
|
||||||
api_key: ${OPENAI_API_KEY:}
|
api_key: ${OPENAI_API_KEY:}
|
||||||
|
|
Loading…
Reference in New Issue