fix: Remove global state (#1216)

* Remove all global settings state

* chore: remove autogenerated class

* chore: cleanup

* chore: merge conflicts
This commit is contained in:
Pablo Orgaz 2023-11-12 22:20:36 +01:00 committed by GitHub
parent f394ca61bb
commit 022bd718e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 286 additions and 190 deletions

View File

@ -8,4 +8,4 @@ from private_gpt.settings.settings import settings
# Set log_config=None to do not use the uvicorn logging configuration, and # Set log_config=None to do not use the uvicorn logging configuration, and
# use ours instead. For reference, see below: # use ours instead. For reference, see below:
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108 # https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
uvicorn.run(app, host="0.0.0.0", port=settings.server.port, log_config=None) uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)

View File

@ -3,7 +3,7 @@ from llama_index import MockEmbedding
from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.base import BaseEmbedding
from private_gpt.paths import models_cache_path from private_gpt.paths import models_cache_path
from private_gpt.settings.settings import settings from private_gpt.settings.settings import Settings
@singleton @singleton
@ -11,7 +11,7 @@ class EmbeddingComponent:
embedding_model: BaseEmbedding embedding_model: BaseEmbedding
@inject @inject
def __init__(self) -> None: def __init__(self, settings: Settings) -> None:
match settings.llm.mode: match settings.llm.mode:
case "local": case "local":
from llama_index.embeddings import HuggingFaceEmbedding from llama_index.embeddings import HuggingFaceEmbedding

View File

@ -4,7 +4,7 @@ from llama_index.llms.base import LLM
from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt
from private_gpt.paths import models_path from private_gpt.paths import models_path
from private_gpt.settings.settings import settings from private_gpt.settings.settings import Settings
@singleton @singleton
@ -12,7 +12,7 @@ class LLMComponent:
llm: LLM llm: LLM
@inject @inject
def __init__(self) -> None: def __init__(self, settings: Settings) -> None:
match settings.llm.mode: match settings.llm.mode:
case "local": case "local":
from llama_index.llms import LlamaCPP from llama_index.llms import LlamaCPP

View File

@ -1,9 +1,19 @@
from injector import Injector from injector import Injector
from private_gpt.settings.settings import Settings, unsafe_typed_settings
def create_application_injector() -> Injector: def create_application_injector() -> Injector:
injector = Injector(auto_bind=True) _injector = Injector(auto_bind=True)
return injector _injector.binder.bind(Settings, to=unsafe_typed_settings)
return _injector
root_injector: Injector = create_application_injector() """
Global injector for the application.
Avoid using this reference, it will make your code harder to test.
Instead, use the `request.state.injector` reference, which is bound to every request
"""
global_injector: Injector = create_application_injector()

128
private_gpt/launcher.py Normal file
View File

@ -0,0 +1,128 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector
from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI:
# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()
tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[method-assign]
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)
return app

View File

@ -1,124 +1,11 @@
"""FastAPI app creation, logger configuration and main API routes.""" """FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any
import llama_index import llama_index
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from private_gpt.paths import docs_path from private_gpt.di import global_injector
from private_gpt.server.chat.chat_router import chat_router from private_gpt.launcher import create_app
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import settings
logger = logging.getLogger(__name__)
# Add LlamaIndex simple observability # Add LlamaIndex simple observability
llama_index.set_global_handler("simple") llama_index.set_global_handler("simple")
# Start the API app = create_app(global_injector)
with open(docs_path / "description.md") as description_file:
description = description_file.read()
tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]
app = FastAPI()
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[method-assign]
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
PrivateGptUi().mount_in_app(app)

View File

@ -13,4 +13,6 @@ def _absolute_or_from_project_root(path: str) -> Path:
models_path: Path = PROJECT_ROOT_PATH / "models" models_path: Path = PROJECT_ROOT_PATH / "models"
models_cache_path: Path = models_path / "cache" models_cache_path: Path = models_path / "cache"
docs_path: Path = PROJECT_ROOT_PATH / "docs" docs_path: Path = PROJECT_ROOT_PATH / "docs"
local_data_path: Path = _absolute_or_from_project_root(settings.data.local_data_folder) local_data_path: Path = _absolute_or_from_project_root(
settings().data.local_data_folder
)

View File

@ -1,9 +1,8 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import ( from private_gpt.open_ai.openai_models import (
OpenAICompletion, OpenAICompletion,
@ -52,7 +51,9 @@ class ChatBody(BaseModel):
responses={200: {"model": OpenAICompletion}}, responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"], tags=["Contextual Completions"],
) )
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response. """Given a list of messages comprising a conversation, return a response.
If `use_context` is set to `true`, the model will use context coming If `use_context` is set to `true`, the model will use context coming
@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
"finish_reason":null}]} "finish_reason":null}]}
``` ```
""" """
service = root_injector.get(ChatService) service = request.state.injector.get(ChatService)
all_messages = [ all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
] ]

