235 lines
6.6 KiB
Go
235 lines
6.6 KiB
Go
package jwt
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
ErrExpiredToken = errors.New("token expired")
|
|
ErrInvalidClaims = errors.New("invalid claims")
|
|
ErrTokenNotYetValid = errors.New("token not yet valid")
|
|
)
|
|
|
|
// Claims represents custom JWT claims
|
|
type Claims struct {
|
|
SessionID string `json:"session_id"`
|
|
PartyID string `json:"party_id"`
|
|
TokenType string `json:"token_type"` // "join", "access", "refresh"
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// JWTService provides JWT operations
|
|
type JWTService struct {
|
|
secretKey []byte
|
|
issuer string
|
|
tokenExpiry time.Duration
|
|
refreshExpiry time.Duration
|
|
}
|
|
|
|
// NewJWTService creates a new JWT service
|
|
func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService {
|
|
return &JWTService{
|
|
secretKey: []byte(secretKey),
|
|
issuer: issuer,
|
|
tokenExpiry: tokenExpiry,
|
|
refreshExpiry: refreshExpiry,
|
|
}
|
|
}
|
|
|
|
// GenerateJoinToken generates a token for joining an MPC session
|
|
func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) {
|
|
now := time.Now()
|
|
claims := Claims{
|
|
SessionID: sessionID.String(),
|
|
PartyID: partyID,
|
|
TokenType: "join",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: uuid.New().String(),
|
|
Issuer: s.issuer,
|
|
Subject: partyID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString(s.secretKey)
|
|
}
|
|
|
|
// AccessTokenClaims represents claims in an access token
|
|
type AccessTokenClaims struct {
|
|
Subject string
|
|
Username string
|
|
Issuer string
|
|
}
|
|
|
|
// GenerateAccessToken generates an access token with username
|
|
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
|
now := time.Now()
|
|
claims := Claims{
|
|
TokenType: "access",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: uuid.New().String(),
|
|
Issuer: s.issuer,
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)),
|
|
},
|
|
}
|
|
// Store username in PartyID field for access tokens
|
|
claims.PartyID = username
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString(s.secretKey)
|
|
}
|
|
|
|
// GenerateRefreshToken generates a refresh token
|
|
func (s *JWTService) GenerateRefreshToken(userID string) (string, error) {
|
|
now := time.Now()
|
|
claims := Claims{
|
|
TokenType: "refresh",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: uuid.New().String(),
|
|
Issuer: s.issuer,
|
|
Subject: userID,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString(s.secretKey)
|
|
}
|
|
|
|
// ValidateToken validates a JWT token and returns the claims
|
|
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
return s.secretKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
return nil, ErrExpiredToken
|
|
}
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok || !token.Valid {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// ParseJoinTokenClaims parses a join token and extracts claims without validating session ID
|
|
// This is used when the session ID is not known beforehand (e.g., join by token)
|
|
func (s *JWTService) ParseJoinTokenClaims(tokenString string) (*Claims, error) {
|
|
claims, err := s.ValidateToken(tokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims.TokenType != "join" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// ValidateJoinToken validates a join token for MPC sessions
|
|
func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) {
|
|
claims, err := s.ValidateToken(tokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims.TokenType != "join" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
if claims.SessionID != sessionID.String() {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
// Allow wildcard party ID "*" for dynamic joining, otherwise must match exactly
|
|
if claims.PartyID != "*" && claims.PartyID != partyID {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RefreshAccessToken creates a new access token from a valid refresh token
|
|
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
|
|
claims, err := s.ValidateToken(refreshToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if claims.TokenType != "refresh" {
|
|
return "", ErrInvalidToken
|
|
}
|
|
|
|
// PartyID stores the username for access tokens
|
|
return s.GenerateAccessToken(claims.Subject, claims.PartyID)
|
|
}
|
|
|
|
// ValidateAccessToken validates an access token and returns structured claims
|
|
func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) {
|
|
claims, err := s.ValidateToken(tokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims.TokenType != "access" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
return &AccessTokenClaims{
|
|
Subject: claims.Subject,
|
|
Username: claims.PartyID, // Username stored in PartyID for access tokens
|
|
Issuer: claims.Issuer,
|
|
}, nil
|
|
}
|
|
|
|
// ValidateRefreshToken validates a refresh token and returns claims
|
|
func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) {
|
|
claims, err := s.ValidateToken(tokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims.TokenType != "refresh" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// TokenGenerator interface for dependency injection
|
|
type TokenGenerator interface {
|
|
GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error)
|
|
}
|
|
|
|
// TokenValidator interface for dependency injection
|
|
type TokenValidator interface {
|
|
ParseJoinTokenClaims(tokenString string) (*Claims, error)
|
|
ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error)
|
|
}
|
|
|
|
// Ensure JWTService implements interfaces
|
|
var _ TokenGenerator = (*JWTService)(nil)
|
|
var _ TokenValidator = (*JWTService)(nil)
|