rwadurian/backend/mpc-system/pkg/middleware/auth.go

187 lines
4.6 KiB
Go

package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/rwadurian/mpc-system/pkg/jwt"
"github.com/rwadurian/mpc-system/pkg/logger"
"go.uber.org/zap"
)
// AuthConfig holds configuration for authentication middleware
type AuthConfig struct {
JWTService *jwt.JWTService
SkipPaths []string // Paths to skip authentication (e.g., /health, /auth/*)
AllowAnonymous bool // If true, allow requests without token (user info will be nil)
}
// ContextKey is a custom type for context keys to avoid collisions
type ContextKey string
const (
// UserContextKey is the key for storing user info in gin context
UserContextKey ContextKey = "user"
// ClaimsContextKey is the key for storing JWT claims in gin context
ClaimsContextKey ContextKey = "claims"
)
// UserInfo represents authenticated user information
type UserInfo struct {
UserID string
Username string
}
// BearerAuth creates a middleware that validates Bearer tokens
// Extracts token from Authorization header: "Bearer <token>"
func BearerAuth(config AuthConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Check if path should be skipped
path := c.Request.URL.Path
for _, skipPath := range config.SkipPaths {
if matchPath(skipPath, path) {
c.Next()
return
}
}
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
if config.AllowAnonymous {
c.Next()
return
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "missing authorization header",
})
return
}
// Check Bearer prefix
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "invalid authorization header format, expected: Bearer <token>",
})
return
}
token := parts[1]
// Validate access token
claims, err := config.JWTService.ValidateAccessToken(token)
if err != nil {
logger.Debug("Token validation failed",
zap.Error(err),
zap.String("path", path))
statusCode := http.StatusUnauthorized
message := "invalid token"
if err == jwt.ErrExpiredToken {
message = "token expired"
}
c.AbortWithStatusJSON(statusCode, gin.H{
"error": "unauthorized",
"message": message,
})
return
}
// Store user info in context
userInfo := &UserInfo{
UserID: claims.Subject,
Username: claims.Username,
}
c.Set(string(UserContextKey), userInfo)
c.Set(string(ClaimsContextKey), claims)
logger.Debug("Request authenticated",
zap.String("user_id", userInfo.UserID),
zap.String("username", userInfo.Username),
zap.String("path", path))
c.Next()
}
}
// GetUser extracts UserInfo from gin context
func GetUser(c *gin.Context) *UserInfo {
if user, exists := c.Get(string(UserContextKey)); exists {
if userInfo, ok := user.(*UserInfo); ok {
return userInfo
}
}
return nil
}
// RequireUser is a middleware that ensures user is authenticated
// Use this after BearerAuth with AllowAnonymous=true
func RequireUser() gin.HandlerFunc {
return func(c *gin.Context) {
user := GetUser(c)
if user == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "authentication required",
})
return
}
c.Next()
}
}
// RequireOwnership is a middleware that ensures the authenticated user
// matches the resource owner (identified by a path parameter)
func RequireOwnership(paramName string) gin.HandlerFunc {
return func(c *gin.Context) {
user := GetUser(c)
if user == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "authentication required",
})
return
}
resourceOwner := c.Param(paramName)
if resourceOwner != "" && resourceOwner != user.UserID && resourceOwner != user.Username {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "access denied to this resource",
})
return
}
c.Next()
}
}
// matchPath checks if a pattern matches a path
// Supports wildcard suffix: "/auth/*" matches "/auth/login", "/auth/refresh"
func matchPath(pattern, path string) bool {
// Exact match
if pattern == path {
return true
}
// Wildcard match
if strings.HasSuffix(pattern, "/*") {
prefix := strings.TrimSuffix(pattern, "/*")
return strings.HasPrefix(path, prefix+"/") || path == prefix
}
// Prefix match (for backward compatibility)
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(path, prefix)
}
return false
}