fix: sagemaker config and chat methods (#1142)
This commit is contained in:
		
							parent
							
								
									b0e258265f
								
							
						
					
					
						commit
						a517a588c4
					
				|  | @ -4,7 +4,7 @@ from __future__ import annotations | ||||||
| import io | import io | ||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| from typing import TYPE_CHECKING, Any | from typing import TYPE_CHECKING | ||||||
| 
 | 
 | ||||||
| import boto3  # type: ignore | import boto3  # type: ignore | ||||||
| from llama_index.bridge.pydantic import Field | from llama_index.bridge.pydantic import Field | ||||||
|  | @ -13,7 +13,14 @@ from llama_index.llms import ( | ||||||
|     CustomLLM, |     CustomLLM, | ||||||
|     LLMMetadata, |     LLMMetadata, | ||||||
| ) | ) | ||||||
| from llama_index.llms.base import llm_completion_callback | from llama_index.llms.base import ( | ||||||
|  |     llm_chat_callback, | ||||||
|  |     llm_completion_callback, | ||||||
|  | ) | ||||||
|  | from llama_index.llms.generic_utils import ( | ||||||
|  |     completion_response_to_chat_response, | ||||||
|  |     stream_completion_response_to_chat_response, | ||||||
|  | ) | ||||||
| from llama_index.llms.llama_utils import ( | from llama_index.llms.llama_utils import ( | ||||||
|     completion_to_prompt as generic_completion_to_prompt, |     completion_to_prompt as generic_completion_to_prompt, | ||||||
| ) | ) | ||||||
|  | @ -22,8 +29,14 @@ from llama_index.llms.llama_utils import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  |     from collections.abc import Sequence | ||||||
|  |     from typing import Any | ||||||
|  | 
 | ||||||
|     from llama_index.callbacks import CallbackManager |     from llama_index.callbacks import CallbackManager | ||||||
|     from llama_index.llms import ( |     from llama_index.llms import ( | ||||||
|  |         ChatMessage, | ||||||
|  |         ChatResponse, | ||||||
|  |         ChatResponseGen, | ||||||
|         CompletionResponseGen, |         CompletionResponseGen, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | @ -247,3 +260,17 @@ class SagemakerLLM(CustomLLM): | ||||||
|                         yield CompletionResponse(delta=delta, text=text, raw=data) |                         yield CompletionResponse(delta=delta, text=text, raw=data) | ||||||
| 
 | 
 | ||||||
|         return get_stream() |         return get_stream() | ||||||
|  | 
 | ||||||
|  |     @llm_chat_callback() | ||||||
|  |     def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | ||||||
|  |         prompt = self.messages_to_prompt(messages) | ||||||
|  |         completion_response = self.complete(prompt, formatted=True, **kwargs) | ||||||
|  |         return completion_response_to_chat_response(completion_response) | ||||||
|  | 
 | ||||||
|  |     @llm_chat_callback() | ||||||
|  |     def stream_chat( | ||||||
|  |         self, messages: Sequence[ChatMessage], **kwargs: Any | ||||||
|  |     ) -> ChatResponseGen: | ||||||
|  |         prompt = self.messages_to_prompt(messages) | ||||||
|  |         completion_response = self.stream_complete(prompt, formatted=True, **kwargs) | ||||||
|  |         return stream_completion_response_to_chat_response(completion_response) | ||||||
|  |  | ||||||
|  | @ -37,8 +37,6 @@ class LLMComponent: | ||||||
| 
 | 
 | ||||||
|                 self.llm = SagemakerLLM( |                 self.llm = SagemakerLLM( | ||||||
|                     endpoint_name=settings.sagemaker.endpoint_name, |                     endpoint_name=settings.sagemaker.endpoint_name, | ||||||
|                     messages_to_prompt=messages_to_prompt, |  | ||||||
|                     completion_to_prompt=completion_to_prompt, |  | ||||||
|                 ) |                 ) | ||||||
|             case "openai": |             case "openai": | ||||||
|                 from llama_index.llms import OpenAI |                 from llama_index.llms import OpenAI | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue