package middleware import ( "net/http" "strconv" "strings" "github.com/gin-gonic/gin" ) // CORSConfig holds configuration for CORS middleware type CORSConfig struct { // AllowOrigins is a list of origins that are allowed to access the resource // Use "*" to allow all origins (not recommended for production) AllowOrigins []string // AllowMethods is a list of HTTP methods allowed for CORS requests AllowMethods []string // AllowHeaders is a list of headers that are allowed in CORS requests AllowHeaders []string // ExposeHeaders is a list of headers that the browser is allowed to access ExposeHeaders []string // AllowCredentials indicates whether credentials (cookies, auth headers) are allowed AllowCredentials bool // MaxAge is the maximum time (in seconds) that preflight results can be cached MaxAge int } // DefaultCORSConfig returns a default CORS configuration func DefaultCORSConfig() CORSConfig { return CORSConfig{ AllowOrigins: []string{}, AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, AllowHeaders: []string{ "Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-Request-ID", }, ExposeHeaders: []string{"Content-Length", "X-Request-ID"}, AllowCredentials: true, MaxAge: 86400, // 24 hours } } // CORS creates a middleware that handles Cross-Origin Resource Sharing func CORS(config CORSConfig) gin.HandlerFunc { // Precompute allowed origins map for fast lookup allowedOrigins := make(map[string]bool) allowAllOrigins := false for _, origin := range config.AllowOrigins { if origin == "*" { allowAllOrigins = true break } allowedOrigins[origin] = true } // Precompute header values allowMethodsHeader := strings.Join(config.AllowMethods, ", ") allowHeadersHeader := strings.Join(config.AllowHeaders, ", ") exposeHeadersHeader := strings.Join(config.ExposeHeaders, ", ") maxAgeHeader := strconv.Itoa(config.MaxAge) return func(c *gin.Context) { origin := c.GetHeader("Origin") // If no origin header, this is not a CORS request if origin == "" { c.Next() return } // Check if origin is allowed var allowOrigin string if allowAllOrigins { allowOrigin = "*" } else if allowedOrigins[origin] { allowOrigin = origin } else { // Origin not allowed, but still process the request // The browser will block the response based on missing headers c.Next() return } // Set CORS headers c.Header("Access-Control-Allow-Origin", allowOrigin) if config.AllowCredentials && !allowAllOrigins { c.Header("Access-Control-Allow-Credentials", "true") } if exposeHeadersHeader != "" { c.Header("Access-Control-Expose-Headers", exposeHeadersHeader) } // Handle preflight request if c.Request.Method == http.MethodOptions { c.Header("Access-Control-Allow-Methods", allowMethodsHeader) c.Header("Access-Control-Allow-Headers", allowHeadersHeader) c.Header("Access-Control-Max-Age", maxAgeHeader) c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } // AllowAllCORS is a permissive CORS middleware (for development only) // WARNING: Do not use in production func AllowAllCORS() gin.HandlerFunc { return func(c *gin.Context) { origin := c.GetHeader("Origin") if origin == "" { origin = "*" } c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, X-Requested-With") c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Max-Age", "86400") if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } }