diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 77e8c3d..1f05223 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -55,8 +55,17 @@ class EmbeddingComponent: "OpenAI dependencies not found, install with `poetry install --extras embeddings-openai`" ) from e - openai_settings = settings.openai.api_key - self.embedding_model = OpenAIEmbedding(api_key=openai_settings) + api_base = ( + settings.openai.embedding_api_base or settings.openai.api_base + ) + api_key = settings.openai.embedding_api_key or settings.openai.api_key + model = settings.openai.embedding_model + + self.embedding_model = OpenAIEmbedding( + api_base=api_base, + api_key=api_key, + model=model, + ) case "ollama": try: from llama_index.embeddings.ollama import ( # type: ignore diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index bd83fb8..28ece45 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -209,6 +209,15 @@ class OpenAISettings(BaseModel): 120.0, description="Time elapsed until openailike server times out the request. Default is 120s. Format is float. ", ) + embedding_api_base: str = Field( + None, + description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.", + ) + embedding_api_key: str + embedding_model: str = Field( + "text-embedding-ada-002", + description="OpenAI embedding Model to use. Example: 'text-embedding-3-large'.", + ) class OllamaSettings(BaseModel): diff --git a/settings.yaml b/settings.yaml index 06fcd63..b524ef6 100644 --- a/settings.yaml +++ b/settings.yaml @@ -95,6 +95,7 @@ sagemaker: openai: api_key: ${OPENAI_API_KEY:} model: gpt-3.5-turbo + embedding_api_key: ${OPENAI_API_KEY:} ollama: llm_model: llama2