package jwt import ( "errors" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) var ( ErrInvalidToken = errors.New("invalid token") ErrExpiredToken = errors.New("token expired") ErrInvalidClaims = errors.New("invalid claims") ErrTokenNotYetValid = errors.New("token not yet valid") ) // Claims represents custom JWT claims type Claims struct { SessionID string `json:"session_id"` PartyID string `json:"party_id"` TokenType string `json:"token_type"` // "join", "access", "refresh" jwt.RegisteredClaims } // JWTService provides JWT operations type JWTService struct { secretKey []byte issuer string tokenExpiry time.Duration refreshExpiry time.Duration } // NewJWTService creates a new JWT service func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService { return &JWTService{ secretKey: []byte(secretKey), issuer: issuer, tokenExpiry: tokenExpiry, refreshExpiry: refreshExpiry, } } // GenerateJoinToken generates a token for joining an MPC session func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) { now := time.Now() claims := Claims{ SessionID: sessionID.String(), PartyID: partyID, TokenType: "join", RegisteredClaims: jwt.RegisteredClaims{ ID: uuid.New().String(), Issuer: s.issuer, Subject: partyID, IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.secretKey) } // AccessTokenClaims represents claims in an access token type AccessTokenClaims struct { Subject string Username string Issuer string } // GenerateAccessToken generates an access token with username func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) { now := time.Now() claims := Claims{ TokenType: "access", RegisteredClaims: jwt.RegisteredClaims{ ID: uuid.New().String(), Issuer: s.issuer, Subject: userID, IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)), }, } // Store username in PartyID field for access tokens claims.PartyID = username token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.secretKey) } // GenerateRefreshToken generates a refresh token func (s *JWTService) GenerateRefreshToken(userID string) (string, error) { now := time.Now() claims := Claims{ TokenType: "refresh", RegisteredClaims: jwt.RegisteredClaims{ ID: uuid.New().String(), Issuer: s.issuer, Subject: userID, IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.secretKey) } // ValidateToken validates a JWT token and returns the claims func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, ErrInvalidToken } return s.secretKey, nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrExpiredToken } return nil, ErrInvalidToken } claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, ErrInvalidClaims } return claims, nil } // ParseJoinTokenClaims parses a join token and extracts claims without validating session ID // This is used when the session ID is not known beforehand (e.g., join by token) func (s *JWTService) ParseJoinTokenClaims(tokenString string) (*Claims, error) { claims, err := s.ValidateToken(tokenString) if err != nil { return nil, err } if claims.TokenType != "join" { return nil, ErrInvalidToken } return claims, nil } // ValidateJoinToken validates a join token for MPC sessions func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) { claims, err := s.ValidateToken(tokenString) if err != nil { return nil, err } if claims.TokenType != "join" { return nil, ErrInvalidToken } if claims.SessionID != sessionID.String() { return nil, ErrInvalidClaims } // Allow wildcard party ID "*" for dynamic joining, otherwise must match exactly if claims.PartyID != "*" && claims.PartyID != partyID { return nil, ErrInvalidClaims } return claims, nil } // RefreshAccessToken creates a new access token from a valid refresh token func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) { claims, err := s.ValidateToken(refreshToken) if err != nil { return "", err } if claims.TokenType != "refresh" { return "", ErrInvalidToken } // PartyID stores the username for access tokens return s.GenerateAccessToken(claims.Subject, claims.PartyID) } // ValidateAccessToken validates an access token and returns structured claims func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) { claims, err := s.ValidateToken(tokenString) if err != nil { return nil, err } if claims.TokenType != "access" { return nil, ErrInvalidToken } return &AccessTokenClaims{ Subject: claims.Subject, Username: claims.PartyID, // Username stored in PartyID for access tokens Issuer: claims.Issuer, }, nil } // ValidateRefreshToken validates a refresh token and returns claims func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) { claims, err := s.ValidateToken(tokenString) if err != nil { return nil, err } if claims.TokenType != "refresh" { return nil, ErrInvalidToken } return claims, nil } // TokenGenerator interface for dependency injection type TokenGenerator interface { GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) } // TokenValidator interface for dependency injection type TokenValidator interface { ParseJoinTokenClaims(tokenString string) (*Claims, error) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) } // Ensure JWTService implements interfaces var _ TokenGenerator = (*JWTService)(nil) var _ TokenValidator = (*JWTService)(nil)