View File

@ -1,9 +1,8 @@
from typing import Literal from typing import Literal
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.utils.auth import authenticated from private_gpt.server.utils.auth import authenticated
@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):
@chunks_router.post("/chunks", tags=["Context Chunks"]) @chunks_router.post("/chunks", tags=["Context Chunks"])
def chunks_retrieval(body: ChunksBody) -> ChunksResponse: def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
"""Given a `text`, returns the most relevant chunks from the ingested documents. """Given a `text`, returns the most relevant chunks from the ingested documents.
The returned information can be used to generate prompts that can be The returned information can be used to generate prompts that can be
@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
`/ingest/list` endpoint. If you want all ingested documents to be used, `/ingest/list` endpoint. If you want all ingested documents to be used,
remove `context_filter` altogether. remove `context_filter` altogether.
""" """
service = root_injector.get(ChunksService) service = request.state.injector.get(ChunksService)
results = service.retrieve_relevant( results = service.retrieve_relevant(
body.text, body.context_filter, body.limit, body.prev_next_chunks body.text, body.context_filter, body.limit, body.prev_next_chunks
) )

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
@ -41,7 +41,9 @@ class CompletionsBody(BaseModel):
responses={200: {"model": OpenAICompletion}}, responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"], tags=["Contextual Completions"],
) )
def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse: def prompt_completion(
request: Request, body: CompletionsBody
) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API. """We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context` Given a prompt, the model will return one predicted completion. If `use_context`
@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
include_sources=body.include_sources, include_sources=body.include_sources,
context_filter=body.context_filter, context_filter=body.context_filter,
) )
return chat_completion(chat_body) return chat_completion(request, chat_body)

View File

@ -1,9 +1,8 @@
from typing import Literal from typing import Literal
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.embeddings.embeddings_service import ( from private_gpt.server.embeddings.embeddings_service import (
Embedding, Embedding,
EmbeddingsService, EmbeddingsService,
@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel):
@embeddings_router.post("/embeddings", tags=["Embeddings"]) @embeddings_router.post("/embeddings", tags=["Embeddings"])
def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse: def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
"""Get a vector representation of a given input. """Get a vector representation of a given input.
That vector representation can be easily consumed That vector representation can be easily consumed
by machine learning models and algorithms. by machine learning models and algorithms.
""" """
service = root_injector.get(EmbeddingsService) service = request.state.injector.get(EmbeddingsService)
input_texts = body.input if isinstance(body.input, list) else [body.input] input_texts = body.input if isinstance(body.input, list) else [body.input]
embeddings = service.texts_embeddings(input_texts) embeddings = service.texts_embeddings(input_texts)
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings) return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)

View File

@ -1,9 +1,8 @@
from typing import Literal from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
from private_gpt.server.utils.auth import authenticated from private_gpt.server.utils.auth import authenticated
@ -17,7 +16,7 @@ class IngestResponse(BaseModel):
@ingest_router.post("/ingest", tags=["Ingestion"]) @ingest_router.post("/ingest", tags=["Ingestion"])
def ingest(file: UploadFile) -> IngestResponse: def ingest(request: Request, file: UploadFile) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context. """Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in The context obtained from files is later used in
@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse:
can be used to filter the context used to create responses in can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs. `/chat/completions`, `/completions`, and `/chunks` APIs.
""" """
service = root_injector.get(IngestService) service = request.state.injector.get(IngestService)
if file.filename is None: if file.filename is None:
raise HTTPException(400, "No file name provided") raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest(file.filename, file.file.read()) ingested_documents = service.ingest(file.filename, file.file.read())
@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse:
@ingest_router.get("/ingest/list", tags=["Ingestion"]) @ingest_router.get("/ingest/list", tags=["Ingestion"])
def list_ingested() -> IngestResponse: def list_ingested(request: Request) -> IngestResponse:
"""Lists already ingested Documents including their Document ID and metadata. """Lists already ingested Documents including their Document ID and metadata.
Those IDs can be used to filter the context used to create responses Those IDs can be used to filter the context used to create responses
in `/chat/completions`, `/completions`, and `/chunks` APIs. in `/chat/completions`, `/completions`, and `/chunks` APIs.
""" """
service = root_injector.get(IngestService) service = request.state.injector.get(IngestService)
ingested_documents = service.list_ingested() ingested_documents = service.list_ingested()
return IngestResponse(object="list", model="private-gpt", data=ingested_documents) return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"]) @ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
def delete_ingested(doc_id: str) -> None: def delete_ingested(request: Request, doc_id: str) -> None:
"""Delete the specified ingested Document. """Delete the specified ingested Document.
The `doc_id` can be obtained from the `GET /ingest/list` endpoint. The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
The document will be effectively deleted from your storage context. The document will be effectively deleted from your storage context.
""" """
service = root_injector.get(IngestService) service = request.state.injector.get(IngestService)
service.delete(doc_id) service.delete(doc_id)

