package middleware import ( "net/http" "strings" "github.com/gin-gonic/gin" "github.com/rwadurian/mpc-system/pkg/jwt" "github.com/rwadurian/mpc-system/pkg/logger" "go.uber.org/zap" ) // AuthConfig holds configuration for authentication middleware type AuthConfig struct { JWTService *jwt.JWTService SkipPaths []string // Paths to skip authentication (e.g., /health, /auth/*) AllowAnonymous bool // If true, allow requests without token (user info will be nil) } // ContextKey is a custom type for context keys to avoid collisions type ContextKey string const ( // UserContextKey is the key for storing user info in gin context UserContextKey ContextKey = "user" // ClaimsContextKey is the key for storing JWT claims in gin context ClaimsContextKey ContextKey = "claims" ) // UserInfo represents authenticated user information type UserInfo struct { UserID string Username string } // BearerAuth creates a middleware that validates Bearer tokens // Extracts token from Authorization header: "Bearer " func BearerAuth(config AuthConfig) gin.HandlerFunc { return func(c *gin.Context) { // Check if path should be skipped path := c.Request.URL.Path for _, skipPath := range config.SkipPaths { if matchPath(skipPath, path) { c.Next() return } } // Extract token from Authorization header authHeader := c.GetHeader("Authorization") if authHeader == "" { if config.AllowAnonymous { c.Next() return } c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "missing authorization header", }) return } // Check Bearer prefix parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "invalid authorization header format, expected: Bearer ", }) return } token := parts[1] // Validate access token claims, err := config.JWTService.ValidateAccessToken(token) if err != nil { logger.Debug("Token validation failed", zap.Error(err), zap.String("path", path)) statusCode := http.StatusUnauthorized message := "invalid token" if err == jwt.ErrExpiredToken { message = "token expired" } c.AbortWithStatusJSON(statusCode, gin.H{ "error": "unauthorized", "message": message, }) return } // Store user info in context userInfo := &UserInfo{ UserID: claims.Subject, Username: claims.Username, } c.Set(string(UserContextKey), userInfo) c.Set(string(ClaimsContextKey), claims) logger.Debug("Request authenticated", zap.String("user_id", userInfo.UserID), zap.String("username", userInfo.Username), zap.String("path", path)) c.Next() } } // GetUser extracts UserInfo from gin context func GetUser(c *gin.Context) *UserInfo { if user, exists := c.Get(string(UserContextKey)); exists { if userInfo, ok := user.(*UserInfo); ok { return userInfo } } return nil } // RequireUser is a middleware that ensures user is authenticated // Use this after BearerAuth with AllowAnonymous=true func RequireUser() gin.HandlerFunc { return func(c *gin.Context) { user := GetUser(c) if user == nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "authentication required", }) return } c.Next() } } // RequireOwnership is a middleware that ensures the authenticated user // matches the resource owner (identified by a path parameter) func RequireOwnership(paramName string) gin.HandlerFunc { return func(c *gin.Context) { user := GetUser(c) if user == nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "authentication required", }) return } resourceOwner := c.Param(paramName) if resourceOwner != "" && resourceOwner != user.UserID && resourceOwner != user.Username { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "access denied to this resource", }) return } c.Next() } } // matchPath checks if a pattern matches a path // Supports wildcard suffix: "/auth/*" matches "/auth/login", "/auth/refresh" func matchPath(pattern, path string) bool { // Exact match if pattern == path { return true } // Wildcard match if strings.HasSuffix(pattern, "/*") { prefix := strings.TrimSuffix(pattern, "/*") return strings.HasPrefix(path, prefix+"/") || path == prefix } // Prefix match (for backward compatibility) if strings.HasSuffix(pattern, "*") { prefix := strings.TrimSuffix(pattern, "*") return strings.HasPrefix(path, prefix) } return false }