feat(mpc-system): implement optimistic locking for session updates

Implement version-based optimistic locking to prevent concurrent update conflicts
when multiple parties simultaneously report completion during keygen operations.

Changes:
- Add version column to mpc_sessions table (migration 004)
- Add Version field to MPCSession entity
- Define ErrOptimisticLockConflict error
- Update SessionPostgresRepo.Update() to check version and increment on success
- Add automatic retry logic (max 3 attempts) to ReportCompletionUseCase
- Update Save and all query methods (FindByStatus, FindExpired, etc.) to handle version field

This replaces pessimistic locking (FOR UPDATE) with optimistic locking using
the industry-standard pattern: WHERE version = $n and checking rowsAffected.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
hailin 2025-12-06 04:16:32 -08:00
parent 63e00a64f5
commit b72268c1ce
5 changed files with 108 additions and 31 deletions

View File

@ -0,0 +1,8 @@
-- Remove optimistic locking support
-- Drop the index first
DROP INDEX IF EXISTS idx_mpc_sessions_version;
-- Remove version column from mpc_sessions table
ALTER TABLE mpc_sessions
DROP COLUMN IF EXISTS version;

View File

@ -0,0 +1,12 @@
-- Add version field for optimistic locking
-- This enables concurrent update detection to prevent lost updates
-- Add version column to mpc_sessions table
ALTER TABLE mpc_sessions
ADD COLUMN version BIGINT NOT NULL DEFAULT 1;
-- Add comment explaining the version field
COMMENT ON COLUMN mpc_sessions.version IS 'Version number for optimistic locking - increments on each update to detect concurrent modifications';
-- Create index on version for better query performance (optional but recommended)
CREATE INDEX idx_mpc_sessions_version ON mpc_sessions(id, version);

View File

@ -37,14 +37,15 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes
_, err = tx.ExecContext(ctx, `
INSERT INTO mpc_sessions (
id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (id) DO UPDATE SET
status = EXCLUDED.status,
message_hash = EXCLUDED.message_hash,
public_key = EXCLUDED.public_key,
updated_at = EXCLUDED.updated_at,
completed_at = EXCLUDED.completed_at
completed_at = EXCLUDED.completed_at,
version = EXCLUDED.version
`,
session.ID.UUID(),
string(session.SessionType),
@ -58,6 +59,7 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes
session.UpdatedAt,
session.ExpiresAt,
session.CompletedAt,
session.Version,
)
if err != nil {
return err
@ -114,7 +116,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
var session sessionRow
err := r.db.QueryRowContext(ctx, `
SELECT id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
FROM mpc_sessions WHERE id = $1
`, id).Scan(
&session.ID,
@ -129,6 +131,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
&session.UpdatedAt,
&session.ExpiresAt,
&session.CompletedAt,
&session.Version,
)
if err != nil {
if err == sql.ErrNoRows {
@ -158,6 +161,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
session.ExpiresAt,
session.CompletedAt,
participants,
session.Version,
)
}
@ -165,7 +169,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_objects.SessionStatus) ([]*entities.MPCSession, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
FROM mpc_sessions WHERE status = $1
`, status.String())
if err != nil {
@ -180,7 +184,7 @@ func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_obj
func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCSession, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
FROM mpc_sessions
WHERE expires_at < NOW() AND status IN ('created', 'in_progress')
`)
@ -196,7 +200,7 @@ func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCS
func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSession, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
FROM mpc_sessions
WHERE status IN ('created', 'in_progress')
ORDER BY created_at ASC
@ -213,7 +217,7 @@ func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSe
func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID string) ([]*entities.MPCSession, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_type, threshold_n, threshold_t, status,
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
FROM mpc_sessions WHERE created_by = $1
ORDER BY created_at DESC
`, creatorID)
@ -229,7 +233,7 @@ func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID strin
func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, partyID value_objects.PartyID) ([]*entities.MPCSession, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT s.id, s.session_type, s.threshold_n, s.threshold_t, s.status,
s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at
s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at, s.version
FROM mpc_sessions s
JOIN participants p ON s.id = p.session_id
WHERE p.party_id = $1 AND s.status IN ('created', 'in_progress')
@ -243,7 +247,7 @@ func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, party
return r.scanSessions(ctx, rows)
}
// Update updates an existing session
// Update updates an existing session using optimistic locking
func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCSession) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
@ -251,40 +255,45 @@ func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCS
}
defer tx.Rollback()
// Lock the session row first to prevent concurrent modifications
// This ensures serializable isolation for the entire session update
_, err = tx.ExecContext(ctx, `
SELECT id FROM mpc_sessions WHERE id = $1 FOR UPDATE
`, session.ID.UUID())
if err != nil {
return err
}
// Update session
_, err = tx.ExecContext(ctx, `
// Update session with optimistic locking (check version)
// Version is incremented automatically to prevent concurrent modifications
result, err := tx.ExecContext(ctx, `
UPDATE mpc_sessions SET
status = $1, public_key = $2, updated_at = $3, completed_at = $4
WHERE id = $5
status = $1,
public_key = $2,
updated_at = $3,
completed_at = $4,
version = version + 1
WHERE id = $5 AND version = $6
`,
session.Status.String(),
session.PublicKey,
session.UpdatedAt,
session.CompletedAt,
session.ID.UUID(),
session.Version,
)
if err != nil {
return err
}
// Lock all participant rows for this session to prevent concurrent modifications
// This prevents lost updates when multiple parties report completion simultaneously
_, err = tx.ExecContext(ctx, `
SELECT id FROM participants WHERE session_id = $1 FOR UPDATE
`, session.ID.UUID())
// Check if the update affected any rows
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
// If no rows were affected, version mismatch occurred (concurrent modification)
if rowsAffected == 0 {
logger.Warn("optimistic lock conflict detected",
zap.String("session_id", session.ID.UUID().String()),
zap.Int64("expected_version", session.Version))
return entities.ErrOptimisticLockConflict
}
// Update session version in memory to reflect the database state
session.Version++
// DEBUG: Log all participant statuses before update
var participantStatuses []string
for _, p := range session.Participants {
@ -479,6 +488,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows)
&s.UpdatedAt,
&s.ExpiresAt,
&s.CompletedAt,
&s.Version,
)
if err != nil {
return nil, err
@ -504,6 +514,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows)
s.ExpiresAt,
s.CompletedAt,
participants,
s.Version,
)
if err != nil {
return nil, err
@ -528,6 +539,7 @@ type sessionRow struct {
UpdatedAt time.Time
ExpiresAt time.Time
CompletedAt *time.Time
Version int64
}
type participantRow struct {

View File

@ -2,6 +2,7 @@ package use_cases
import (
"context"
"errors"
"fmt"
"time"
@ -37,14 +38,43 @@ func NewReportCompletionUseCase(
}
}
// Execute executes the report completion use case
const (
maxRetries = 3 // Maximum number of retry attempts for optimistic lock conflicts
)
// Execute executes the report completion use case with retry logic for optimistic lock conflicts
func (uc *ReportCompletionUseCase) Execute(
ctx context.Context,
inputData input.ReportCompletionInput,
) (*input.ReportCompletionOutput, error) {
return uc.executeWithRetry(ctx, inputData, 0)
}
// executeWithRetry executes the report completion with retry logic
func (uc *ReportCompletionUseCase) executeWithRetry(
ctx context.Context,
inputData input.ReportCompletionInput,
retry int,
) (*input.ReportCompletionOutput, error) {
if retry >= maxRetries {
logger.Error("max retries exceeded for optimistic lock",
zap.String("session_id", inputData.SessionID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry_count", retry))
return nil, fmt.Errorf("max retries exceeded: %w", entities.ErrOptimisticLockConflict)
}
if retry > 0 {
logger.Info("retrying report completion due to optimistic lock conflict",
zap.String("session_id", inputData.SessionID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry_attempt", retry))
}
logger.Debug("ReportCompletion.Execute: START",
zap.String("session_id", inputData.SessionID.String()),
zap.String("party_id", inputData.PartyID))
zap.String("party_id", inputData.PartyID),
zap.Int("retry", retry))
// 1. Load session
session, err := uc.sessionRepo.FindByUUID(ctx, inputData.SessionID)
@ -169,12 +199,22 @@ func (uc *ReportCompletionUseCase) Execute(
zap.Strings("all_participant_statuses", beforeUpdateStatuses))
if err := uc.sessionRepo.Update(ctx, session); err != nil {
// Check if this is an optimistic lock conflict
if errors.Is(err, entities.ErrOptimisticLockConflict) {
logger.Warn("optimistic lock conflict, retrying",
zap.String("session_id", session.ID.String()),
zap.String("party_id", inputData.PartyID),
zap.Int("retry", retry))
// Retry the entire operation with fresh data
return uc.executeWithRetry(ctx, inputData, retry+1)
}
return nil, err
}
logger.Debug("ReportCompletion.Execute: AFTER sessionRepo.Update",
zap.String("session_id", session.ID.String()),
zap.String("party_id", inputData.PartyID))
zap.String("party_id", inputData.PartyID),
zap.Int("retry", retry))
// 7. Publish participant completed event
event := output.ParticipantCompletedEvent{

View File

@ -17,6 +17,7 @@ var (
ErrInvalidSessionType = errors.New("invalid session type")
ErrInvalidStatusTransition = errors.New("invalid status transition")
ErrParticipantTimedOut = errors.New("participant timed out")
ErrOptimisticLockConflict = errors.New("optimistic lock conflict: session was modified by another transaction")
)
// SessionType represents the type of MPC session
@ -48,6 +49,7 @@ type MPCSession struct {
UpdatedAt time.Time
ExpiresAt time.Time
CompletedAt *time.Time
Version int64 // Optimistic locking version number
}
// NewMPCSession creates a new MPC session
@ -78,6 +80,7 @@ func NewMPCSession(
CreatedAt: now,
UpdatedAt: now,
ExpiresAt: now.Add(expiresIn),
Version: 1,
}, nil
}
@ -384,6 +387,7 @@ func ReconstructSession(
createdAt, updatedAt, expiresAt time.Time,
completedAt *time.Time,
participants []*Participant,
version int64,
) (*MPCSession, error) {
sessionStatus, err := value_objects.NewSessionStatus(status)
if err != nil {
@ -409,5 +413,6 @@ func ReconstructSession(
UpdatedAt: updatedAt,
ExpiresAt: expiresAt,
CompletedAt: completedAt,
Version: version,
}, nil
}