rwadurian/backend/mpc-system/pkg/middleware/cors.go

135 lines
3.7 KiB
Go

package middleware
import (
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// CORSConfig holds configuration for CORS middleware
type CORSConfig struct {
// AllowOrigins is a list of origins that are allowed to access the resource
// Use "*" to allow all origins (not recommended for production)
AllowOrigins []string
// AllowMethods is a list of HTTP methods allowed for CORS requests
AllowMethods []string
// AllowHeaders is a list of headers that are allowed in CORS requests
AllowHeaders []string
// ExposeHeaders is a list of headers that the browser is allowed to access
ExposeHeaders []string
// AllowCredentials indicates whether credentials (cookies, auth headers) are allowed
AllowCredentials bool
// MaxAge is the maximum time (in seconds) that preflight results can be cached
MaxAge int
}
// DefaultCORSConfig returns a default CORS configuration
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowOrigins: []string{},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{
"Origin",
"Content-Type",
"Accept",
"Authorization",
"X-Requested-With",
"X-Request-ID",
},
ExposeHeaders: []string{"Content-Length", "X-Request-ID"},
AllowCredentials: true,
MaxAge: 86400, // 24 hours
}
}
// CORS creates a middleware that handles Cross-Origin Resource Sharing
func CORS(config CORSConfig) gin.HandlerFunc {
// Precompute allowed origins map for fast lookup
allowedOrigins := make(map[string]bool)
allowAllOrigins := false
for _, origin := range config.AllowOrigins {
if origin == "*" {
allowAllOrigins = true
break
}
allowedOrigins[origin] = true
}
// Precompute header values
allowMethodsHeader := strings.Join(config.AllowMethods, ", ")
allowHeadersHeader := strings.Join(config.AllowHeaders, ", ")
exposeHeadersHeader := strings.Join(config.ExposeHeaders, ", ")
maxAgeHeader := strconv.Itoa(config.MaxAge)
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
// If no origin header, this is not a CORS request
if origin == "" {
c.Next()
return
}
// Check if origin is allowed
var allowOrigin string
if allowAllOrigins {
allowOrigin = "*"
} else if allowedOrigins[origin] {
allowOrigin = origin
} else {
// Origin not allowed, but still process the request
// The browser will block the response based on missing headers
c.Next()
return
}
// Set CORS headers
c.Header("Access-Control-Allow-Origin", allowOrigin)
if config.AllowCredentials && !allowAllOrigins {
c.Header("Access-Control-Allow-Credentials", "true")
}
if exposeHeadersHeader != "" {
c.Header("Access-Control-Expose-Headers", exposeHeadersHeader)
}
// Handle preflight request
if c.Request.Method == http.MethodOptions {
c.Header("Access-Control-Allow-Methods", allowMethodsHeader)
c.Header("Access-Control-Allow-Headers", allowHeadersHeader)
c.Header("Access-Control-Max-Age", maxAgeHeader)
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// AllowAllCORS is a permissive CORS middleware (for development only)
// WARNING: Do not use in production
func AllowAllCORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
if origin == "" {
origin = "*"
}
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Requested-With")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}