fix: sagemaker config and chat methods (#1142)
This commit is contained in:
parent
b0e258265f
commit
a517a588c4
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import boto3 # type: ignore
|
||||
from llama_index.bridge.pydantic import Field
|
||||
|
@ -13,7 +13,14 @@ from llama_index.llms import (
|
|||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.llms.base import llm_completion_callback
|
||||
from llama_index.llms.base import (
|
||||
llm_chat_callback,
|
||||
llm_completion_callback,
|
||||
)
|
||||
from llama_index.llms.generic_utils import (
|
||||
completion_response_to_chat_response,
|
||||
stream_completion_response_to_chat_response,
|
||||
)
|
||||
from llama_index.llms.llama_utils import (
|
||||
completion_to_prompt as generic_completion_to_prompt,
|
||||
)
|
||||
|
@ -22,8 +29,14 @@ from llama_index.llms.llama_utils import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.llms import (
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseGen,
|
||||
CompletionResponseGen,
|
||||
)
|
||||
|
||||
|
@ -247,3 +260,17 @@ class SagemakerLLM(CustomLLM):
|
|||
yield CompletionResponse(delta=delta, text=text, raw=data)
|
||||
|
||||
return get_stream()
|
||||
|
||||
@llm_chat_callback()
|
||||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
||||
return completion_response_to_chat_response(completion_response)
|
||||
|
||||
@llm_chat_callback()
|
||||
def stream_chat(
|
||||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||||
) -> ChatResponseGen:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
|
||||
return stream_completion_response_to_chat_response(completion_response)
|
||||
|
|
|
@ -37,8 +37,6 @@ class LLMComponent:
|
|||
|
||||
self.llm = SagemakerLLM(
|
||||
endpoint_name=settings.sagemaker.endpoint_name,
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
completion_to_prompt=completion_to_prompt,
|
||||
)
|
||||
case "openai":
|
||||
from llama_index.llms import OpenAI
|
||||
|
|
Loading…
Reference in New Issue