From aa70d3d9f0b1a6b6244dbd26aff59f2504cee115 Mon Sep 17 00:00:00 2001 From: lopagela Date: Sun, 12 Nov 2023 19:05:00 +0100 Subject: [PATCH] Add simple Basic auth (#1203) * Add simple Basic auth To enable the basic authentication, one must set `server.auth.enabled` to true. The static string defined in `server.auth.secret` must be set in the header `Authorization`. The health check endpoint will always be accessible, no matter the API auth configuration. * Fix linting and type check * Fighting with mypy being too restrictive Had to disable mypy in the `auth` as we are not using the same signature for the authenticated method. mypy was complaining that the signatures of `authenticated` must be identical, no matter in which logical branch we are. Given that fastapi is accomodating itself of method signatures (it will inject the dependencies in the method call), this warning of mypy is actually preventing us to do something legit. mypy doc: https://mypy.readthedocs.io/en/stable/common_issues.html * Write tests to verify that the simple auth is working --- private_gpt/server/chat/chat_router.py | 5 +- private_gpt/server/chunks/chunks_router.py | 5 +- .../server/completions/completions_router.py | 5 +- .../server/embeddings/embeddings_router.py | 5 +- private_gpt/server/health/health_router.py | 1 + private_gpt/server/ingest/ingest_router.py | 5 +- private_gpt/server/utils/__init__.py | 0 private_gpt/server/utils/auth.py | 68 +++++++++++++++++++ private_gpt/settings/settings.py | 24 ++++++- settings-test.yaml | 4 ++ settings.yaml | 6 ++ tests/fixtures/fast_api_test_client.py | 6 ++ tests/server/ingest/test_ingest_routes.py | 21 ++++++ tests/server/utils/test_auth.py | 6 ++ tests/server/utils/test_simple_auth.py | 55 +++++++++++++++ 15 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 private_gpt/server/utils/__init__.py create mode 100644 private_gpt/server/utils/auth.py create mode 100644 tests/server/utils/test_auth.py create mode 100644 tests/server/utils/test_simple_auth.py diff --git a/private_gpt/server/chat/chat_router.py b/private_gpt/server/chat/chat_router.py index 4a0cfd4..bd7034b 100644 --- a/private_gpt/server/chat/chat_router.py +++ b/private_gpt/server/chat/chat_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from llama_index.llms import ChatMessage, MessageRole from pydantic import BaseModel from starlette.responses import StreamingResponse @@ -12,8 +12,9 @@ from private_gpt.open_ai.openai_models import ( to_openai_sse_stream, ) from private_gpt.server.chat.chat_service import ChatService +from private_gpt.server.utils.auth import authenticated -chat_router = APIRouter(prefix="/v1") +chat_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class ChatBody(BaseModel): diff --git a/private_gpt/server/chunks/chunks_router.py b/private_gpt/server/chunks/chunks_router.py index f249a94..d965d98 100644 --- a/private_gpt/server/chunks/chunks_router.py +++ b/private_gpt/server/chunks/chunks_router.py @@ -1,13 +1,14 @@ from typing import Literal -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field from private_gpt.di import root_injector from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk, ChunksService +from private_gpt.server.utils.auth import authenticated -chunks_router = APIRouter(prefix="/v1") +chunks_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class ChunksBody(BaseModel): diff --git a/private_gpt/server/completions/completions_router.py b/private_gpt/server/completions/completions_router.py index d174ec0..4840047 100644 --- a/private_gpt/server/completions/completions_router.py +++ b/private_gpt/server/completions/completions_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from starlette.responses import StreamingResponse @@ -8,8 +8,9 @@ from private_gpt.open_ai.openai_models import ( OpenAIMessage, ) from private_gpt.server.chat.chat_router import ChatBody, chat_completion +from private_gpt.server.utils.auth import authenticated -completions_router = APIRouter(prefix="/v1") +completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class CompletionsBody(BaseModel): diff --git a/private_gpt/server/embeddings/embeddings_router.py b/private_gpt/server/embeddings/embeddings_router.py index 53d6143..f5236c6 100644 --- a/private_gpt/server/embeddings/embeddings_router.py +++ b/private_gpt/server/embeddings/embeddings_router.py @@ -1,6 +1,6 @@ from typing import Literal -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from private_gpt.di import root_injector @@ -8,8 +8,9 @@ from private_gpt.server.embeddings.embeddings_service import ( Embedding, EmbeddingsService, ) +from private_gpt.server.utils.auth import authenticated -embeddings_router = APIRouter(prefix="/v1") +embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class EmbeddingsBody(BaseModel): diff --git a/private_gpt/server/health/health_router.py b/private_gpt/server/health/health_router.py index 8c27e86..4a30b76 100644 --- a/private_gpt/server/health/health_router.py +++ b/private_gpt/server/health/health_router.py @@ -3,6 +3,7 @@ from typing import Literal from fastapi import APIRouter from pydantic import BaseModel, Field +# Not authentication or authorization required to get the health status. health_router = APIRouter() diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index eb74752..d682de7 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -1,12 +1,13 @@ from typing import Literal -from fastapi import APIRouter, HTTPException, UploadFile +from fastapi import APIRouter, Depends, HTTPException, UploadFile from pydantic import BaseModel from private_gpt.di import root_injector from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService +from private_gpt.server.utils.auth import authenticated -ingest_router = APIRouter(prefix="/v1") +ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class IngestResponse(BaseModel): diff --git a/private_gpt/server/utils/__init__.py b/private_gpt/server/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/private_gpt/server/utils/auth.py b/private_gpt/server/utils/auth.py new file mode 100644 index 0000000..371e794 --- /dev/null +++ b/private_gpt/server/utils/auth.py @@ -0,0 +1,68 @@ +"""Authentication mechanism for the API. + +Define a simple mechanism to authenticate requests. +More complex authentication mechanisms can be defined here, and be placed in the +`authenticated` method (being a 'bean' injected in fastapi routers). + +Authorization can also be made after the authentication, and depends on +the authentication. Authorization should not be implemented in this file. + +Authorization can be done by following fastapi's guides: +* https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ +* https://fastapi.tiangolo.com/tutorial/security/ +* https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/ +""" +# mypy: ignore-errors +# Disabled mypy error: All conditional function variants must have identical signatures +# We are changing the implementation of the authenticated method, based on +# the config. If the auth is not enabled, we are not defining the complex method +# with its dependencies. +import logging +import secrets +from typing import Annotated + +from fastapi import Depends, Header, HTTPException + +from private_gpt.settings.settings import settings + +# 401 signify that the request requires authentication. +# 403 signify that the authenticated user is not authorized to perform the operation. +NOT_AUTHENTICATED = HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'}, +) + +logger = logging.getLogger(__name__) + + +def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool: + """Check if the request is authenticated.""" + if not secrets.compare_digest(authorization, settings.server.auth.secret): + # If the "Authorization" header is not the expected one, raise an exception. + raise NOT_AUTHENTICATED + return True + + +if not settings.server.auth.enabled: + logger.debug( + "Defining a dummy authentication mechanism for fastapi, always authenticating requests" + ) + + # Define a dummy authentication method that always returns True. + def authenticated() -> bool: + """Check if the request is authenticated.""" + return True + +else: + logger.info("Defining the given authentication mechanism for the API") + + # Method to be used as a dependency to check if the request is authenticated. + def authenticated( + _simple_authentication: Annotated[bool, Depends(_simple_authentication)] + ) -> bool: + """Check if the request is authenticated.""" + assert settings.server.auth.enabled + if not _simple_authentication: + raise NOT_AUTHENTICATED + return True diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 0d17ffe..9529cac 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -15,7 +15,8 @@ class CorsSettings(BaseModel): enabled: bool = Field( description="Flag indicating if CORS headers are set or not." - "If set to True, the CORS headers will be set to allow all origins, methods and headers." + "If set to True, the CORS headers will be set to allow all origins, methods and headers.", + default=False, ) allow_credentials: bool = Field( description="Indicate that cookies should be supported for cross-origin requests", @@ -41,6 +42,23 @@ class CorsSettings(BaseModel): ) +class AuthSettings(BaseModel): + """Authentication configuration. + + The implementation of the authentication strategy must + """ + + enabled: bool = Field( + description="Flag indicating if authentication is enabled or not.", + default=False, + ) + secret: str = Field( + description="The secret to be used for authentication. " + "It can be any non-blank string. For HTTP basic authentication, " + "this value should be the whole 'Authorization' header that is expected" + ) + + class ServerSettings(BaseModel): env_name: str = Field( description="Name of the environment (prod, staging, local...)" @@ -49,6 +67,10 @@ class ServerSettings(BaseModel): cors: CorsSettings = Field( description="CORS configuration", default=CorsSettings(enabled=False) ) + auth: AuthSettings = Field( + description="Authentication configuration", + default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"), + ) class DataSettings(BaseModel): diff --git a/settings-test.yaml b/settings-test.yaml index 9e93886..965a0ef 100644 --- a/settings-test.yaml +++ b/settings-test.yaml @@ -1,5 +1,9 @@ server: env_name: test + auth: + enabled: false + # Dummy secrets used for tests + secret: "foo bar; dummy secret" data: local_data_folder: local_data/tests diff --git a/settings.yaml b/settings.yaml index ebd0e64..807ad53 100644 --- a/settings.yaml +++ b/settings.yaml @@ -6,6 +6,12 @@ server: allow_origins: ["*"] allow_methods: ["*"] allow_headers: ["*"] + auth: + enabled: false + # python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())' + # 'secret' is the username and 'key' is the password for basic auth by default + # If the auth is enabled, this value must be set in the "Authorization" header of the request. + secret: "Basic c2VjcmV0OmtleQ==" data: local_data_folder: local_data/private_gpt diff --git a/tests/fixtures/fast_api_test_client.py b/tests/fixtures/fast_api_test_client.py index 428f9a3..b91dfec 100644 --- a/tests/fixtures/fast_api_test_client.py +++ b/tests/fixtures/fast_api_test_client.py @@ -1,9 +1,15 @@ import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient from private_gpt.main import app +@pytest.fixture() +def current_test_app() -> FastAPI: + return app + + @pytest.fixture() def test_client() -> TestClient: return TestClient(app) diff --git a/tests/server/ingest/test_ingest_routes.py b/tests/server/ingest/test_ingest_routes.py index b70efcd..a844ee6 100644 --- a/tests/server/ingest/test_ingest_routes.py +++ b/tests/server/ingest/test_ingest_routes.py @@ -1,5 +1,8 @@ +import tempfile from pathlib import Path +from fastapi.testclient import TestClient + from tests.fixtures.ingest_helper import IngestHelper @@ -13,3 +16,21 @@ def test_ingest_accepts_pdf_files(ingest_helper: IngestHelper) -> None: path = Path(__file__).parents[0] / "test.pdf" ingest_result = ingest_helper.ingest_file(path) assert len(ingest_result.data) == 1 + + +def test_ingest_list_returns_something_after_ingestion( + test_client: TestClient, ingest_helper: IngestHelper +) -> None: + response_before = test_client.get("/v1/ingest/list") + count_ingest_before = len(response_before.json()["data"]) + with tempfile.NamedTemporaryFile("w", suffix=".txt") as test_file: + test_file.write("Foo bar; hello there!") + test_file.flush() + test_file.seek(0) + ingest_result = ingest_helper.ingest_file(Path(test_file.name)) + assert len(ingest_result.data) == 1, "The temp doc should have been ingested" + response_after = test_client.get("/v1/ingest/list") + count_ingest_after = len(response_after.json()["data"]) + assert ( + count_ingest_after == count_ingest_before + 1 + ), "The temp doc should be returned" diff --git a/tests/server/utils/test_auth.py b/tests/server/utils/test_auth.py new file mode 100644 index 0000000..7c42dd2 --- /dev/null +++ b/tests/server/utils/test_auth.py @@ -0,0 +1,6 @@ +from fastapi.testclient import TestClient + + +def test_default_does_not_require_auth(test_client: TestClient) -> None: + response_before = test_client.get("/v1/ingest/list") + assert response_before.status_code == 200 diff --git a/tests/server/utils/test_simple_auth.py b/tests/server/utils/test_simple_auth.py new file mode 100644 index 0000000..6c304a5 --- /dev/null +++ b/tests/server/utils/test_simple_auth.py @@ -0,0 +1,55 @@ +"""Tests to validate that the simple authentication mechanism is working. + +NOTE: We are not testing the switch based on the config in + `private_gpt.server.utils.auth`. This is not done because of the way the code + is currently architecture (it is hard to patch the `settings` and the app while + the tests are directly importing them). +""" +from typing import Annotated + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +from private_gpt.server.utils.auth import ( + NOT_AUTHENTICATED, + _simple_authentication, + authenticated, +) +from private_gpt.settings.settings import settings + + +def _copy_simple_authenticated( + _simple_authentication: Annotated[bool, Depends(_simple_authentication)] +) -> bool: + """Check if the request is authenticated.""" + if not _simple_authentication: + raise NOT_AUTHENTICATED + return True + + +@pytest.fixture(autouse=True) +def _patch_authenticated_dependency(current_test_app: FastAPI): + # Patch the server to use simple authentication + current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated + + # Call the actual test + yield + + # Remove the patch for other tests + current_test_app.dependency_overrides = {} + + +def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None: + response = test_client.get("/v1/ingest/list") + assert response.status_code == 401 + + +def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None: + response_fail = test_client.get("/v1/ingest/list") + assert response_fail.status_code == 401 + + response_success = test_client.get( + "/v1/ingest/list", headers={"Authorization": settings.server.auth.secret} + ) + assert response_success.status_code == 200