fix: sagemaker config and chat methods (#1142)

This commit is contained in:
Pablo Orgaz 2023-10-30 21:54:41 +01:00 committed by GitHub
parent b0e258265f
commit a517a588c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 4 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import json
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import boto3 # type: ignore
from llama_index.bridge.pydantic import Field
@ -13,7 +13,14 @@ from llama_index.llms import (
CustomLLM,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.llms.base import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.llms.llama_utils import (
completion_to_prompt as generic_completion_to_prompt,
)
@ -22,8 +29,14 @@ from llama_index.llms.llama_utils import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponseGen,
)
@ -247,3 +260,17 @@ class SagemakerLLM(CustomLLM):
yield CompletionResponse(delta=delta, text=text, raw=data)
return get_stream()
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)

View File

@ -37,8 +37,6 @@ class LLMComponent:
self.llm = SagemakerLLM(
endpoint_name=settings.sagemaker.endpoint_name,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
)
case "openai":
from llama_index.llms import OpenAI