diff --git a/backend/mpc-system/migrations/004_add_optimistic_locking.down.sql b/backend/mpc-system/migrations/004_add_optimistic_locking.down.sql new file mode 100644 index 00000000..9dc611c3 --- /dev/null +++ b/backend/mpc-system/migrations/004_add_optimistic_locking.down.sql @@ -0,0 +1,8 @@ +-- Remove optimistic locking support + +-- Drop the index first +DROP INDEX IF EXISTS idx_mpc_sessions_version; + +-- Remove version column from mpc_sessions table +ALTER TABLE mpc_sessions +DROP COLUMN IF EXISTS version; diff --git a/backend/mpc-system/migrations/004_add_optimistic_locking.up.sql b/backend/mpc-system/migrations/004_add_optimistic_locking.up.sql new file mode 100644 index 00000000..a3c0278c --- /dev/null +++ b/backend/mpc-system/migrations/004_add_optimistic_locking.up.sql @@ -0,0 +1,12 @@ +-- Add version field for optimistic locking +-- This enables concurrent update detection to prevent lost updates + +-- Add version column to mpc_sessions table +ALTER TABLE mpc_sessions +ADD COLUMN version BIGINT NOT NULL DEFAULT 1; + +-- Add comment explaining the version field +COMMENT ON COLUMN mpc_sessions.version IS 'Version number for optimistic locking - increments on each update to detect concurrent modifications'; + +-- Create index on version for better query performance (optional but recommended) +CREATE INDEX idx_mpc_sessions_version ON mpc_sessions(id, version); diff --git a/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go b/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go index f4c99671..64b45644 100644 --- a/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go +++ b/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go @@ -37,14 +37,15 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes _, err = tx.ExecContext(ctx, ` INSERT INTO mpc_sessions ( id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (id) DO UPDATE SET status = EXCLUDED.status, message_hash = EXCLUDED.message_hash, public_key = EXCLUDED.public_key, updated_at = EXCLUDED.updated_at, - completed_at = EXCLUDED.completed_at + completed_at = EXCLUDED.completed_at, + version = EXCLUDED.version `, session.ID.UUID(), string(session.SessionType), @@ -58,6 +59,7 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes session.UpdatedAt, session.ExpiresAt, session.CompletedAt, + session.Version, ) if err != nil { return err @@ -114,7 +116,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en var session sessionRow err := r.db.QueryRowContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE id = $1 `, id).Scan( &session.ID, @@ -129,6 +131,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en &session.UpdatedAt, &session.ExpiresAt, &session.CompletedAt, + &session.Version, ) if err != nil { if err == sql.ErrNoRows { @@ -158,6 +161,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en session.ExpiresAt, session.CompletedAt, participants, + session.Version, ) } @@ -165,7 +169,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_objects.SessionStatus) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE status = $1 `, status.String()) if err != nil { @@ -180,7 +184,7 @@ func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_obj func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE expires_at < NOW() AND status IN ('created', 'in_progress') `) @@ -196,7 +200,7 @@ func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCS func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE status IN ('created', 'in_progress') ORDER BY created_at ASC @@ -213,7 +217,7 @@ func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSe func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID string) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at + message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE created_by = $1 ORDER BY created_at DESC `, creatorID) @@ -229,7 +233,7 @@ func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID strin func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, partyID value_objects.PartyID) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT s.id, s.session_type, s.threshold_n, s.threshold_t, s.status, - s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at + s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at, s.version FROM mpc_sessions s JOIN participants p ON s.id = p.session_id WHERE p.party_id = $1 AND s.status IN ('created', 'in_progress') @@ -243,7 +247,7 @@ func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, party return r.scanSessions(ctx, rows) } -// Update updates an existing session +// Update updates an existing session using optimistic locking func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCSession) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { @@ -251,40 +255,45 @@ func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCS } defer tx.Rollback() - // Lock the session row first to prevent concurrent modifications - // This ensures serializable isolation for the entire session update - _, err = tx.ExecContext(ctx, ` - SELECT id FROM mpc_sessions WHERE id = $1 FOR UPDATE - `, session.ID.UUID()) - if err != nil { - return err - } - - // Update session - _, err = tx.ExecContext(ctx, ` + // Update session with optimistic locking (check version) + // Version is incremented automatically to prevent concurrent modifications + result, err := tx.ExecContext(ctx, ` UPDATE mpc_sessions SET - status = $1, public_key = $2, updated_at = $3, completed_at = $4 - WHERE id = $5 + status = $1, + public_key = $2, + updated_at = $3, + completed_at = $4, + version = version + 1 + WHERE id = $5 AND version = $6 `, session.Status.String(), session.PublicKey, session.UpdatedAt, session.CompletedAt, session.ID.UUID(), + session.Version, ) if err != nil { return err } - // Lock all participant rows for this session to prevent concurrent modifications - // This prevents lost updates when multiple parties report completion simultaneously - _, err = tx.ExecContext(ctx, ` - SELECT id FROM participants WHERE session_id = $1 FOR UPDATE - `, session.ID.UUID()) + // Check if the update affected any rows + rowsAffected, err := result.RowsAffected() if err != nil { return err } + // If no rows were affected, version mismatch occurred (concurrent modification) + if rowsAffected == 0 { + logger.Warn("optimistic lock conflict detected", + zap.String("session_id", session.ID.UUID().String()), + zap.Int64("expected_version", session.Version)) + return entities.ErrOptimisticLockConflict + } + + // Update session version in memory to reflect the database state + session.Version++ + // DEBUG: Log all participant statuses before update var participantStatuses []string for _, p := range session.Participants { @@ -479,6 +488,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows) &s.UpdatedAt, &s.ExpiresAt, &s.CompletedAt, + &s.Version, ) if err != nil { return nil, err @@ -504,6 +514,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows) s.ExpiresAt, s.CompletedAt, participants, + s.Version, ) if err != nil { return nil, err @@ -528,6 +539,7 @@ type sessionRow struct { UpdatedAt time.Time ExpiresAt time.Time CompletedAt *time.Time + Version int64 } type participantRow struct { diff --git a/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go b/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go index 22ef7f7d..356cfd2b 100644 --- a/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go +++ b/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go @@ -2,6 +2,7 @@ package use_cases import ( "context" + "errors" "fmt" "time" @@ -37,14 +38,43 @@ func NewReportCompletionUseCase( } } -// Execute executes the report completion use case +const ( + maxRetries = 3 // Maximum number of retry attempts for optimistic lock conflicts +) + +// Execute executes the report completion use case with retry logic for optimistic lock conflicts func (uc *ReportCompletionUseCase) Execute( ctx context.Context, inputData input.ReportCompletionInput, ) (*input.ReportCompletionOutput, error) { + return uc.executeWithRetry(ctx, inputData, 0) +} + +// executeWithRetry executes the report completion with retry logic +func (uc *ReportCompletionUseCase) executeWithRetry( + ctx context.Context, + inputData input.ReportCompletionInput, + retry int, +) (*input.ReportCompletionOutput, error) { + if retry >= maxRetries { + logger.Error("max retries exceeded for optimistic lock", + zap.String("session_id", inputData.SessionID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry_count", retry)) + return nil, fmt.Errorf("max retries exceeded: %w", entities.ErrOptimisticLockConflict) + } + + if retry > 0 { + logger.Info("retrying report completion due to optimistic lock conflict", + zap.String("session_id", inputData.SessionID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry_attempt", retry)) + } + logger.Debug("ReportCompletion.Execute: START", zap.String("session_id", inputData.SessionID.String()), - zap.String("party_id", inputData.PartyID)) + zap.String("party_id", inputData.PartyID), + zap.Int("retry", retry)) // 1. Load session session, err := uc.sessionRepo.FindByUUID(ctx, inputData.SessionID) @@ -169,12 +199,22 @@ func (uc *ReportCompletionUseCase) Execute( zap.Strings("all_participant_statuses", beforeUpdateStatuses)) if err := uc.sessionRepo.Update(ctx, session); err != nil { + // Check if this is an optimistic lock conflict + if errors.Is(err, entities.ErrOptimisticLockConflict) { + logger.Warn("optimistic lock conflict, retrying", + zap.String("session_id", session.ID.String()), + zap.String("party_id", inputData.PartyID), + zap.Int("retry", retry)) + // Retry the entire operation with fresh data + return uc.executeWithRetry(ctx, inputData, retry+1) + } return nil, err } logger.Debug("ReportCompletion.Execute: AFTER sessionRepo.Update", zap.String("session_id", session.ID.String()), - zap.String("party_id", inputData.PartyID)) + zap.String("party_id", inputData.PartyID), + zap.Int("retry", retry)) // 7. Publish participant completed event event := output.ParticipantCompletedEvent{ diff --git a/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go b/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go index 5f74bc4e..a486375c 100644 --- a/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go +++ b/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go @@ -17,6 +17,7 @@ var ( ErrInvalidSessionType = errors.New("invalid session type") ErrInvalidStatusTransition = errors.New("invalid status transition") ErrParticipantTimedOut = errors.New("participant timed out") + ErrOptimisticLockConflict = errors.New("optimistic lock conflict: session was modified by another transaction") ) // SessionType represents the type of MPC session @@ -48,6 +49,7 @@ type MPCSession struct { UpdatedAt time.Time ExpiresAt time.Time CompletedAt *time.Time + Version int64 // Optimistic locking version number } // NewMPCSession creates a new MPC session @@ -78,6 +80,7 @@ func NewMPCSession( CreatedAt: now, UpdatedAt: now, ExpiresAt: now.Add(expiresIn), + Version: 1, }, nil } @@ -384,6 +387,7 @@ func ReconstructSession( createdAt, updatedAt, expiresAt time.Time, completedAt *time.Time, participants []*Participant, + version int64, ) (*MPCSession, error) { sessionStatus, err := value_objects.NewSessionStatus(status) if err != nil { @@ -409,5 +413,6 @@ func ReconstructSession( UpdatedAt: updatedAt, ExpiresAt: expiresAt, CompletedAt: completedAt, + Version: version, }, nil }