Add basic CORS (#1198)
This commit is contained in:
		
							parent
							
								
									a666fd5b73
								
							
						
					
					
						commit
						8487440a6f
					
				|  | @ -129,7 +129,6 @@ you want to give a hand: | ||||||
| ### Features | ### Features | ||||||
| - Implement concurrency lock to avoid errors when there are several calls to the local LlamaCPP model | - Implement concurrency lock to avoid errors when there are several calls to the local LlamaCPP model | ||||||
| - API key-based request control to the API | - API key-based request control to the API | ||||||
| - CORS support |  | ||||||
| - Support for Sagemaker | - Support for Sagemaker | ||||||
| - Support Function calling | - Support Function calling | ||||||
| - Add md5 to check files already ingested | - Add md5 to check files already ingested | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ from typing import Any | ||||||
| 
 | 
 | ||||||
| import llama_index | import llama_index | ||||||
| from fastapi import FastAPI | from fastapi import FastAPI | ||||||
|  | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from fastapi.openapi.utils import get_openapi | from fastapi.openapi.utils import get_openapi | ||||||
| 
 | 
 | ||||||
| from private_gpt.paths import docs_path | from private_gpt.paths import docs_path | ||||||
|  | @ -104,6 +105,17 @@ app.include_router(ingest_router) | ||||||
| app.include_router(embeddings_router) | app.include_router(embeddings_router) | ||||||
| app.include_router(health_router) | app.include_router(health_router) | ||||||
| 
 | 
 | ||||||
|  | if settings.server.cors.enabled: | ||||||
|  |     logger.debug("Setting up CORS middleware") | ||||||
|  |     app.add_middleware( | ||||||
|  |         CORSMiddleware, | ||||||
|  |         allow_credentials=settings.server.cors.allow_credentials, | ||||||
|  |         allow_origins=settings.server.cors.allow_origins, | ||||||
|  |         allow_origin_regex=settings.server.cors.allow_origin_regex, | ||||||
|  |         allow_methods=settings.server.cors.allow_methods, | ||||||
|  |         allow_headers=settings.server.cors.allow_headers, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if settings.ui.enabled: | if settings.ui.enabled: | ||||||
|     logger.debug("Importing the UI module") |     logger.debug("Importing the UI module") | ||||||
|  |  | ||||||
|  | @ -3,11 +3,50 @@ from pydantic import BaseModel, Field | ||||||
| from private_gpt.settings.settings_loader import load_active_profiles | from private_gpt.settings.settings_loader import load_active_profiles | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class CorsSettings(BaseModel): | ||||||
|  |     """CORS configuration. | ||||||
|  | 
 | ||||||
|  |     For more details on the CORS configuration, see: | ||||||
|  |     # * https://fastapi.tiangolo.com/tutorial/cors/ | ||||||
|  |     # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     enabled: bool = Field( | ||||||
|  |         description="Flag indicating if CORS headers are set or not." | ||||||
|  |         "If set to True, the CORS headers will be set to allow all origins, methods and headers." | ||||||
|  |     ) | ||||||
|  |     allow_credentials: bool = Field( | ||||||
|  |         description="Indicate that cookies should be supported for cross-origin requests", | ||||||
|  |         default=False, | ||||||
|  |     ) | ||||||
|  |     allow_origins: list[str] = Field( | ||||||
|  |         description="A list of origins that should be permitted to make cross-origin requests.", | ||||||
|  |         default=[], | ||||||
|  |     ) | ||||||
|  |     allow_origin_regex: list[str] = Field( | ||||||
|  |         description="A regex string to match against origins that should be permitted to make cross-origin requests.", | ||||||
|  |         default=None, | ||||||
|  |     ) | ||||||
|  |     allow_methods: list[str] = Field( | ||||||
|  |         description="A list of HTTP methods that should be allowed for cross-origin requests.", | ||||||
|  |         default=[ | ||||||
|  |             "GET", | ||||||
|  |         ], | ||||||
|  |     ) | ||||||
|  |     allow_headers: list[str] = Field( | ||||||
|  |         description="A list of HTTP request headers that should be supported for cross-origin requests.", | ||||||
|  |         default=[], | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ServerSettings(BaseModel): | class ServerSettings(BaseModel): | ||||||
|     env_name: str = Field( |     env_name: str = Field( | ||||||
|         description="Name of the environment (prod, staging, local...)" |         description="Name of the environment (prod, staging, local...)" | ||||||
|     ) |     ) | ||||||
|     port: int = Field("Port of PrivateGPT FastAPI server, defaults to 8001") |     port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001") | ||||||
|  |     cors: CorsSettings = Field( | ||||||
|  |         description="CORS configuration", default=CorsSettings(enabled=False) | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DataSettings(BaseModel): | class DataSettings(BaseModel): | ||||||
|  |  | ||||||
|  | @ -1,6 +1,11 @@ | ||||||
| server: | server: | ||||||
|   env_name: ${APP_ENV:prod} |   env_name: ${APP_ENV:prod} | ||||||
|   port: ${PORT:8001} |   port: ${PORT:8001} | ||||||
|  |   cors: | ||||||
|  |     enabled: false | ||||||
|  |     allow_origins: ["*"] | ||||||
|  |     allow_methods: ["*"] | ||||||
|  |     allow_headers: ["*"] | ||||||
| 
 | 
 | ||||||
| data: | data: | ||||||
|   local_data_folder: local_data/private_gpt |   local_data_folder: local_data/private_gpt | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue