feat(mpc-system): implement optimistic locking for session updates
Implement version-based optimistic locking to prevent concurrent update conflicts when multiple parties simultaneously report completion during keygen operations. Changes: - Add version column to mpc_sessions table (migration 004) - Add Version field to MPCSession entity - Define ErrOptimisticLockConflict error - Update SessionPostgresRepo.Update() to check version and increment on success - Add automatic retry logic (max 3 attempts) to ReportCompletionUseCase - Update Save and all query methods (FindByStatus, FindExpired, etc.) to handle version field This replaces pessimistic locking (FOR UPDATE) with optimistic locking using the industry-standard pattern: WHERE version = $n and checking rowsAffected. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
63e00a64f5
commit
b72268c1ce
|
|
@ -0,0 +1,8 @@
|
|||
-- Remove optimistic locking support
|
||||
|
||||
-- Drop the index first
|
||||
DROP INDEX IF EXISTS idx_mpc_sessions_version;
|
||||
|
||||
-- Remove version column from mpc_sessions table
|
||||
ALTER TABLE mpc_sessions
|
||||
DROP COLUMN IF EXISTS version;
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
-- Add version field for optimistic locking
|
||||
-- This enables concurrent update detection to prevent lost updates
|
||||
|
||||
-- Add version column to mpc_sessions table
|
||||
ALTER TABLE mpc_sessions
|
||||
ADD COLUMN version BIGINT NOT NULL DEFAULT 1;
|
||||
|
||||
-- Add comment explaining the version field
|
||||
COMMENT ON COLUMN mpc_sessions.version IS 'Version number for optimistic locking - increments on each update to detect concurrent modifications';
|
||||
|
||||
-- Create index on version for better query performance (optional but recommended)
|
||||
CREATE INDEX idx_mpc_sessions_version ON mpc_sessions(id, version);
|
||||
|
|
@ -37,14 +37,15 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes
|
|||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO mpc_sessions (
|
||||
id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
status = EXCLUDED.status,
|
||||
message_hash = EXCLUDED.message_hash,
|
||||
public_key = EXCLUDED.public_key,
|
||||
updated_at = EXCLUDED.updated_at,
|
||||
completed_at = EXCLUDED.completed_at
|
||||
completed_at = EXCLUDED.completed_at,
|
||||
version = EXCLUDED.version
|
||||
`,
|
||||
session.ID.UUID(),
|
||||
string(session.SessionType),
|
||||
|
|
@ -58,6 +59,7 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes
|
|||
session.UpdatedAt,
|
||||
session.ExpiresAt,
|
||||
session.CompletedAt,
|
||||
session.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -114,7 +116,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
|
|||
var session sessionRow
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
FROM mpc_sessions WHERE id = $1
|
||||
`, id).Scan(
|
||||
&session.ID,
|
||||
|
|
@ -129,6 +131,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
|
|||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.CompletedAt,
|
||||
&session.Version,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
@ -158,6 +161,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
|
|||
session.ExpiresAt,
|
||||
session.CompletedAt,
|
||||
participants,
|
||||
session.Version,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -165,7 +169,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en
|
|||
func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_objects.SessionStatus) ([]*entities.MPCSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
FROM mpc_sessions WHERE status = $1
|
||||
`, status.String())
|
||||
if err != nil {
|
||||
|
|
@ -180,7 +184,7 @@ func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_obj
|
|||
func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
FROM mpc_sessions
|
||||
WHERE expires_at < NOW() AND status IN ('created', 'in_progress')
|
||||
`)
|
||||
|
|
@ -196,7 +200,7 @@ func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCS
|
|||
func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
FROM mpc_sessions
|
||||
WHERE status IN ('created', 'in_progress')
|
||||
ORDER BY created_at ASC
|
||||
|
|
@ -213,7 +217,7 @@ func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSe
|
|||
func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID string) ([]*entities.MPCSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, session_type, threshold_n, threshold_t, status,
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at
|
||||
message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version
|
||||
FROM mpc_sessions WHERE created_by = $1
|
||||
ORDER BY created_at DESC
|
||||
`, creatorID)
|
||||
|
|
@ -229,7 +233,7 @@ func (r *SessionPostgresRepo) FindByCreator(ctx context.Context, creatorID strin
|
|||
func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, partyID value_objects.PartyID) ([]*entities.MPCSession, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT s.id, s.session_type, s.threshold_n, s.threshold_t, s.status,
|
||||
s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at
|
||||
s.message_hash, s.public_key, s.created_by, s.created_at, s.updated_at, s.expires_at, s.completed_at, s.version
|
||||
FROM mpc_sessions s
|
||||
JOIN participants p ON s.id = p.session_id
|
||||
WHERE p.party_id = $1 AND s.status IN ('created', 'in_progress')
|
||||
|
|
@ -243,7 +247,7 @@ func (r *SessionPostgresRepo) FindActiveByParticipant(ctx context.Context, party
|
|||
return r.scanSessions(ctx, rows)
|
||||
}
|
||||
|
||||
// Update updates an existing session
|
||||
// Update updates an existing session using optimistic locking
|
||||
func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCSession) error {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
|
|
@ -251,40 +255,45 @@ func (r *SessionPostgresRepo) Update(ctx context.Context, session *entities.MPCS
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Lock the session row first to prevent concurrent modifications
|
||||
// This ensures serializable isolation for the entire session update
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
SELECT id FROM mpc_sessions WHERE id = $1 FOR UPDATE
|
||||
`, session.ID.UUID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update session
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
// Update session with optimistic locking (check version)
|
||||
// Version is incremented automatically to prevent concurrent modifications
|
||||
result, err := tx.ExecContext(ctx, `
|
||||
UPDATE mpc_sessions SET
|
||||
status = $1, public_key = $2, updated_at = $3, completed_at = $4
|
||||
WHERE id = $5
|
||||
status = $1,
|
||||
public_key = $2,
|
||||
updated_at = $3,
|
||||
completed_at = $4,
|
||||
version = version + 1
|
||||
WHERE id = $5 AND version = $6
|
||||
`,
|
||||
session.Status.String(),
|
||||
session.PublicKey,
|
||||
session.UpdatedAt,
|
||||
session.CompletedAt,
|
||||
session.ID.UUID(),
|
||||
session.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Lock all participant rows for this session to prevent concurrent modifications
|
||||
// This prevents lost updates when multiple parties report completion simultaneously
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
SELECT id FROM participants WHERE session_id = $1 FOR UPDATE
|
||||
`, session.ID.UUID())
|
||||
// Check if the update affected any rows
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no rows were affected, version mismatch occurred (concurrent modification)
|
||||
if rowsAffected == 0 {
|
||||
logger.Warn("optimistic lock conflict detected",
|
||||
zap.String("session_id", session.ID.UUID().String()),
|
||||
zap.Int64("expected_version", session.Version))
|
||||
return entities.ErrOptimisticLockConflict
|
||||
}
|
||||
|
||||
// Update session version in memory to reflect the database state
|
||||
session.Version++
|
||||
|
||||
// DEBUG: Log all participant statuses before update
|
||||
var participantStatuses []string
|
||||
for _, p := range session.Participants {
|
||||
|
|
@ -479,6 +488,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows)
|
|||
&s.UpdatedAt,
|
||||
&s.ExpiresAt,
|
||||
&s.CompletedAt,
|
||||
&s.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -504,6 +514,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows)
|
|||
s.ExpiresAt,
|
||||
s.CompletedAt,
|
||||
participants,
|
||||
s.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -528,6 +539,7 @@ type sessionRow struct {
|
|||
UpdatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
CompletedAt *time.Time
|
||||
Version int64
|
||||
}
|
||||
|
||||
type participantRow struct {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package use_cases
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -37,14 +38,43 @@ func NewReportCompletionUseCase(
|
|||
}
|
||||
}
|
||||
|
||||
// Execute executes the report completion use case
|
||||
const (
|
||||
maxRetries = 3 // Maximum number of retry attempts for optimistic lock conflicts
|
||||
)
|
||||
|
||||
// Execute executes the report completion use case with retry logic for optimistic lock conflicts
|
||||
func (uc *ReportCompletionUseCase) Execute(
|
||||
ctx context.Context,
|
||||
inputData input.ReportCompletionInput,
|
||||
) (*input.ReportCompletionOutput, error) {
|
||||
return uc.executeWithRetry(ctx, inputData, 0)
|
||||
}
|
||||
|
||||
// executeWithRetry executes the report completion with retry logic
|
||||
func (uc *ReportCompletionUseCase) executeWithRetry(
|
||||
ctx context.Context,
|
||||
inputData input.ReportCompletionInput,
|
||||
retry int,
|
||||
) (*input.ReportCompletionOutput, error) {
|
||||
if retry >= maxRetries {
|
||||
logger.Error("max retries exceeded for optimistic lock",
|
||||
zap.String("session_id", inputData.SessionID.String()),
|
||||
zap.String("party_id", inputData.PartyID),
|
||||
zap.Int("retry_count", retry))
|
||||
return nil, fmt.Errorf("max retries exceeded: %w", entities.ErrOptimisticLockConflict)
|
||||
}
|
||||
|
||||
if retry > 0 {
|
||||
logger.Info("retrying report completion due to optimistic lock conflict",
|
||||
zap.String("session_id", inputData.SessionID.String()),
|
||||
zap.String("party_id", inputData.PartyID),
|
||||
zap.Int("retry_attempt", retry))
|
||||
}
|
||||
|
||||
logger.Debug("ReportCompletion.Execute: START",
|
||||
zap.String("session_id", inputData.SessionID.String()),
|
||||
zap.String("party_id", inputData.PartyID))
|
||||
zap.String("party_id", inputData.PartyID),
|
||||
zap.Int("retry", retry))
|
||||
|
||||
// 1. Load session
|
||||
session, err := uc.sessionRepo.FindByUUID(ctx, inputData.SessionID)
|
||||
|
|
@ -169,12 +199,22 @@ func (uc *ReportCompletionUseCase) Execute(
|
|||
zap.Strings("all_participant_statuses", beforeUpdateStatuses))
|
||||
|
||||
if err := uc.sessionRepo.Update(ctx, session); err != nil {
|
||||
// Check if this is an optimistic lock conflict
|
||||
if errors.Is(err, entities.ErrOptimisticLockConflict) {
|
||||
logger.Warn("optimistic lock conflict, retrying",
|
||||
zap.String("session_id", session.ID.String()),
|
||||
zap.String("party_id", inputData.PartyID),
|
||||
zap.Int("retry", retry))
|
||||
// Retry the entire operation with fresh data
|
||||
return uc.executeWithRetry(ctx, inputData, retry+1)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Debug("ReportCompletion.Execute: AFTER sessionRepo.Update",
|
||||
zap.String("session_id", session.ID.String()),
|
||||
zap.String("party_id", inputData.PartyID))
|
||||
zap.String("party_id", inputData.PartyID),
|
||||
zap.Int("retry", retry))
|
||||
|
||||
// 7. Publish participant completed event
|
||||
event := output.ParticipantCompletedEvent{
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ var (
|
|||
ErrInvalidSessionType = errors.New("invalid session type")
|
||||
ErrInvalidStatusTransition = errors.New("invalid status transition")
|
||||
ErrParticipantTimedOut = errors.New("participant timed out")
|
||||
ErrOptimisticLockConflict = errors.New("optimistic lock conflict: session was modified by another transaction")
|
||||
)
|
||||
|
||||
// SessionType represents the type of MPC session
|
||||
|
|
@ -48,6 +49,7 @@ type MPCSession struct {
|
|||
UpdatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
CompletedAt *time.Time
|
||||
Version int64 // Optimistic locking version number
|
||||
}
|
||||
|
||||
// NewMPCSession creates a new MPC session
|
||||
|
|
@ -78,6 +80,7 @@ func NewMPCSession(
|
|||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ExpiresAt: now.Add(expiresIn),
|
||||
Version: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -384,6 +387,7 @@ func ReconstructSession(
|
|||
createdAt, updatedAt, expiresAt time.Time,
|
||||
completedAt *time.Time,
|
||||
participants []*Participant,
|
||||
version int64,
|
||||
) (*MPCSession, error) {
|
||||
sessionStatus, err := value_objects.NewSessionStatus(status)
|
||||
if err != nil {
|
||||
|
|
@ -409,5 +413,6 @@ func ReconstructSession(
|
|||
UpdatedAt: updatedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
CompletedAt: completedAt,
|
||||
Version: version,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue