diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 41b4e1f..61fe6fa 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -13,14 +13,19 @@ class EmbeddingComponent: @inject def __init__(self) -> None: match settings.llm.mode: - case "mock": - # Not a random number, is the dimensionality used by - # the default embedding model - self.embedding_model = MockEmbedding(384) - case _: + case "local": from llama_index.embeddings import HuggingFaceEmbedding self.embedding_model = HuggingFaceEmbedding( model_name=settings.local.embedding_hf_model_name, cache_folder=str(models_cache_path), ) + case "openai": + from llama_index import OpenAIEmbedding + + openai_settings = settings.openai.api_key + self.embedding_model = OpenAIEmbedding(api_key=openai_settings) + case "mock": + # Not a random number, is the dimensionality used by + # the default embedding model + self.embedding_model = MockEmbedding(384)