fix: sagemaker config and chat methods (#1142)
This commit is contained in:
parent
b0e258265f
commit
a517a588c4
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import boto3 # type: ignore
|
import boto3 # type: ignore
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
|
@ -13,7 +13,14 @@ from llama_index.llms import (
|
||||||
CustomLLM,
|
CustomLLM,
|
||||||
LLMMetadata,
|
LLMMetadata,
|
||||||
)
|
)
|
||||||
from llama_index.llms.base import llm_completion_callback
|
from llama_index.llms.base import (
|
||||||
|
llm_chat_callback,
|
||||||
|
llm_completion_callback,
|
||||||
|
)
|
||||||
|
from llama_index.llms.generic_utils import (
|
||||||
|
completion_response_to_chat_response,
|
||||||
|
stream_completion_response_to_chat_response,
|
||||||
|
)
|
||||||
from llama_index.llms.llama_utils import (
|
from llama_index.llms.llama_utils import (
|
||||||
completion_to_prompt as generic_completion_to_prompt,
|
completion_to_prompt as generic_completion_to_prompt,
|
||||||
)
|
)
|
||||||
|
@ -22,8 +29,14 @@ from llama_index.llms.llama_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.callbacks import CallbackManager
|
from llama_index.callbacks import CallbackManager
|
||||||
from llama_index.llms import (
|
from llama_index.llms import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseGen,
|
||||||
CompletionResponseGen,
|
CompletionResponseGen,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -247,3 +260,17 @@ class SagemakerLLM(CustomLLM):
|
||||||
yield CompletionResponse(delta=delta, text=text, raw=data)
|
yield CompletionResponse(delta=delta, text=text, raw=data)
|
||||||
|
|
||||||
return get_stream()
|
return get_stream()
|
||||||
|
|
||||||
|
@llm_chat_callback()
|
||||||
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||||||
|
prompt = self.messages_to_prompt(messages)
|
||||||
|
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
||||||
|
return completion_response_to_chat_response(completion_response)
|
||||||
|
|
||||||
|
@llm_chat_callback()
|
||||||
|
def stream_chat(
|
||||||
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||||||
|
) -> ChatResponseGen:
|
||||||
|
prompt = self.messages_to_prompt(messages)
|
||||||
|
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
|
||||||
|
return stream_completion_response_to_chat_response(completion_response)
|
||||||
|
|
|
@ -37,8 +37,6 @@ class LLMComponent:
|
||||||
|
|
||||||
self.llm = SagemakerLLM(
|
self.llm = SagemakerLLM(
|
||||||
endpoint_name=settings.sagemaker.endpoint_name,
|
endpoint_name=settings.sagemaker.endpoint_name,
|
||||||
messages_to_prompt=messages_to_prompt,
|
|
||||||
completion_to_prompt=completion_to_prompt,
|
|
||||||
)
|
)
|
||||||
case "openai":
|
case "openai":
|
||||||
from llama_index.llms import OpenAI
|
from llama_index.llms import OpenAI
|
||||||
|
|
Loading…
Reference in New Issue