diff --git a/private_gpt/server/chat/chat_router.py b/private_gpt/server/chat/chat_router.py index 4a0cfd4..bd7034b 100644 --- a/private_gpt/server/chat/chat_router.py +++ b/private_gpt/server/chat/chat_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from llama_index.llms import ChatMessage, MessageRole from pydantic import BaseModel from starlette.responses import StreamingResponse @@ -12,8 +12,9 @@ from private_gpt.open_ai.openai_models import ( to_openai_sse_stream, ) from private_gpt.server.chat.chat_service import ChatService +from private_gpt.server.utils.auth import authenticated -chat_router = APIRouter(prefix="/v1") +chat_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class ChatBody(BaseModel): diff --git a/private_gpt/server/chunks/chunks_router.py b/private_gpt/server/chunks/chunks_router.py index f249a94..d965d98 100644 --- a/private_gpt/server/chunks/chunks_router.py +++ b/private_gpt/server/chunks/chunks_router.py @@ -1,13 +1,14 @@ from typing import Literal -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field from private_gpt.di import root_injector from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk, ChunksService +from private_gpt.server.utils.auth import authenticated -chunks_router = APIRouter(prefix="/v1") +chunks_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class ChunksBody(BaseModel): diff --git a/private_gpt/server/completions/completions_router.py b/private_gpt/server/completions/completions_router.py index d174ec0..4840047 100644 --- a/private_gpt/server/completions/completions_router.py +++ b/private_gpt/server/completions/completions_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from starlette.responses import StreamingResponse @@ -8,8 +8,9 @@ from private_gpt.open_ai.openai_models import ( OpenAIMessage, ) from private_gpt.server.chat.chat_router import ChatBody, chat_completion +from private_gpt.server.utils.auth import authenticated -completions_router = APIRouter(prefix="/v1") +completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class CompletionsBody(BaseModel): diff --git a/private_gpt/server/embeddings/embeddings_router.py b/private_gpt/server/embeddings/embeddings_router.py index 53d6143..f5236c6 100644 --- a/private_gpt/server/embeddings/embeddings_router.py +++ b/private_gpt/server/embeddings/embeddings_router.py @@ -1,6 +1,6 @@ from typing import Literal -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from private_gpt.di import root_injector @@ -8,8 +8,9 @@ from private_gpt.server.embeddings.embeddings_service import ( Embedding, EmbeddingsService, ) +from private_gpt.server.utils.auth import authenticated -embeddings_router = APIRouter(prefix="/v1") +embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class EmbeddingsBody(BaseModel): diff --git a/private_gpt/server/health/health_router.py b/private_gpt/server/health/health_router.py index 8c27e86..4a30b76 100644 --- a/private_gpt/server/health/health_router.py +++ b/private_gpt/server/health/health_router.py @@ -3,6 +3,7 @@ from typing import Literal from fastapi import APIRouter from pydantic import BaseModel, Field +# Not authentication or authorization required to get the health status. health_router = APIRouter() diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index eb74752..d682de7 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -1,12 +1,13 @@ from typing import Literal -from fastapi import APIRouter, HTTPException, UploadFile +from fastapi import APIRouter, Depends, HTTPException, UploadFile from pydantic import BaseModel from private_gpt.di import root_injector from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService +from private_gpt.server.utils.auth import authenticated -ingest_router = APIRouter(prefix="/v1") +ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class IngestResponse(BaseModel): diff --git a/private_gpt/server/utils/__init__.py b/private_gpt/server/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/private_gpt/server/utils/auth.py b/private_gpt/server/utils/auth.py new file mode 100644 index 0000000..371e794 --- /dev/null +++ b/private_gpt/server/utils/auth.py @@ -0,0 +1,68 @@ +"""Authentication mechanism for the API. + +Define a simple mechanism to authenticate requests. +More complex authentication mechanisms can be defined here, and be placed in the +`authenticated` method (being a 'bean' injected in fastapi routers). + +Authorization can also be made after the authentication, and depends on +the authentication. Authorization should not be implemented in this file. + +Authorization can be done by following fastapi's guides: +* https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ +* https://fastapi.tiangolo.com/tutorial/security/ +* https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/ +""" +# mypy: ignore-errors +# Disabled mypy error: All conditional function variants must have identical signatures +# We are changing the implementation of the authenticated method, based on +# the config. If the auth is not enabled, we are not defining the complex method +# with its dependencies. +import logging +import secrets +from typing import Annotated + +from fastapi import Depends, Header, HTTPException + +from private_gpt.settings.settings import settings + +# 401 signify that the request requires authentication. +# 403 signify that the authenticated user is not authorized to perform the operation. +NOT_AUTHENTICATED = HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'}, +) + +logger = logging.getLogger(__name__) + + +def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool: + """Check if the request is authenticated.""" + if not secrets.compare_digest(authorization, settings.server.auth.secret): + # If the "Authorization" header is not the expected one, raise an exception. + raise NOT_AUTHENTICATED + return True + + +if not settings.server.auth.enabled: + logger.debug( + "Defining a dummy authentication mechanism for fastapi, always authenticating requests" + ) + + # Define a dummy authentication method that always returns True. + def authenticated() -> bool: + """Check if the request is authenticated.""" + return True + +else: + logger.info("Defining the given authentication mechanism for the API") + + # Method to be used as a dependency to check if the request is authenticated. + def authenticated( + _simple_authentication: Annotated[bool, Depends(_simple_authentication)] + ) -> bool: + """Check if the request is authenticated.""" + assert settings.server.auth.enabled + if not _simple_authentication: + raise NOT_AUTHENTICATED + return True diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 0d17ffe..9529cac 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -15,7 +15,8 @@ class CorsSettings(BaseModel): enabled: bool = Field( description="Flag indicating if CORS headers are set or not." - "If set to True, the CORS headers will be set to allow all origins, methods and headers." + "If set to True, the CORS headers will be set to allow all origins, methods and headers.", + default=False, ) allow_credentials: bool = Field( description="Indicate that cookies should be supported for cross-origin requests", @@ -41,6 +42,23 @@ class CorsSettings(BaseModel): ) +class AuthSettings(BaseModel): + """Authentication configuration. + + The implementation of the authentication strategy must + """ + + enabled: bool = Field( + description="Flag indicating if authentication is enabled or not.", + default=False, + ) + secret: str = Field( + description="The secret to be used for authentication. " + "It can be any non-blank string. For HTTP basic authentication, " + "this value should be the whole 'Authorization' header that is expected" + ) + + class ServerSettings(BaseModel): env_name: str = Field( description="Name of the environment (prod, staging, local...)" @@ -49,6 +67,10 @@ class ServerSettings(BaseModel): cors: CorsSettings = Field( description="CORS configuration", default=CorsSettings(enabled=False) ) + auth: AuthSettings = Field( + description="Authentication configuration", + default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"), + ) class DataSettings(BaseModel): diff --git a/settings-test.yaml b/settings-test.yaml index 9e93886..965a0ef 100644 --- a/settings-test.yaml +++ b/settings-test.yaml @@ -1,5 +1,9 @@ server: env_name: test + auth: + enabled: false + # Dummy secrets used for tests + secret: "foo bar; dummy secret" data: local_data_folder: local_data/tests diff --git a/settings.yaml b/settings.yaml index ebd0e64..807ad53 100644 --- a/settings.yaml +++ b/settings.yaml @@ -6,6 +6,12 @@ server: allow_origins: ["*"] allow_methods: ["*"] allow_headers: ["*"] + auth: + enabled: false + # python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())' + # 'secret' is the username and 'key' is the password for basic auth by default + # If the auth is enabled, this value must be set in the "Authorization" header of the request. + secret: "Basic c2VjcmV0OmtleQ==" data: local_data_folder: local_data/private_gpt diff --git a/tests/fixtures/fast_api_test_client.py b/tests/fixtures/fast_api_test_client.py index 428f9a3..b91dfec 100644 --- a/tests/fixtures/fast_api_test_client.py +++ b/tests/fixtures/fast_api_test_client.py @@ -1,9 +1,15 @@ import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient from private_gpt.main import app +@pytest.fixture() +def current_test_app() -> FastAPI: + return app + + @pytest.fixture() def test_client() -> TestClient: return TestClient(app) diff --git a/tests/server/ingest/test_ingest_routes.py b/tests/server/ingest/test_ingest_routes.py index b70efcd..a844ee6 100644 --- a/tests/server/ingest/test_ingest_routes.py +++ b/tests/server/ingest/test_ingest_routes.py @@ -1,5 +1,8 @@ +import tempfile from pathlib import Path +from fastapi.testclient import TestClient + from tests.fixtures.ingest_helper import IngestHelper @@ -13,3 +16,21 @@ def test_ingest_accepts_pdf_files(ingest_helper: IngestHelper) -> None: path = Path(__file__).parents[0] / "test.pdf" ingest_result = ingest_helper.ingest_file(path) assert len(ingest_result.data) == 1 + + +def test_ingest_list_returns_something_after_ingestion( + test_client: TestClient, ingest_helper: IngestHelper +) -> None: + response_before = test_client.get("/v1/ingest/list") + count_ingest_before = len(response_before.json()["data"]) + with tempfile.NamedTemporaryFile("w", suffix=".txt") as test_file: + test_file.write("Foo bar; hello there!") + test_file.flush() + test_file.seek(0) + ingest_result = ingest_helper.ingest_file(Path(test_file.name)) + assert len(ingest_result.data) == 1, "The temp doc should have been ingested" + response_after = test_client.get("/v1/ingest/list") + count_ingest_after = len(response_after.json()["data"]) + assert ( + count_ingest_after == count_ingest_before + 1 + ), "The temp doc should be returned" diff --git a/tests/server/utils/test_auth.py b/tests/server/utils/test_auth.py new file mode 100644 index 0000000..7c42dd2 --- /dev/null +++ b/tests/server/utils/test_auth.py @@ -0,0 +1,6 @@ +from fastapi.testclient import TestClient + + +def test_default_does_not_require_auth(test_client: TestClient) -> None: + response_before = test_client.get("/v1/ingest/list") + assert response_before.status_code == 200 diff --git a/tests/server/utils/test_simple_auth.py b/tests/server/utils/test_simple_auth.py new file mode 100644 index 0000000..6c304a5 --- /dev/null +++ b/tests/server/utils/test_simple_auth.py @@ -0,0 +1,55 @@ +"""Tests to validate that the simple authentication mechanism is working. + +NOTE: We are not testing the switch based on the config in + `private_gpt.server.utils.auth`. This is not done because of the way the code + is currently architecture (it is hard to patch the `settings` and the app while + the tests are directly importing them). +""" +from typing import Annotated + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +from private_gpt.server.utils.auth import ( + NOT_AUTHENTICATED, + _simple_authentication, + authenticated, +) +from private_gpt.settings.settings import settings + + +def _copy_simple_authenticated( + _simple_authentication: Annotated[bool, Depends(_simple_authentication)] +) -> bool: + """Check if the request is authenticated.""" + if not _simple_authentication: + raise NOT_AUTHENTICATED + return True + + +@pytest.fixture(autouse=True) +def _patch_authenticated_dependency(current_test_app: FastAPI): + # Patch the server to use simple authentication + current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated + + # Call the actual test + yield + + # Remove the patch for other tests + current_test_app.dependency_overrides = {} + + +def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None: + response = test_client.get("/v1/ingest/list") + assert response.status_code == 401 + + +def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None: + response_fail = test_client.get("/v1/ingest/list") + assert response_fail.status_code == 401 + + response_success = test_client.get( + "/v1/ingest/list", headers={"Authorization": settings.server.auth.secret} + ) + assert response_success.status_code == 200