fix: sagemaker config and chat methods (#1142)
This commit is contained in:
		
							parent
							
								
									b0e258265f
								
							
						
					
					
						commit
						a517a588c4
					
				|  | @ -4,7 +4,7 @@ from __future__ import annotations | |||
| import io | ||||
| import json | ||||
| import logging | ||||
| from typing import TYPE_CHECKING, Any | ||||
| from typing import TYPE_CHECKING | ||||
| 
 | ||||
| import boto3  # type: ignore | ||||
| from llama_index.bridge.pydantic import Field | ||||
|  | @ -13,7 +13,14 @@ from llama_index.llms import ( | |||
|     CustomLLM, | ||||
|     LLMMetadata, | ||||
| ) | ||||
| from llama_index.llms.base import llm_completion_callback | ||||
| from llama_index.llms.base import ( | ||||
|     llm_chat_callback, | ||||
|     llm_completion_callback, | ||||
| ) | ||||
| from llama_index.llms.generic_utils import ( | ||||
|     completion_response_to_chat_response, | ||||
|     stream_completion_response_to_chat_response, | ||||
| ) | ||||
| from llama_index.llms.llama_utils import ( | ||||
|     completion_to_prompt as generic_completion_to_prompt, | ||||
| ) | ||||
|  | @ -22,8 +29,14 @@ from llama_index.llms.llama_utils import ( | |||
| ) | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from collections.abc import Sequence | ||||
|     from typing import Any | ||||
| 
 | ||||
|     from llama_index.callbacks import CallbackManager | ||||
|     from llama_index.llms import ( | ||||
|         ChatMessage, | ||||
|         ChatResponse, | ||||
|         ChatResponseGen, | ||||
|         CompletionResponseGen, | ||||
|     ) | ||||
| 
 | ||||
|  | @ -247,3 +260,17 @@ class SagemakerLLM(CustomLLM): | |||
|                         yield CompletionResponse(delta=delta, text=text, raw=data) | ||||
| 
 | ||||
|         return get_stream() | ||||
| 
 | ||||
|     @llm_chat_callback() | ||||
|     def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | ||||
|         prompt = self.messages_to_prompt(messages) | ||||
|         completion_response = self.complete(prompt, formatted=True, **kwargs) | ||||
|         return completion_response_to_chat_response(completion_response) | ||||
| 
 | ||||
|     @llm_chat_callback() | ||||
|     def stream_chat( | ||||
|         self, messages: Sequence[ChatMessage], **kwargs: Any | ||||
|     ) -> ChatResponseGen: | ||||
|         prompt = self.messages_to_prompt(messages) | ||||
|         completion_response = self.stream_complete(prompt, formatted=True, **kwargs) | ||||
|         return stream_completion_response_to_chat_response(completion_response) | ||||
|  |  | |||
|  | @ -37,8 +37,6 @@ class LLMComponent: | |||
| 
 | ||||
|                 self.llm = SagemakerLLM( | ||||
|                     endpoint_name=settings.sagemaker.endpoint_name, | ||||
|                     messages_to_prompt=messages_to_prompt, | ||||
|                     completion_to_prompt=completion_to_prompt, | ||||
|                 ) | ||||
|             case "openai": | ||||
|                 from llama_index.llms import OpenAI | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue