package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"io"
"math/big"
"golang.org/x/crypto/hkdf"
)
var (
ErrInvalidKeySize = errors.New("invalid key size")
ErrInvalidCipherText = errors.New("invalid ciphertext")
ErrEncryptionFailed = errors.New("encryption failed")
ErrDecryptionFailed = errors.New("decryption failed")
ErrInvalidPublicKey = errors.New("invalid public key")
ErrInvalidSignature = errors.New("invalid signature")
)
// CryptoService provides cryptographic operations
type CryptoService struct {
masterKey []byte
}
// NewCryptoService creates a new crypto service
func NewCryptoService(masterKey []byte) (*CryptoService, error) {
if len(masterKey) != 32 {
return nil, ErrInvalidKeySize
}
return &CryptoService{masterKey: masterKey}, nil
}
// GenerateRandomBytes generates random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateRandomHex generates a random hex string
func GenerateRandomHex(n int) (string, error) {
bytes, err := GenerateRandomBytes(n)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// DeriveKey derives a key from the master key using HKDF
func (c *CryptoService) DeriveKey(context string, length int) ([]byte, error) {
hkdfReader := hkdf.New(sha256.New, c.masterKey, nil, []byte(context))
key := make([]byte, length)
if _, err := io.ReadFull(hkdfReader, key); err != nil {
return nil, err
}
return key, nil
}
// EncryptShare encrypts a key share using AES-256-GCM
func (c *CryptoService) EncryptShare(shareData []byte, partyID string) ([]byte, error) {
// Derive a unique key for this party
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
// Encrypt and prepend nonce
ciphertext := aesGCM.Seal(nonce, nonce, shareData, []byte(partyID))
return ciphertext, nil
}
// DecryptShare decrypts a key share
func (c *CryptoService) DecryptShare(encryptedData []byte, partyID string) ([]byte, error) {
// Derive the same key used for encryption
key, err := c.DeriveKey("share_encryption:"+partyID, 32)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := aesGCM.NonceSize()
if len(encryptedData) < nonceSize {
return nil, ErrInvalidCipherText
}
nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, []byte(partyID))
if err != nil {
return nil, ErrDecryptionFailed
}
return plaintext, nil
}
// EncryptMessage encrypts a message using AES-256-GCM
func (c *CryptoService) EncryptMessage(plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.masterKey)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
// DecryptMessage decrypts a message
func (c *CryptoService) DecryptMessage(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.masterKey)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := aesGCM.NonceSize()
if len(ciphertext) < nonceSize {
return nil, ErrInvalidCipherText
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, ErrDecryptionFailed
}
return plaintext, nil
}
// Hash256 computes SHA-256 hash
func Hash256(data []byte) []byte {
hash := sha256.Sum256(data)
return hash[:]
}
// VerifyECDSASignature verifies an ECDSA signature
func VerifyECDSASignature(messageHash, signature, publicKey []byte) (bool, error) {
// Parse public key (assuming secp256k1/P256 uncompressed format)
curve := elliptic.P256()
x, y := elliptic.Unmarshal(curve, publicKey)
if x == nil {
return false, ErrInvalidPublicKey
}
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
}
// Parse signature (R || S, each 32 bytes)
if len(signature) != 64 {
return false, ErrInvalidSignature
}
r := new(big.Int).SetBytes(signature[:32])
s := new(big.Int).SetBytes(signature[32:])
// Verify signature
valid := ecdsa.Verify(pubKey, messageHash, r, s)
return valid, nil
}
// GenerateNonce generates a cryptographic nonce
func GenerateNonce() ([]byte, error) {
return GenerateRandomBytes(32)
}
// SecureCompare performs constant-time comparison
func SecureCompare(a, b []byte) bool {
if len(a) != len(b) {
return false
}
var result byte
for i := 0; i < len(a); i++ {
result |= a[i] ^ b[i]
}
return result == 0
}
// ParsePublicKey parses a public key from bytes (P256 uncompressed format)
func ParsePublicKey(publicKeyBytes []byte) (*ecdsa.PublicKey, error) {
curve := elliptic.P256()
x, y := elliptic.Unmarshal(curve, publicKeyBytes)
if x == nil {
return nil, ErrInvalidPublicKey
}
return &ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
}, nil
}
// VerifySignature verifies an ECDSA signature using a public key
func VerifySignature(pubKey *ecdsa.PublicKey, messageHash, signature []byte) bool {
// Parse signature (R || S, each 32 bytes)
if len(signature) != 64 {
return false
}
r := new(big.Int).SetBytes(signature[:32])
s := new(big.Int).SetBytes(signature[32:])
return ecdsa.Verify(pubKey, messageHash, r, s)
}
// HashMessage computes SHA-256 hash of a message (alias for Hash256)
func HashMessage(message []byte) []byte {
return Hash256(message)
}
// Encrypt encrypts data using AES-256-GCM with the provided key
func Encrypt(key, plaintext []byte) ([]byte, error) {
if len(key) != 32 {
return nil, ErrInvalidKeySize
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
// Decrypt decrypts data using AES-256-GCM with the provided key
func Decrypt(key, ciphertext []byte) ([]byte, error) {
if len(key) != 32 {
return nil, ErrInvalidKeySize
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := aesGCM.NonceSize()
if len(ciphertext) < nonceSize {
return nil, ErrInvalidCipherText
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, ErrDecryptionFailed
}
return plaintext, nil
}
// DeriveKey derives a key from secret and salt using HKDF (standalone function)
func DeriveKey(secret, salt []byte, length int) ([]byte, error) {
hkdfReader := hkdf.New(sha256.New, secret, salt, nil)
key := make([]byte, length)
if _, err := io.ReadFull(hkdfReader, key); err != nil {
return nil, err
}
return key, nil
}
// SignMessage signs a message using ECDSA private key
func SignMessage(privateKey *ecdsa.PrivateKey, message []byte) ([]byte, error) {
hash := Hash256(message)
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
if err != nil {
return nil, err
}
// Encode R and S as 32 bytes each (total 64 bytes)
signature := make([]byte, 64)
rBytes := r.Bytes()
sBytes := s.Bytes()
// Pad with zeros if necessary
copy(signature[32-len(rBytes):32], rBytes)
copy(signature[64-len(sBytes):64], sBytes)
return signature, nil
}
// EncodeToHex encodes bytes to hex string
func EncodeToHex(data []byte) string {
return hex.EncodeToString(data)
}
// DecodeFromHex decodes hex string to bytes
func DecodeFromHex(s string) ([]byte, error) {
return hex.DecodeString(s)
}
// EncodeToBase64 encodes bytes to base64 string
func EncodeToBase64(data []byte) string {
return hex.EncodeToString(data) // Using hex for simplicity, could use base64
}
// DecodeFromBase64 decodes base64 string to bytes
func DecodeFromBase64(s string) ([]byte, error) {
return hex.DecodeString(s)
}
// MarshalPublicKey marshals an ECDSA public key to bytes
func MarshalPublicKey(pubKey *ecdsa.PublicKey) []byte {
return elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y)
}
// CompareBytes performs constant-time comparison of two byte slices
func CompareBytes(a, b []byte) bool {
return SecureCompare(a, b)
}
package jwt
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("token expired")
ErrInvalidClaims = errors.New("invalid claims")
ErrTokenNotYetValid = errors.New("token not yet valid")
)
// Claims represents custom JWT claims
type Claims struct {
SessionID string `json:"session_id"`
PartyID string `json:"party_id"`
TokenType string `json:"token_type"` // "join", "access", "refresh"
jwt.RegisteredClaims
}
// JWTService provides JWT operations
type JWTService struct {
secretKey []byte
issuer string
tokenExpiry time.Duration
refreshExpiry time.Duration
}
// NewJWTService creates a new JWT service
func NewJWTService(secretKey string, issuer string, tokenExpiry, refreshExpiry time.Duration) *JWTService {
return &JWTService{
secretKey: []byte(secretKey),
issuer: issuer,
tokenExpiry: tokenExpiry,
refreshExpiry: refreshExpiry,
}
}
// GenerateJoinToken generates a token for joining an MPC session
func (s *JWTService) GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error) {
now := time.Now()
claims := Claims{
SessionID: sessionID.String(),
PartyID: partyID,
TokenType: "join",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: partyID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(expiresIn)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// AccessTokenClaims represents claims in an access token
type AccessTokenClaims struct {
Subject string
Username string
Issuer string
}
// GenerateAccessToken generates an access token with username
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
now := time.Now()
claims := Claims{
TokenType: "access",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.tokenExpiry)),
},
}
// Store username in PartyID field for access tokens
claims.PartyID = username
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// GenerateRefreshToken generates a refresh token
func (s *JWTService) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := Claims{
TokenType: "refresh",
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.New().String(),
Issuer: s.issuer,
Subject: userID,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.refreshExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.secretKey)
}
// ValidateToken validates a JWT token and returns the claims
func (s *JWTService) ValidateToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, ErrInvalidToken
}
return s.secretKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrExpiredToken
}
return nil, ErrInvalidToken
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, ErrInvalidClaims
}
return claims, nil
}
// ValidateJoinToken validates a join token for MPC sessions
func (s *JWTService) ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "join" {
return nil, ErrInvalidToken
}
if claims.SessionID != sessionID.String() {
return nil, ErrInvalidClaims
}
if claims.PartyID != partyID {
return nil, ErrInvalidClaims
}
return claims, nil
}
// RefreshAccessToken creates a new access token from a valid refresh token
func (s *JWTService) RefreshAccessToken(refreshToken string) (string, error) {
claims, err := s.ValidateToken(refreshToken)
if err != nil {
return "", err
}
if claims.TokenType != "refresh" {
return "", ErrInvalidToken
}
// PartyID stores the username for access tokens
return s.GenerateAccessToken(claims.Subject, claims.PartyID)
}
// ValidateAccessToken validates an access token and returns structured claims
func (s *JWTService) ValidateAccessToken(tokenString string) (*AccessTokenClaims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "access" {
return nil, ErrInvalidToken
}
return &AccessTokenClaims{
Subject: claims.Subject,
Username: claims.PartyID, // Username stored in PartyID for access tokens
Issuer: claims.Issuer,
}, nil
}
// ValidateRefreshToken validates a refresh token and returns claims
func (s *JWTService) ValidateRefreshToken(tokenString string) (*Claims, error) {
claims, err := s.ValidateToken(tokenString)
if err != nil {
return nil, err
}
if claims.TokenType != "refresh" {
return nil, ErrInvalidToken
}
return claims, nil
}
// TokenGenerator interface for dependency injection
type TokenGenerator interface {
GenerateJoinToken(sessionID uuid.UUID, partyID string, expiresIn time.Duration) (string, error)
}
// TokenValidator interface for dependency injection
type TokenValidator interface {
ValidateJoinToken(tokenString string, sessionID uuid.UUID, partyID string) (*Claims, error)
}
// Ensure JWTService implements interfaces
var _ TokenGenerator = (*JWTService)(nil)
var _ TokenValidator = (*JWTService)(nil)
package utils
import (
"context"
"encoding/json"
"math/big"
"reflect"
"strings"
"time"
"github.com/google/uuid"
)
// GenerateID generates a new UUID
func GenerateID() uuid.UUID {
return uuid.New()
}
// ParseUUID parses a string to UUID
func ParseUUID(s string) (uuid.UUID, error) {
return uuid.Parse(s)
}
// MustParseUUID parses a string to UUID, panics on error
func MustParseUUID(s string) uuid.UUID {
id, err := uuid.Parse(s)
if err != nil {
panic(err)
}
return id
}
// IsValidUUID checks if a string is a valid UUID
func IsValidUUID(s string) bool {
_, err := uuid.Parse(s)
return err == nil
}
// ToJSON converts an interface to JSON bytes
func ToJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// FromJSON converts JSON bytes to an interface
func FromJSON(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}
// NowUTC returns the current UTC time
func NowUTC() time.Time {
return time.Now().UTC()
}
// TimePtr returns a pointer to the time
func TimePtr(t time.Time) *time.Time {
return &t
}
// NowPtr returns a pointer to the current time
func NowPtr() *time.Time {
now := NowUTC()
return &now
}
// BigIntToBytes converts a big.Int to bytes (32 bytes, left-padded)
func BigIntToBytes(n *big.Int) []byte {
if n == nil {
return make([]byte, 32)
}
b := n.Bytes()
if len(b) > 32 {
return b[:32]
}
if len(b) < 32 {
result := make([]byte, 32)
copy(result[32-len(b):], b)
return result
}
return b
}
// BytesToBigInt converts bytes to big.Int
func BytesToBigInt(b []byte) *big.Int {
return new(big.Int).SetBytes(b)
}
// StringSliceContains checks if a string slice contains a value
func StringSliceContains(slice []string, value string) bool {
for _, s := range slice {
if s == value {
return true
}
}
return false
}
// StringSliceRemove removes a value from a string slice
func StringSliceRemove(slice []string, value string) []string {
result := make([]string, 0, len(slice))
for _, s := range slice {
if s != value {
result = append(result, s)
}
}
return result
}
// UniqueStrings returns unique strings from a slice
func UniqueStrings(slice []string) []string {
seen := make(map[string]struct{})
result := make([]string, 0, len(slice))
for _, s := range slice {
if _, ok := seen[s]; !ok {
seen[s] = struct{}{}
result = append(result, s)
}
}
return result
}
// TruncateString truncates a string to max length
func TruncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
// SafeString returns an empty string if the pointer is nil
func SafeString(s *string) string {
if s == nil {
return ""
}
return *s
}
// StringPtr returns a pointer to the string
func StringPtr(s string) *string {
return &s
}
// IntPtr returns a pointer to the int
func IntPtr(i int) *int {
return &i
}
// BoolPtr returns a pointer to the bool
func BoolPtr(b bool) *bool {
return &b
}
// IsZero checks if a value is zero/empty
func IsZero(v interface{}) bool {
return reflect.ValueOf(v).IsZero()
}
// Coalesce returns the first non-zero value
func Coalesce[T comparable](values ...T) T {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// MapKeys returns the keys of a map
func MapKeys[K comparable, V any](m map[K]V) []K {
keys := make([]K, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// MapValues returns the values of a map
func MapValues[K comparable, V any](m map[K]V) []V {
values := make([]V, 0, len(m))
for _, v := range m {
values = append(values, v)
}
return values
}
// Min returns the minimum of two values
func Min[T ~int | ~int64 | ~float64](a, b T) T {
if a < b {
return a
}
return b
}
// Max returns the maximum of two values
func Max[T ~int | ~int64 | ~float64](a, b T) T {
if a > b {
return a
}
return b
}
// Clamp clamps a value between min and max
func Clamp[T ~int | ~int64 | ~float64](value, min, max T) T {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// ContextWithTimeout creates a context with timeout
func ContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// MaskString masks a string showing only first and last n characters
func MaskString(s string, showChars int) string {
if len(s) <= showChars*2 {
return strings.Repeat("*", len(s))
}
return s[:showChars] + strings.Repeat("*", len(s)-showChars*2) + s[len(s)-showChars:]
}
// Retry executes a function with retries
func Retry(attempts int, sleep time.Duration, f func() error) error {
var err error
for i := 0; i < attempts; i++ {
if err = f(); err == nil {
return nil
}
if i < attempts-1 {
time.Sleep(sleep)
sleep *= 2 // Exponential backoff
}
}
return err
}
package entities
import (
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
)
// Account represents a user account with MPC-based authentication
type Account struct {
ID value_objects.AccountID
Username string
Email string
Phone *string
PublicKey []byte // MPC group public key
KeygenSessionID uuid.UUID
ThresholdN int
ThresholdT int
Status value_objects.AccountStatus
CreatedAt time.Time
UpdatedAt time.Time
LastLoginAt *time.Time
}
// NewAccount creates a new Account
func NewAccount(
username string,
email string,
publicKey []byte,
keygenSessionID uuid.UUID,
thresholdN int,
thresholdT int,
) *Account {
now := time.Now().UTC()
return &Account{
ID: value_objects.NewAccountID(),
Username: username,
Email: email,
PublicKey: publicKey,
KeygenSessionID: keygenSessionID,
ThresholdN: thresholdN,
ThresholdT: thresholdT,
Status: value_objects.AccountStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
}
// SetPhone sets the phone number
func (a *Account) SetPhone(phone string) {
a.Phone = &phone
a.UpdatedAt = time.Now().UTC()
}
// UpdateLastLogin updates the last login timestamp
func (a *Account) UpdateLastLogin() {
now := time.Now().UTC()
a.LastLoginAt = &now
a.UpdatedAt = now
}
// Suspend suspends the account
func (a *Account) Suspend() error {
if a.Status == value_objects.AccountStatusRecovering {
return ErrAccountInRecovery
}
a.Status = value_objects.AccountStatusSuspended
a.UpdatedAt = time.Now().UTC()
return nil
}
// Lock locks the account
func (a *Account) Lock() error {
if a.Status == value_objects.AccountStatusRecovering {
return ErrAccountInRecovery
}
a.Status = value_objects.AccountStatusLocked
a.UpdatedAt = time.Now().UTC()
return nil
}
// Activate activates the account
func (a *Account) Activate() {
a.Status = value_objects.AccountStatusActive
a.UpdatedAt = time.Now().UTC()
}
// StartRecovery marks the account as recovering
func (a *Account) StartRecovery() error {
if !a.Status.CanInitiateRecovery() {
return ErrCannotInitiateRecovery
}
a.Status = value_objects.AccountStatusRecovering
a.UpdatedAt = time.Now().UTC()
return nil
}
// CompleteRecovery completes the recovery process with new public key
func (a *Account) CompleteRecovery(newPublicKey []byte, newKeygenSessionID uuid.UUID) {
a.PublicKey = newPublicKey
a.KeygenSessionID = newKeygenSessionID
a.Status = value_objects.AccountStatusActive
a.UpdatedAt = time.Now().UTC()
}
// CanLogin checks if the account can login
func (a *Account) CanLogin() bool {
return a.Status.CanLogin()
}
// IsActive checks if the account is active
func (a *Account) IsActive() bool {
return a.Status == value_objects.AccountStatusActive
}
// Validate validates the account data
func (a *Account) Validate() error {
if a.Username == "" {
return ErrInvalidUsername
}
if a.Email == "" {
return ErrInvalidEmail
}
if len(a.PublicKey) == 0 {
return ErrInvalidPublicKey
}
if a.ThresholdT > a.ThresholdN || a.ThresholdT <= 0 {
return ErrInvalidThreshold
}
return nil
}
// Account errors
var (
ErrInvalidUsername = &AccountError{Code: "INVALID_USERNAME", Message: "username is required"}
ErrInvalidEmail = &AccountError{Code: "INVALID_EMAIL", Message: "email is required"}
ErrInvalidPublicKey = &AccountError{Code: "INVALID_PUBLIC_KEY", Message: "public key is required"}
ErrInvalidThreshold = &AccountError{Code: "INVALID_THRESHOLD", Message: "invalid threshold configuration"}
ErrAccountInRecovery = &AccountError{Code: "ACCOUNT_IN_RECOVERY", Message: "account is in recovery mode"}
ErrCannotInitiateRecovery = &AccountError{Code: "CANNOT_INITIATE_RECOVERY", Message: "cannot initiate recovery in current state"}
ErrAccountNotActive = &AccountError{Code: "ACCOUNT_NOT_ACTIVE", Message: "account is not active"}
ErrAccountNotFound = &AccountError{Code: "ACCOUNT_NOT_FOUND", Message: "account not found"}
ErrDuplicateUsername = &AccountError{Code: "DUPLICATE_USERNAME", Message: "username already exists"}
ErrDuplicateEmail = &AccountError{Code: "DUPLICATE_EMAIL", Message: "email already exists"}
)
// AccountError represents an account domain error
type AccountError struct {
Code string
Message string
}
func (e *AccountError) Error() string {
return e.Message
}
package entities
import (
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
)
// AccountShare represents a mapping of key share to account
// Note: This records share location, not share content
type AccountShare struct {
ID uuid.UUID
AccountID value_objects.AccountID
ShareType value_objects.ShareType
PartyID string
PartyIndex int
DeviceType *string
DeviceID *string
CreatedAt time.Time
LastUsedAt *time.Time
IsActive bool
}
// NewAccountShare creates a new AccountShare
func NewAccountShare(
accountID value_objects.AccountID,
shareType value_objects.ShareType,
partyID string,
partyIndex int,
) *AccountShare {
return &AccountShare{
ID: uuid.New(),
AccountID: accountID,
ShareType: shareType,
PartyID: partyID,
PartyIndex: partyIndex,
CreatedAt: time.Now().UTC(),
IsActive: true,
}
}
// SetDeviceInfo sets device information for user device shares
func (s *AccountShare) SetDeviceInfo(deviceType, deviceID string) {
s.DeviceType = &deviceType
s.DeviceID = &deviceID
}
// UpdateLastUsed updates the last used timestamp
func (s *AccountShare) UpdateLastUsed() {
now := time.Now().UTC()
s.LastUsedAt = &now
}
// Deactivate deactivates the share (e.g., when device is lost)
func (s *AccountShare) Deactivate() {
s.IsActive = false
}
// Activate activates the share
func (s *AccountShare) Activate() {
s.IsActive = true
}
// IsUserDeviceShare checks if this is a user device share
func (s *AccountShare) IsUserDeviceShare() bool {
return s.ShareType == value_objects.ShareTypeUserDevice
}
// IsServerShare checks if this is a server share
func (s *AccountShare) IsServerShare() bool {
return s.ShareType == value_objects.ShareTypeServer
}
// IsRecoveryShare checks if this is a recovery share
func (s *AccountShare) IsRecoveryShare() bool {
return s.ShareType == value_objects.ShareTypeRecovery
}
// Validate validates the account share
func (s *AccountShare) Validate() error {
if s.AccountID.IsZero() {
return ErrShareInvalidAccountID
}
if !s.ShareType.IsValid() {
return ErrShareInvalidType
}
if s.PartyID == "" {
return ErrShareInvalidPartyID
}
if s.PartyIndex < 0 {
return ErrShareInvalidPartyIndex
}
return nil
}
// AccountShare errors
var (
ErrShareInvalidAccountID = &AccountError{Code: "SHARE_INVALID_ACCOUNT_ID", Message: "invalid account ID"}
ErrShareInvalidType = &AccountError{Code: "SHARE_INVALID_TYPE", Message: "invalid share type"}
ErrShareInvalidPartyID = &AccountError{Code: "SHARE_INVALID_PARTY_ID", Message: "invalid party ID"}
ErrShareInvalidPartyIndex = &AccountError{Code: "SHARE_INVALID_PARTY_INDEX", Message: "invalid party index"}
ErrShareNotFound = &AccountError{Code: "SHARE_NOT_FOUND", Message: "share not found"}
)
package entities
import (
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
)
// RecoverySession represents an account recovery session
type RecoverySession struct {
ID uuid.UUID
AccountID value_objects.AccountID
RecoveryType value_objects.RecoveryType
OldShareType *value_objects.ShareType
NewKeygenSessionID *uuid.UUID
Status value_objects.RecoveryStatus
RequestedAt time.Time
CompletedAt *time.Time
}
// NewRecoverySession creates a new RecoverySession
func NewRecoverySession(
accountID value_objects.AccountID,
recoveryType value_objects.RecoveryType,
) *RecoverySession {
return &RecoverySession{
ID: uuid.New(),
AccountID: accountID,
RecoveryType: recoveryType,
Status: value_objects.RecoveryStatusRequested,
RequestedAt: time.Now().UTC(),
}
}
// SetOldShareType sets the old share type being replaced
func (r *RecoverySession) SetOldShareType(shareType value_objects.ShareType) {
r.OldShareType = &shareType
}
// StartKeygen starts the keygen process for recovery
func (r *RecoverySession) StartKeygen(keygenSessionID uuid.UUID) error {
if r.Status != value_objects.RecoveryStatusRequested {
return ErrRecoveryInvalidState
}
r.NewKeygenSessionID = &keygenSessionID
r.Status = value_objects.RecoveryStatusInProgress
return nil
}
// Complete marks the recovery as completed
func (r *RecoverySession) Complete() error {
if r.Status != value_objects.RecoveryStatusInProgress {
return ErrRecoveryInvalidState
}
now := time.Now().UTC()
r.CompletedAt = &now
r.Status = value_objects.RecoveryStatusCompleted
return nil
}
// Fail marks the recovery as failed
func (r *RecoverySession) Fail() error {
if r.Status == value_objects.RecoveryStatusCompleted {
return ErrRecoveryAlreadyCompleted
}
r.Status = value_objects.RecoveryStatusFailed
return nil
}
// IsCompleted checks if recovery is completed
func (r *RecoverySession) IsCompleted() bool {
return r.Status == value_objects.RecoveryStatusCompleted
}
// IsFailed checks if recovery failed
func (r *RecoverySession) IsFailed() bool {
return r.Status == value_objects.RecoveryStatusFailed
}
// IsInProgress checks if recovery is in progress
func (r *RecoverySession) IsInProgress() bool {
return r.Status == value_objects.RecoveryStatusInProgress
}
// Validate validates the recovery session
func (r *RecoverySession) Validate() error {
if r.AccountID.IsZero() {
return ErrRecoveryInvalidAccountID
}
if !r.RecoveryType.IsValid() {
return ErrRecoveryInvalidType
}
return nil
}
// Recovery errors
var (
ErrRecoveryInvalidAccountID = &AccountError{Code: "RECOVERY_INVALID_ACCOUNT_ID", Message: "invalid account ID for recovery"}
ErrRecoveryInvalidType = &AccountError{Code: "RECOVERY_INVALID_TYPE", Message: "invalid recovery type"}
ErrRecoveryInvalidState = &AccountError{Code: "RECOVERY_INVALID_STATE", Message: "invalid recovery state for this operation"}
ErrRecoveryAlreadyCompleted = &AccountError{Code: "RECOVERY_ALREADY_COMPLETED", Message: "recovery already completed"}
ErrRecoveryNotFound = &AccountError{Code: "RECOVERY_NOT_FOUND", Message: "recovery session not found"}
)
package value_objects
import (
"github.com/google/uuid"
)
// AccountID represents a unique account identifier
type AccountID struct {
value uuid.UUID
}
// NewAccountID creates a new AccountID
func NewAccountID() AccountID {
return AccountID{value: uuid.New()}
}
// AccountIDFromString creates an AccountID from a string
func AccountIDFromString(s string) (AccountID, error) {
id, err := uuid.Parse(s)
if err != nil {
return AccountID{}, err
}
return AccountID{value: id}, nil
}
// AccountIDFromUUID creates an AccountID from a UUID
func AccountIDFromUUID(id uuid.UUID) AccountID {
return AccountID{value: id}
}
// String returns the string representation
func (id AccountID) String() string {
return id.value.String()
}
// UUID returns the UUID value
func (id AccountID) UUID() uuid.UUID {
return id.value
}
// IsZero checks if the AccountID is zero
func (id AccountID) IsZero() bool {
return id.value == uuid.Nil
}
// Equals checks if two AccountIDs are equal
func (id AccountID) Equals(other AccountID) bool {
return id.value == other.value
}
package value_objects
// AccountStatus represents the status of an account
type AccountStatus string
const (
AccountStatusActive AccountStatus = "active"
AccountStatusSuspended AccountStatus = "suspended"
AccountStatusLocked AccountStatus = "locked"
AccountStatusRecovering AccountStatus = "recovering"
)
// String returns the string representation
func (s AccountStatus) String() string {
return string(s)
}
// IsValid checks if the status is valid
func (s AccountStatus) IsValid() bool {
switch s {
case AccountStatusActive, AccountStatusSuspended, AccountStatusLocked, AccountStatusRecovering:
return true
default:
return false
}
}
// CanLogin checks if the account can login with this status
func (s AccountStatus) CanLogin() bool {
return s == AccountStatusActive
}
// CanInitiateRecovery checks if recovery can be initiated
func (s AccountStatus) CanInitiateRecovery() bool {
return s == AccountStatusActive || s == AccountStatusLocked
}
// ShareType represents the type of key share
type ShareType string
const (
ShareTypeUserDevice ShareType = "user_device"
ShareTypeServer ShareType = "server"
ShareTypeRecovery ShareType = "recovery"
)
// String returns the string representation
func (st ShareType) String() string {
return string(st)
}
// IsValid checks if the share type is valid
func (st ShareType) IsValid() bool {
switch st {
case ShareTypeUserDevice, ShareTypeServer, ShareTypeRecovery:
return true
default:
return false
}
}
// RecoveryType represents the type of account recovery
type RecoveryType string
const (
RecoveryTypeDeviceLost RecoveryType = "device_lost"
RecoveryTypeShareRotation RecoveryType = "share_rotation"
)
// String returns the string representation
func (rt RecoveryType) String() string {
return string(rt)
}
// IsValid checks if the recovery type is valid
func (rt RecoveryType) IsValid() bool {
switch rt {
case RecoveryTypeDeviceLost, RecoveryTypeShareRotation:
return true
default:
return false
}
}
// RecoveryStatus represents the status of a recovery session
type RecoveryStatus string
const (
RecoveryStatusRequested RecoveryStatus = "requested"
RecoveryStatusInProgress RecoveryStatus = "in_progress"
RecoveryStatusCompleted RecoveryStatus = "completed"
RecoveryStatusFailed RecoveryStatus = "failed"
)
// String returns the string representation
func (rs RecoveryStatus) String() string {
return string(rs)
}
// IsValid checks if the recovery status is valid
func (rs RecoveryStatus) IsValid() bool {
switch rs {
case RecoveryStatusRequested, RecoveryStatusInProgress, RecoveryStatusCompleted, RecoveryStatusFailed:
return true
default:
return false
}
}
package entities
// DeviceType represents the type of device
type DeviceType string
const (
DeviceTypeAndroid DeviceType = "android"
DeviceTypeIOS DeviceType = "ios"
DeviceTypePC DeviceType = "pc"
DeviceTypeServer DeviceType = "server"
DeviceTypeRecovery DeviceType = "recovery"
)
// DeviceInfo holds information about a participant's device
type DeviceInfo struct {
DeviceType DeviceType `json:"device_type"`
DeviceID string `json:"device_id"`
Platform string `json:"platform"`
AppVersion string `json:"app_version"`
}
// NewDeviceInfo creates a new DeviceInfo
func NewDeviceInfo(deviceType DeviceType, deviceID, platform, appVersion string) DeviceInfo {
return DeviceInfo{
DeviceType: deviceType,
DeviceID: deviceID,
Platform: platform,
AppVersion: appVersion,
}
}
// IsServer checks if the device is a server
func (d DeviceInfo) IsServer() bool {
return d.DeviceType == DeviceTypeServer
}
// IsMobile checks if the device is mobile
func (d DeviceInfo) IsMobile() bool {
return d.DeviceType == DeviceTypeAndroid || d.DeviceType == DeviceTypeIOS
}
// IsRecovery checks if the device is a recovery device
func (d DeviceInfo) IsRecovery() bool {
return d.DeviceType == DeviceTypeRecovery
}
// Validate validates the device info
func (d DeviceInfo) Validate() error {
if d.DeviceType == "" {
return ErrInvalidDeviceInfo
}
return nil
}
package entities
import (
"errors"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
)
var (
ErrSessionFull = errors.New("session is full")
ErrSessionExpired = errors.New("session expired")
ErrSessionNotInProgress = errors.New("session not in progress")
ErrParticipantNotFound = errors.New("participant not found")
ErrInvalidSessionType = errors.New("invalid session type")
ErrInvalidStatusTransition = errors.New("invalid status transition")
)
// SessionType represents the type of MPC session
type SessionType string
const (
SessionTypeKeygen SessionType = "keygen"
SessionTypeSign SessionType = "sign"
)
// IsValid checks if the session type is valid
func (t SessionType) IsValid() bool {
return t == SessionTypeKeygen || t == SessionTypeSign
}
// MPCSession represents an MPC session
// Coordinator only manages session metadata, does not participate in MPC computation
type MPCSession struct {
ID value_objects.SessionID
SessionType SessionType
Threshold value_objects.Threshold
Participants []*Participant
Status value_objects.SessionStatus
MessageHash []byte // Used for Sign sessions
PublicKey []byte // Group public key after Keygen completion
CreatedBy string
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time
CompletedAt *time.Time
}
// NewMPCSession creates a new MPC session
func NewMPCSession(
sessionType SessionType,
threshold value_objects.Threshold,
createdBy string,
expiresIn time.Duration,
messageHash []byte, // Only for Sign sessions
) (*MPCSession, error) {
if !sessionType.IsValid() {
return nil, ErrInvalidSessionType
}
if sessionType == SessionTypeSign && len(messageHash) == 0 {
return nil, errors.New("message hash required for sign session")
}
now := time.Now().UTC()
return &MPCSession{
ID: value_objects.NewSessionID(),
SessionType: sessionType,
Threshold: threshold,
Participants: make([]*Participant, 0, threshold.N()),
Status: value_objects.SessionStatusCreated,
MessageHash: messageHash,
CreatedBy: createdBy,
CreatedAt: now,
UpdatedAt: now,
ExpiresAt: now.Add(expiresIn),
}, nil
}
// AddParticipant adds a participant to the session
func (s *MPCSession) AddParticipant(p *Participant) error {
if len(s.Participants) >= s.Threshold.N() {
return ErrSessionFull
}
s.Participants = append(s.Participants, p)
s.UpdatedAt = time.Now().UTC()
return nil
}
// GetParticipant gets a participant by party ID
func (s *MPCSession) GetParticipant(partyID value_objects.PartyID) (*Participant, error) {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
return p, nil
}
}
return nil, ErrParticipantNotFound
}
// UpdateParticipantStatus updates a participant's status
func (s *MPCSession) UpdateParticipantStatus(partyID value_objects.PartyID, status value_objects.ParticipantStatus) error {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
switch status {
case value_objects.ParticipantStatusJoined:
return p.Join()
case value_objects.ParticipantStatusReady:
return p.MarkReady()
case value_objects.ParticipantStatusCompleted:
return p.MarkCompleted()
case value_objects.ParticipantStatusFailed:
p.MarkFailed()
return nil
default:
return errors.New("invalid status")
}
}
}
return ErrParticipantNotFound
}
// CanStart checks if all participants have joined and the session can start
func (s *MPCSession) CanStart() bool {
if len(s.Participants) != s.Threshold.N() {
return false
}
joinedCount := 0
for _, p := range s.Participants {
if p.IsJoined() {
joinedCount++
}
}
return joinedCount == s.Threshold.N()
}
// Start transitions the session to in_progress
func (s *MPCSession) Start() error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusInProgress) {
return ErrInvalidStatusTransition
}
if !s.CanStart() {
return errors.New("not all participants have joined")
}
s.Status = value_objects.SessionStatusInProgress
s.UpdatedAt = time.Now().UTC()
return nil
}
// Complete marks the session as completed
func (s *MPCSession) Complete(publicKey []byte) error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusCompleted) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusCompleted
s.PublicKey = publicKey
now := time.Now().UTC()
s.CompletedAt = &now
s.UpdatedAt = now
return nil
}
// Fail marks the session as failed
func (s *MPCSession) Fail() error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusFailed) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusFailed
s.UpdatedAt = time.Now().UTC()
return nil
}
// Expire marks the session as expired
func (s *MPCSession) Expire() error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusExpired) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusExpired
s.UpdatedAt = time.Now().UTC()
return nil
}
// IsExpired checks if the session has expired
func (s *MPCSession) IsExpired() bool {
return time.Now().UTC().After(s.ExpiresAt)
}
// IsActive checks if the session is active
func (s *MPCSession) IsActive() bool {
return s.Status.IsActive() && !s.IsExpired()
}
// IsParticipant checks if a party is a participant
func (s *MPCSession) IsParticipant(partyID value_objects.PartyID) bool {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
return true
}
}
return false
}
// AllCompleted checks if all participants have completed
func (s *MPCSession) AllCompleted() bool {
for _, p := range s.Participants {
if !p.IsCompleted() {
return false
}
}
return true
}
// CompletedCount returns the number of completed participants
func (s *MPCSession) CompletedCount() int {
count := 0
for _, p := range s.Participants {
if p.IsCompleted() {
count++
}
}
return count
}
// JoinedCount returns the number of joined participants
func (s *MPCSession) JoinedCount() int {
count := 0
for _, p := range s.Participants {
if p.IsJoined() {
count++
}
}
return count
}
// GetPartyIDs returns all party IDs
func (s *MPCSession) GetPartyIDs() []string {
ids := make([]string, len(s.Participants))
for i, p := range s.Participants {
ids[i] = p.PartyID.String()
}
return ids
}
// GetOtherParties returns participants except the specified party
func (s *MPCSession) GetOtherParties(excludePartyID value_objects.PartyID) []*Participant {
others := make([]*Participant, 0, len(s.Participants)-1)
for _, p := range s.Participants {
if !p.PartyID.Equals(excludePartyID) {
others = append(others, p)
}
}
return others
}
// ToDTO converts to a DTO for API responses
func (s *MPCSession) ToDTO() SessionDTO {
participants := make([]ParticipantDTO, len(s.Participants))
for i, p := range s.Participants {
participants[i] = ParticipantDTO{
PartyID: p.PartyID.String(),
PartyIndex: p.PartyIndex,
Status: p.Status.String(),
DeviceType: string(p.DeviceInfo.DeviceType),
}
}
return SessionDTO{
ID: s.ID.String(),
SessionType: string(s.SessionType),
ThresholdN: s.Threshold.N(),
ThresholdT: s.Threshold.T(),
Participants: participants,
Status: s.Status.String(),
CreatedAt: s.CreatedAt,
ExpiresAt: s.ExpiresAt,
}
}
// SessionDTO is a data transfer object for sessions
type SessionDTO struct {
ID string `json:"id"`
SessionType string `json:"session_type"`
ThresholdN int `json:"threshold_n"`
ThresholdT int `json:"threshold_t"`
Participants []ParticipantDTO `json:"participants"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// ParticipantDTO is a data transfer object for participants
type ParticipantDTO struct {
PartyID string `json:"party_id"`
PartyIndex int `json:"party_index"`
Status string `json:"status"`
DeviceType string `json:"device_type"`
}
// Reconstruct reconstructs an MPCSession from database
func ReconstructSession(
id uuid.UUID,
sessionType string,
thresholdT, thresholdN int,
status string,
messageHash, publicKey []byte,
createdBy string,
createdAt, updatedAt, expiresAt time.Time,
completedAt *time.Time,
participants []*Participant,
) (*MPCSession, error) {
sessionStatus, err := value_objects.NewSessionStatus(status)
if err != nil {
return nil, err
}
threshold, err := value_objects.NewThreshold(thresholdT, thresholdN)
if err != nil {
return nil, err
}
return &MPCSession{
ID: value_objects.SessionIDFromUUID(id),
SessionType: SessionType(sessionType),
Threshold: threshold,
Participants: participants,
Status: sessionStatus,
MessageHash: messageHash,
PublicKey: publicKey,
CreatedBy: createdBy,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
ExpiresAt: expiresAt,
CompletedAt: completedAt,
}, nil
}
package entities
import (
"errors"
"time"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
)
var (
ErrInvalidDeviceInfo = errors.New("invalid device info")
ErrParticipantNotInvited = errors.New("participant not in invited status")
ErrInvalidParticipant = errors.New("invalid participant")
)
// Participant represents a party in an MPC session
type Participant struct {
PartyID value_objects.PartyID
PartyIndex int
Status value_objects.ParticipantStatus
DeviceInfo DeviceInfo
PublicKey []byte // Party's identity public key (for authentication)
JoinedAt time.Time
CompletedAt *time.Time
}
// NewParticipant creates a new participant
func NewParticipant(partyID value_objects.PartyID, partyIndex int, deviceInfo DeviceInfo) (*Participant, error) {
if partyID.IsZero() {
return nil, ErrInvalidParticipant
}
if partyIndex < 0 {
return nil, ErrInvalidParticipant
}
if err := deviceInfo.Validate(); err != nil {
return nil, err
}
return &Participant{
PartyID: partyID,
PartyIndex: partyIndex,
Status: value_objects.ParticipantStatusInvited,
DeviceInfo: deviceInfo,
JoinedAt: time.Now().UTC(),
}, nil
}
// Join marks the participant as joined
func (p *Participant) Join() error {
if !p.Status.CanTransitionTo(value_objects.ParticipantStatusJoined) {
return errors.New("cannot transition to joined status")
}
p.Status = value_objects.ParticipantStatusJoined
p.JoinedAt = time.Now().UTC()
return nil
}
// MarkReady marks the participant as ready
func (p *Participant) MarkReady() error {
if !p.Status.CanTransitionTo(value_objects.ParticipantStatusReady) {
return errors.New("cannot transition to ready status")
}
p.Status = value_objects.ParticipantStatusReady
return nil
}
// MarkCompleted marks the participant as completed
func (p *Participant) MarkCompleted() error {
if !p.Status.CanTransitionTo(value_objects.ParticipantStatusCompleted) {
return errors.New("cannot transition to completed status")
}
p.Status = value_objects.ParticipantStatusCompleted
now := time.Now().UTC()
p.CompletedAt = &now
return nil
}
// MarkFailed marks the participant as failed
func (p *Participant) MarkFailed() {
p.Status = value_objects.ParticipantStatusFailed
}
// IsJoined checks if the participant has joined
func (p *Participant) IsJoined() bool {
return p.Status == value_objects.ParticipantStatusJoined ||
p.Status == value_objects.ParticipantStatusReady ||
p.Status == value_objects.ParticipantStatusCompleted
}
// IsReady checks if the participant is ready
func (p *Participant) IsReady() bool {
return p.Status == value_objects.ParticipantStatusReady ||
p.Status == value_objects.ParticipantStatusCompleted
}
// IsCompleted checks if the participant has completed
func (p *Participant) IsCompleted() bool {
return p.Status == value_objects.ParticipantStatusCompleted
}
// IsFailed checks if the participant has failed
func (p *Participant) IsFailed() bool {
return p.Status == value_objects.ParticipantStatusFailed
}
// SetPublicKey sets the participant's public key
func (p *Participant) SetPublicKey(publicKey []byte) {
p.PublicKey = publicKey
}
package entities
import (
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
)
// SessionMessage represents an MPC message (encrypted, Coordinator does not decrypt)
type SessionMessage struct {
ID uuid.UUID
SessionID value_objects.SessionID
FromParty value_objects.PartyID
ToParties []value_objects.PartyID // nil means broadcast
RoundNumber int
MessageType string
Payload []byte // Encrypted MPC protocol message
CreatedAt time.Time
DeliveredAt *time.Time
}
// NewSessionMessage creates a new session message
func NewSessionMessage(
sessionID value_objects.SessionID,
fromParty value_objects.PartyID,
toParties []value_objects.PartyID,
roundNumber int,
messageType string,
payload []byte,
) *SessionMessage {
return &SessionMessage{
ID: uuid.New(),
SessionID: sessionID,
FromParty: fromParty,
ToParties: toParties,
RoundNumber: roundNumber,
MessageType: messageType,
Payload: payload,
CreatedAt: time.Now().UTC(),
}
}
// IsBroadcast checks if the message is a broadcast
func (m *SessionMessage) IsBroadcast() bool {
return len(m.ToParties) == 0
}
// IsFor checks if the message is for a specific party
func (m *SessionMessage) IsFor(partyID value_objects.PartyID) bool {
if m.IsBroadcast() {
// Broadcast is for everyone except sender
return !m.FromParty.Equals(partyID)
}
for _, to := range m.ToParties {
if to.Equals(partyID) {
return true
}
}
return false
}
// MarkDelivered marks the message as delivered
func (m *SessionMessage) MarkDelivered() {
now := time.Now().UTC()
m.DeliveredAt = &now
}
// IsDelivered checks if the message has been delivered
func (m *SessionMessage) IsDelivered() bool {
return m.DeliveredAt != nil
}
// GetToPartyStrings returns to parties as strings
func (m *SessionMessage) GetToPartyStrings() []string {
if m.IsBroadcast() {
return nil
}
result := make([]string, len(m.ToParties))
for i, p := range m.ToParties {
result[i] = p.String()
}
return result
}
// ToDTO converts to a DTO
func (m *SessionMessage) ToDTO() MessageDTO {
toParties := m.GetToPartyStrings()
return MessageDTO{
ID: m.ID.String(),
SessionID: m.SessionID.String(),
FromParty: m.FromParty.String(),
ToParties: toParties,
IsBroadcast: m.IsBroadcast(),
RoundNumber: m.RoundNumber,
MessageType: m.MessageType,
Payload: m.Payload,
CreatedAt: m.CreatedAt,
}
}
// MessageDTO is a data transfer object for messages
type MessageDTO struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
FromParty string `json:"from_party"`
ToParties []string `json:"to_parties,omitempty"`
IsBroadcast bool `json:"is_broadcast"`
RoundNumber int `json:"round_number"`
MessageType string `json:"message_type"`
Payload []byte `json:"payload"`
CreatedAt time.Time `json:"created_at"`
}
package value_objects
import (
"errors"
"regexp"
)
var (
ErrInvalidPartyID = errors.New("invalid party ID")
partyIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
// PartyID represents a unique party identifier
type PartyID struct {
value string
}
// NewPartyID creates a new PartyID
func NewPartyID(value string) (PartyID, error) {
if value == "" {
return PartyID{}, ErrInvalidPartyID
}
if !partyIDRegex.MatchString(value) {
return PartyID{}, ErrInvalidPartyID
}
if len(value) > 255 {
return PartyID{}, ErrInvalidPartyID
}
return PartyID{value: value}, nil
}
// MustNewPartyID creates a new PartyID, panics on error
func MustNewPartyID(value string) PartyID {
id, err := NewPartyID(value)
if err != nil {
panic(err)
}
return id
}
// String returns the string representation
func (id PartyID) String() string {
return id.value
}
// IsZero checks if the PartyID is zero
func (id PartyID) IsZero() bool {
return id.value == ""
}
// Equals checks if two PartyIDs are equal
func (id PartyID) Equals(other PartyID) bool {
return id.value == other.value
}
package value_objects
import (
"github.com/google/uuid"
)
// SessionID represents a unique session identifier
type SessionID struct {
value uuid.UUID
}
// NewSessionID creates a new SessionID
func NewSessionID() SessionID {
return SessionID{value: uuid.New()}
}
// SessionIDFromString creates a SessionID from a string
func SessionIDFromString(s string) (SessionID, error) {
id, err := uuid.Parse(s)
if err != nil {
return SessionID{}, err
}
return SessionID{value: id}, nil
}
// SessionIDFromUUID creates a SessionID from a UUID
func SessionIDFromUUID(id uuid.UUID) SessionID {
return SessionID{value: id}
}
// String returns the string representation
func (id SessionID) String() string {
return id.value.String()
}
// UUID returns the UUID value
func (id SessionID) UUID() uuid.UUID {
return id.value
}
// IsZero checks if the SessionID is zero
func (id SessionID) IsZero() bool {
return id.value == uuid.Nil
}
// Equals checks if two SessionIDs are equal
func (id SessionID) Equals(other SessionID) bool {
return id.value == other.value
}
package value_objects
import (
"errors"
)
var ErrInvalidSessionStatus = errors.New("invalid session status")
// SessionStatus represents the status of an MPC session
type SessionStatus string
const (
SessionStatusCreated SessionStatus = "created"
SessionStatusInProgress SessionStatus = "in_progress"
SessionStatusCompleted SessionStatus = "completed"
SessionStatusFailed SessionStatus = "failed"
SessionStatusExpired SessionStatus = "expired"
)
// ValidSessionStatuses contains all valid session statuses
var ValidSessionStatuses = []SessionStatus{
SessionStatusCreated,
SessionStatusInProgress,
SessionStatusCompleted,
SessionStatusFailed,
SessionStatusExpired,
}
// NewSessionStatus creates a new SessionStatus from string
func NewSessionStatus(s string) (SessionStatus, error) {
status := SessionStatus(s)
if !status.IsValid() {
return "", ErrInvalidSessionStatus
}
return status, nil
}
// String returns the string representation
func (s SessionStatus) String() string {
return string(s)
}
// IsValid checks if the status is valid
func (s SessionStatus) IsValid() bool {
for _, valid := range ValidSessionStatuses {
if s == valid {
return true
}
}
return false
}
// CanTransitionTo checks if the status can transition to another
func (s SessionStatus) CanTransitionTo(target SessionStatus) bool {
transitions := map[SessionStatus][]SessionStatus{
SessionStatusCreated: {SessionStatusInProgress, SessionStatusFailed, SessionStatusExpired},
SessionStatusInProgress: {SessionStatusCompleted, SessionStatusFailed, SessionStatusExpired},
SessionStatusCompleted: {},
SessionStatusFailed: {},
SessionStatusExpired: {},
}
allowed, ok := transitions[s]
if !ok {
return false
}
for _, status := range allowed {
if status == target {
return true
}
}
return false
}
// IsTerminal checks if the status is terminal (cannot transition)
func (s SessionStatus) IsTerminal() bool {
return s == SessionStatusCompleted || s == SessionStatusFailed || s == SessionStatusExpired
}
// IsActive checks if the session is active
func (s SessionStatus) IsActive() bool {
return s == SessionStatusCreated || s == SessionStatusInProgress
}
// ParticipantStatus represents the status of a participant
type ParticipantStatus string
const (
ParticipantStatusInvited ParticipantStatus = "invited"
ParticipantStatusJoined ParticipantStatus = "joined"
ParticipantStatusReady ParticipantStatus = "ready"
ParticipantStatusCompleted ParticipantStatus = "completed"
ParticipantStatusFailed ParticipantStatus = "failed"
)
// ValidParticipantStatuses contains all valid participant statuses
var ValidParticipantStatuses = []ParticipantStatus{
ParticipantStatusInvited,
ParticipantStatusJoined,
ParticipantStatusReady,
ParticipantStatusCompleted,
ParticipantStatusFailed,
}
// String returns the string representation
func (s ParticipantStatus) String() string {
return string(s)
}
// IsValid checks if the status is valid
func (s ParticipantStatus) IsValid() bool {
for _, valid := range ValidParticipantStatuses {
if s == valid {
return true
}
}
return false
}
// CanTransitionTo checks if the status can transition to another
func (s ParticipantStatus) CanTransitionTo(target ParticipantStatus) bool {
transitions := map[ParticipantStatus][]ParticipantStatus{
ParticipantStatusInvited: {ParticipantStatusJoined, ParticipantStatusFailed},
ParticipantStatusJoined: {ParticipantStatusReady, ParticipantStatusFailed},
ParticipantStatusReady: {ParticipantStatusCompleted, ParticipantStatusFailed},
ParticipantStatusCompleted: {},
ParticipantStatusFailed: {},
}
allowed, ok := transitions[s]
if !ok {
return false
}
for _, status := range allowed {
if status == target {
return true
}
}
return false
}
package value_objects
import (
"errors"
"fmt"
)
var (
ErrInvalidThreshold = errors.New("invalid threshold")
ErrThresholdTooLarge = errors.New("threshold t cannot exceed n")
ErrThresholdTooSmall = errors.New("threshold t must be at least 1")
ErrNTooSmall = errors.New("n must be at least 2")
ErrNTooLarge = errors.New("n cannot exceed maximum allowed")
)
const (
MinN = 2
MaxN = 10
MinT = 1
)
// Threshold represents the t-of-n threshold configuration
type Threshold struct {
t int // Minimum number of parties required
n int // Total number of parties
}
// NewThreshold creates a new Threshold value object
func NewThreshold(t, n int) (Threshold, error) {
if n < MinN {
return Threshold{}, ErrNTooSmall
}
if n > MaxN {
return Threshold{}, ErrNTooLarge
}
if t < MinT {
return Threshold{}, ErrThresholdTooSmall
}
if t > n {
return Threshold{}, ErrThresholdTooLarge
}
return Threshold{t: t, n: n}, nil
}
// MustNewThreshold creates a new Threshold, panics on error
func MustNewThreshold(t, n int) Threshold {
threshold, err := NewThreshold(t, n)
if err != nil {
panic(err)
}
return threshold
}
// T returns the minimum required parties
func (th Threshold) T() int {
return th.t
}
// N returns the total parties
func (th Threshold) N() int {
return th.n
}
// IsZero checks if the Threshold is zero
func (th Threshold) IsZero() bool {
return th.t == 0 && th.n == 0
}
// Equals checks if two Thresholds are equal
func (th Threshold) Equals(other Threshold) bool {
return th.t == other.t && th.n == other.n
}
// String returns the string representation
func (th Threshold) String() string {
return fmt.Sprintf("%d-of-%d", th.t, th.n)
}
// CanSign checks if the given number of parties can sign
func (th Threshold) CanSign(availableParties int) bool {
return availableParties >= th.t
}
// RequiresAllParties checks if all parties are required
func (th Threshold) RequiresAllParties() bool {
return th.t == th.n
}