View File

@ -38,13 +38,13 @@ logger = logging.getLogger(__name__)
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool: def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
"""Check if the request is authenticated.""" """Check if the request is authenticated."""
if not secrets.compare_digest(authorization, settings.server.auth.secret): if not secrets.compare_digest(authorization, settings().server.auth.secret):
# If the "Authorization" header is not the expected one, raise an exception. # If the "Authorization" header is not the expected one, raise an exception.
raise NOT_AUTHENTICATED raise NOT_AUTHENTICATED
return True return True
if not settings.server.auth.enabled: if not settings().server.auth.enabled:
logger.debug( logger.debug(
"Defining a dummy authentication mechanism for fastapi, always authenticating requests" "Defining a dummy authentication mechanism for fastapi, always authenticating requests"
) )
@ -62,7 +62,7 @@ else:
_simple_authentication: Annotated[bool, Depends(_simple_authentication)] _simple_authentication: Annotated[bool, Depends(_simple_authentication)]
) -> bool: ) -> bool:
"""Check if the request is authenticated.""" """Check if the request is authenticated."""
assert settings.server.auth.enabled assert settings().server.auth.enabled
if not _simple_authentication: if not _simple_authentication:
raise NOT_AUTHENTICATED raise NOT_AUTHENTICATED
return True return True

View File

@ -2,7 +2,7 @@ from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.settings.settings_loader import load_active_profiles from private_gpt.settings.settings_loader import load_active_settings
class CorsSettings(BaseModel): class CorsSettings(BaseModel):
@ -114,4 +114,29 @@ class Settings(BaseModel):
openai: OpenAISettings openai: OpenAISettings
settings = Settings(**load_active_profiles()) """
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_settings = load_active_settings()
"""
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_typed_settings = Settings(**unsafe_settings)
def settings() -> Settings:
"""Get the current loaded settings from the DI container.
This method exists to keep compatibility with the existing code,
that require global access to the settings.
For regular components use dependency injection instead.
"""
from private_gpt.di import global_injector
return global_injector.get(Settings)

View File

@ -2,6 +2,7 @@ import functools
import logging import logging
import os import os
import sys import sys
from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -28,7 +29,11 @@ active_profiles: list[str] = unique_list(
) )
def load_profile(profile: str) -> dict[str, Any]: def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
return functools.reduce(deep_update, settings, {})
def load_settings_from_profile(profile: str) -> dict[str, Any]:
if profile == "default": if profile == "default":
profile_file_name = "settings.yaml" profile_file_name = "settings.yaml"
else: else:
@ -42,9 +47,11 @@ def load_profile(profile: str) -> dict[str, Any]:
return config return config
def load_active_profiles() -> dict[str, Any]: def load_active_settings() -> dict[str, Any]:
"""Load active profiles and merge them.""" """Load active profiles and merge them."""
logger.info("Starting application with profiles=%s", active_profiles) logger.info("Starting application with profiles=%s", active_profiles)
loaded_profiles = [load_profile(profile) for profile in active_profiles] loaded_profiles = [
merged: dict[str, Any] = functools.reduce(deep_update, loaded_profiles, {}) load_settings_from_profile(profile) for profile in active_profiles
]
merged: dict[str, Any] = merge_settings(loaded_profiles)
return merged return merged

