rwadurian/backend/mpc-system/pkg/jwt/jwt.go

235 lines
6.6 KiB
Go

package jwt
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("token expired")
ErrInvalidClaims = errors.New("invalid claims")
ErrTokenNotYetValid = errors.New("token not yet valid")
)
// Claims represents custom JWT claims
type Claims struct {
SessionID string `json:"session_id"`
PartyID string `json:"party_id"`
TokenType string `json:"token_type"` // "join", "access", "refresh"
jwt.RegisteredClaims
}
// JWTService provides JWT operations
type JWTService struct {
secretKey []byte
issuer string
tokenExpiry time.Duration
refreshExpiry time.Duration
}
// NewJWTService creates a new JWT service
func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService {
return &JWTService{
secretKey: []byte(secretKey),
issuer: issuer,
tokenExpiry: tokenExpiry,
refreshExpiry: refreshExpiry,
}
}
// GenerateJoinToken generates a token for joining an MPC session
func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) {
now := time.Now()
claims := Claims{
SessionID: sessionID.String(),
PartyID: partyID,
TokenType: "join",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: partyID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// AccessTokenClaims represents claims in an access token
type AccessTokenClaims struct {
Subject string
Username string
Issuer string
}
// GenerateAccessToken generates an access token with username
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
now := time.Now()
claims := Claims{
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)),
},
}
// Store username in PartyID field for access tokens
claims.PartyID = username
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// GenerateRefreshToken generates a refresh token
func (s *JWTService) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := Claims{
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// ValidateToken validates a JWT token and returns the claims
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, ErrInvalidToken
}
return s.secretKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrExpiredToken
}
return nil, ErrInvalidToken
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, ErrInvalidClaims
}
return claims, nil
}
// ParseJoinTokenClaims parses a join token and extracts claims without validating session ID
// This is used when the session ID is not known beforehand (e.g., join by token)
func (s *JWTService) ParseJoinTokenClaims(tokenString string) (*Claims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "join" {
return nil, ErrInvalidToken
}
return claims, nil
}
// ValidateJoinToken validates a join token for MPC sessions
func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "join" {
return nil, ErrInvalidToken
}
if claims.SessionID != sessionID.String() {
return nil, ErrInvalidClaims
}
// Allow wildcard party ID "*" for dynamic joining, otherwise must match exactly
if claims.PartyID != "*" && claims.PartyID != partyID {
return nil, ErrInvalidClaims
}
return claims, nil
}
// RefreshAccessToken creates a new access token from a valid refresh token
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
claims, err := s.ValidateToken(refreshToken)
if err != nil {
return "", err
}
if claims.TokenType != "refresh" {
return "", ErrInvalidToken
}
// PartyID stores the username for access tokens
return s.GenerateAccessToken(claims.Subject, claims.PartyID)
}
// ValidateAccessToken validates an access token and returns structured claims
func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "access" {
return nil, ErrInvalidToken
}
return &AccessTokenClaims{
Subject: claims.Subject,
Username: claims.PartyID, // Username stored in PartyID for access tokens
Issuer: claims.Issuer,
}, nil
}
// ValidateRefreshToken validates a refresh token and returns claims
func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "refresh" {
return nil, ErrInvalidToken
}
return claims, nil
}
// TokenGenerator interface for dependency injection
type TokenGenerator interface {
GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error)
}
// TokenValidator interface for dependency injection
type TokenValidator interface {
ParseJoinTokenClaims(tokenString string) (*Claims, error)
ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error)
}
// Ensure JWTService implements interfaces
var _ TokenGenerator = (*JWTService)(nil)
var _ TokenValidator = (*JWTService)(nil)