187 lines
4.6 KiB
Go
187 lines
4.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/rwadurian/mpc-system/pkg/jwt"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// AuthConfig holds configuration for authentication middleware
|
|
type AuthConfig struct {
|
|
JWTService *jwt.JWTService
|
|
SkipPaths []string // Paths to skip authentication (e.g., /health, /auth/*)
|
|
AllowAnonymous bool // If true, allow requests without token (user info will be nil)
|
|
}
|
|
|
|
// ContextKey is a custom type for context keys to avoid collisions
|
|
type ContextKey string
|
|
|
|
const (
|
|
// UserContextKey is the key for storing user info in gin context
|
|
UserContextKey ContextKey = "user"
|
|
// ClaimsContextKey is the key for storing JWT claims in gin context
|
|
ClaimsContextKey ContextKey = "claims"
|
|
)
|
|
|
|
// UserInfo represents authenticated user information
|
|
type UserInfo struct {
|
|
UserID string
|
|
Username string
|
|
}
|
|
|
|
// BearerAuth creates a middleware that validates Bearer tokens
|
|
// Extracts token from Authorization header: "Bearer <token>"
|
|
func BearerAuth(config AuthConfig) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Check if path should be skipped
|
|
path := c.Request.URL.Path
|
|
for _, skipPath := range config.SkipPaths {
|
|
if matchPath(skipPath, path) {
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
// Extract token from Authorization header
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
if config.AllowAnonymous {
|
|
c.Next()
|
|
return
|
|
}
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "unauthorized",
|
|
"message": "missing authorization header",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check Bearer prefix
|
|
parts := strings.SplitN(authHeader, " ", 2)
|
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "unauthorized",
|
|
"message": "invalid authorization header format, expected: Bearer <token>",
|
|
})
|
|
return
|
|
}
|
|
|
|
token := parts[1]
|
|
|
|
// Validate access token
|
|
claims, err := config.JWTService.ValidateAccessToken(token)
|
|
if err != nil {
|
|
logger.Debug("Token validation failed",
|
|
zap.Error(err),
|
|
zap.String("path", path))
|
|
|
|
statusCode := http.StatusUnauthorized
|
|
message := "invalid token"
|
|
|
|
if err == jwt.ErrExpiredToken {
|
|
message = "token expired"
|
|
}
|
|
|
|
c.AbortWithStatusJSON(statusCode, gin.H{
|
|
"error": "unauthorized",
|
|
"message": message,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Store user info in context
|
|
userInfo := &UserInfo{
|
|
UserID: claims.Subject,
|
|
Username: claims.Username,
|
|
}
|
|
c.Set(string(UserContextKey), userInfo)
|
|
c.Set(string(ClaimsContextKey), claims)
|
|
|
|
logger.Debug("Request authenticated",
|
|
zap.String("user_id", userInfo.UserID),
|
|
zap.String("username", userInfo.Username),
|
|
zap.String("path", path))
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// GetUser extracts UserInfo from gin context
|
|
func GetUser(c *gin.Context) *UserInfo {
|
|
if user, exists := c.Get(string(UserContextKey)); exists {
|
|
if userInfo, ok := user.(*UserInfo); ok {
|
|
return userInfo
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RequireUser is a middleware that ensures user is authenticated
|
|
// Use this after BearerAuth with AllowAnonymous=true
|
|
func RequireUser() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
user := GetUser(c)
|
|
if user == nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "unauthorized",
|
|
"message": "authentication required",
|
|
})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireOwnership is a middleware that ensures the authenticated user
|
|
// matches the resource owner (identified by a path parameter)
|
|
func RequireOwnership(paramName string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
user := GetUser(c)
|
|
if user == nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "unauthorized",
|
|
"message": "authentication required",
|
|
})
|
|
return
|
|
}
|
|
|
|
resourceOwner := c.Param(paramName)
|
|
if resourceOwner != "" && resourceOwner != user.UserID && resourceOwner != user.Username {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": "forbidden",
|
|
"message": "access denied to this resource",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// matchPath checks if a pattern matches a path
|
|
// Supports wildcard suffix: "/auth/*" matches "/auth/login", "/auth/refresh"
|
|
func matchPath(pattern, path string) bool {
|
|
// Exact match
|
|
if pattern == path {
|
|
return true
|
|
}
|
|
|
|
// Wildcard match
|
|
if strings.HasSuffix(pattern, "/*") {
|
|
prefix := strings.TrimSuffix(pattern, "/*")
|
|
return strings.HasPrefix(path, prefix+"/") || path == prefix
|
|
}
|
|
|
|
// Prefix match (for backward compatibility)
|
|
if strings.HasSuffix(pattern, "*") {
|
|
prefix := strings.TrimSuffix(pattern, "*")
|
|
return strings.HasPrefix(path, prefix)
|
|
}
|
|
|
|
return false
|
|
}
|