View File

@ -8,10 +8,11 @@ from typing import Any, TextIO
import gradio as gr # type: ignore import gradio as gr # type: ignore
from fastapi import FastAPI from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.di import root_injector from private_gpt.di import global_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_service import IngestService
@ -48,11 +49,18 @@ class Source(BaseModel):
return curated_sources return curated_sources
@singleton
class PrivateGptUi: class PrivateGptUi:
def __init__(self) -> None: @inject
self._ingest_service = root_injector.get(IngestService) def __init__(
self._chat_service = root_injector.get(ChatService) self,
self._chunks_service = root_injector.get(ChunksService) ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
self._chunks_service = chunks_service
# Cache the UI blocks # Cache the UI blocks
self._ui_block = None self._ui_block = None
@ -198,7 +206,7 @@ class PrivateGptUi:
_ = gr.ChatInterface( _ = gr.ChatInterface(
self._chat, self._chat,
chatbot=gr.Chatbot( chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}", label=f"LLM: {settings().llm.mode}",
show_copy_button=True, show_copy_button=True,
render=False, render=False,
avatar_images=( avatar_images=(
@ -217,16 +225,15 @@ class PrivateGptUi:
self._ui_block = self._build_ui_blocks() self._ui_block = self._build_ui_blocks()
return self._ui_block return self._ui_block
def mount_in_app(self, app: FastAPI) -> None: def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks() blocks = self.get_ui_blocks()
blocks.queue() blocks.queue()
base_path = settings.ui.path logger.info("Mounting the gradio UI, at path=%s", path)
logger.info("Mounting the gradio UI, at path=%s", base_path) gr.mount_gradio_app(app, blocks, path=path)
gr.mount_gradio_app(app, blocks, path=base_path)
if __name__ == "__main__": if __name__ == "__main__":
ui = PrivateGptUi() ui = global_injector.get(PrivateGptUi)
_blocks = ui.get_ui_blocks() _blocks = ui.get_ui_blocks()
_blocks.queue() _blocks.queue()
_blocks.launch(debug=False, show_api=False) _blocks.launch(debug=False, show_api=False)

View File

@ -2,13 +2,13 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from private_gpt.di import root_injector from private_gpt.di import global_injector
from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.ingest.ingest_watcher import IngestWatcher from private_gpt.server.ingest.ingest_watcher import IngestWatcher
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ingest_service = root_injector.get(IngestService) ingest_service = global_injector.get(IngestService)
parser = argparse.ArgumentParser(prog="ingest_folder.py") parser = argparse.ArgumentParser(prog="ingest_folder.py")
parser.add_argument("folder", help="Folder to ingest") parser.add_argument("folder", help="Folder to ingest")

View File

@ -9,9 +9,9 @@ from private_gpt.settings.settings import settings
os.makedirs(models_path, exist_ok=True) os.makedirs(models_path, exist_ok=True)
embedding_path = models_path / "embedding" embedding_path = models_path / "embedding"
print(f"Downloading embedding {settings.local.embedding_hf_model_name}") print(f"Downloading embedding {settings().local.embedding_hf_model_name}")
snapshot_download( snapshot_download(
repo_id=settings.local.embedding_hf_model_name, repo_id=settings().local.embedding_hf_model_name,
cache_dir=models_cache_path, cache_dir=models_cache_path,
local_dir=embedding_path, local_dir=embedding_path,
) )
@ -20,8 +20,8 @@ print("Downloading models for local execution...")
# Download LLM and create a symlink to the model file # Download LLM and create a symlink to the model file
hf_hub_download( hf_hub_download(
repo_id=settings.local.llm_hf_repo_id, repo_id=settings().local.llm_hf_repo_id,
filename=settings.local.llm_hf_model_file, filename=settings().local.llm_hf_model_file,
cache_dir=models_cache_path, cache_dir=models_cache_path,
local_dir=models_path, local_dir=models_path,
) )

View File

@ -5,8 +5,12 @@ server:
# Dummy secrets used for tests # Dummy secrets used for tests
secret: "foo bar; dummy secret" secret: "foo bar; dummy secret"
data: data:
local_data_folder: local_data/tests local_data_folder: local_data/tests
llm: llm:
mode: mock mode: mock
ui:
enabled: false

View File

@ -1,15 +1,14 @@
import pytest import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from private_gpt.main import app from private_gpt.launcher import create_app
from tests.fixtures.mock_injector import MockInjector
@pytest.fixture() @pytest.fixture()
def current_test_app() -> FastAPI: def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
return app if request is not None and hasattr(request, "param"):
injector.bind_settings(request.param or {})
app_under_test = create_app(injector.test_injector)
@pytest.fixture() return TestClient(app_under_test)
def test_client() -> TestClient:
return TestClient(app)

View File

@ -1,10 +1,13 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from injector import Provider, ScopeDecorator, singleton from injector import Provider, ScopeDecorator, singleton
from private_gpt.di import create_application_injector from private_gpt.di import create_application_injector
from private_gpt.settings.settings import Settings, unsafe_settings
from private_gpt.settings.settings_loader import merge_settings
from private_gpt.utils.typing import T from private_gpt.utils.typing import T
@ -24,6 +27,12 @@ class MockInjector:
self.test_injector.binder.bind(interface, to=mock, scope=scope) self.test_injector.binder.bind(interface, to=mock, scope=scope)
return mock # type: ignore return mock # type: ignore
def bind_settings(self, settings: dict[str, Any]) -> Settings:
merged = merge_settings([unsafe_settings, settings])
new_settings = Settings(**merged)
self.test_injector.binder.bind(Settings, new_settings)
return new_settings
def get(self, interface: type[T]) -> T: def get(self, interface: type[T]) -> T:
return self.test_injector.get(interface) return self.test_injector.get(interface)

View File

@ -8,7 +8,7 @@ NOTE: We are not testing the switch based on the config in
from typing import Annotated from typing import Annotated
import pytest import pytest
from fastapi import Depends, FastAPI from fastapi import Depends
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from private_gpt.server.utils.auth import ( from private_gpt.server.utils.auth import (
@ -29,15 +29,16 @@ def _copy_simple_authenticated(
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _patch_authenticated_dependency(current_test_app: FastAPI): def _patch_authenticated_dependency(test_client: TestClient):
# Patch the server to use simple authentication # Patch the server to use simple authentication
current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated
test_client.app.dependency_overrides[authenticated] = _copy_simple_authenticated
# Call the actual test # Call the actual test
yield yield
# Remove the patch for other tests # Remove the patch for other tests
current_test_app.dependency_overrides = {} test_client.app.dependency_overrides = {}
def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None: def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None:
@ -50,6 +51,6 @@ def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None:
assert response_fail.status_code == 401 assert response_fail.status_code == 401
response_success = test_client.get( response_success = test_client.get(
"/v1/ingest/list", headers={"Authorization": settings.server.auth.secret} "/v1/ingest/list", headers={"Authorization": settings().server.auth.secret}
) )
assert response_success.status_code == 200 assert response_success.status_code == 200

View File

@ -1,5 +1,12 @@
from private_gpt.settings.settings import settings from private_gpt.settings.settings import Settings, settings
from tests.fixtures.mock_injector import MockInjector
def test_settings_are_loaded_and_merged() -> None: def test_settings_are_loaded_and_merged() -> None:
assert settings.server.env_name == "test" assert settings().server.env_name == "test"
def test_settings_can_be_overriden(injector: MockInjector) -> None:
injector.bind_settings({"server": {"env_name": "overriden"}})
mocked_settings = injector.get(Settings)
assert mocked_settings.server.env_name == "overriden"

10
tests/ui/test_ui.py Normal file
View File

@ -0,0 +1,10 @@
import pytest
from fastapi.testclient import TestClient
@pytest.mark.parametrize(
"test_client", [{"ui": {"enabled": True, "path": "/ui"}}], indirect=True
)
def test_ui_starts_in_the_given_endpoint(test_client: TestClient) -> None:
response = test_client.get("/ui")
assert response.status_code == 200