diff --git a/README.md b/README.md index 7e92115..35c9ee1 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,6 @@ you want to give a hand: ### Features - Implement concurrency lock to avoid errors when there are several calls to the local LlamaCPP model - API key-based request control to the API -- CORS support - Support for Sagemaker - Support Function calling - Add md5 to check files already ingested diff --git a/private_gpt/main.py b/private_gpt/main.py index 2cf6bf9..519f205 100644 --- a/private_gpt/main.py +++ b/private_gpt/main.py @@ -4,6 +4,7 @@ from typing import Any import llama_index from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from private_gpt.paths import docs_path @@ -104,6 +105,17 @@ app.include_router(ingest_router) app.include_router(embeddings_router) app.include_router(health_router) +if settings.server.cors.enabled: + logger.debug("Setting up CORS middleware") + app.add_middleware( + CORSMiddleware, + allow_credentials=settings.server.cors.allow_credentials, + allow_origins=settings.server.cors.allow_origins, + allow_origin_regex=settings.server.cors.allow_origin_regex, + allow_methods=settings.server.cors.allow_methods, + allow_headers=settings.server.cors.allow_headers, + ) + if settings.ui.enabled: logger.debug("Importing the UI module") diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 16b1d4b..d1c6690 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -3,11 +3,50 @@ from pydantic import BaseModel, Field from private_gpt.settings.settings_loader import load_active_profiles +class CorsSettings(BaseModel): + """CORS configuration. + + For more details on the CORS configuration, see: + # * https://fastapi.tiangolo.com/tutorial/cors/ + # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + """ + + enabled: bool = Field( + description="Flag indicating if CORS headers are set or not." + "If set to True, the CORS headers will be set to allow all origins, methods and headers." + ) + allow_credentials: bool = Field( + description="Indicate that cookies should be supported for cross-origin requests", + default=False, + ) + allow_origins: list[str] = Field( + description="A list of origins that should be permitted to make cross-origin requests.", + default=[], + ) + allow_origin_regex: list[str] = Field( + description="A regex string to match against origins that should be permitted to make cross-origin requests.", + default=None, + ) + allow_methods: list[str] = Field( + description="A list of HTTP methods that should be allowed for cross-origin requests.", + default=[ + "GET", + ], + ) + allow_headers: list[str] = Field( + description="A list of HTTP request headers that should be supported for cross-origin requests.", + default=[], + ) + + class ServerSettings(BaseModel): env_name: str = Field( description="Name of the environment (prod, staging, local...)" ) - port: int = Field("Port of PrivateGPT FastAPI server, defaults to 8001") + port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001") + cors: CorsSettings = Field( + description="CORS configuration", default=CorsSettings(enabled=False) + ) class DataSettings(BaseModel): diff --git a/settings.yaml b/settings.yaml index fba278d..c658705 100644 --- a/settings.yaml +++ b/settings.yaml @@ -1,6 +1,11 @@ server: env_name: ${APP_ENV:prod} port: ${PORT:8001} + cors: + enabled: false + allow_origins: ["*"] + allow_methods: ["*"] + allow_headers: ["*"] data: local_data_folder: local_data/private_gpt