diff --git a/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go b/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go index 64b45644..7406bbf2 100644 --- a/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go +++ b/backend/mpc-system/services/session-coordinator/adapters/output/postgres/session_postgres_repo.go @@ -37,12 +37,13 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes _, err = tx.ExecContext(ctx, ` INSERT INTO mpc_sessions ( id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + message_hash, public_key, delegate_party_id, created_by, created_at, updated_at, expires_at, completed_at, version + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (id) DO UPDATE SET status = EXCLUDED.status, message_hash = EXCLUDED.message_hash, public_key = EXCLUDED.public_key, + delegate_party_id = EXCLUDED.delegate_party_id, updated_at = EXCLUDED.updated_at, completed_at = EXCLUDED.completed_at, version = EXCLUDED.version @@ -54,6 +55,7 @@ func (r *SessionPostgresRepo) Save(ctx context.Context, session *entities.MPCSes session.Status.String(), session.MessageHash, session.PublicKey, + session.DelegatePartyID, session.CreatedBy, session.CreatedAt, session.UpdatedAt, @@ -116,7 +118,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en var session sessionRow err := r.db.QueryRowContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version + message_hash, public_key, delegate_party_id, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE id = $1 `, id).Scan( &session.ID, @@ -126,6 +128,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en &session.Status, &session.MessageHash, &session.PublicKey, + &session.DelegatePartyID, &session.CreatedBy, &session.CreatedAt, &session.UpdatedAt, @@ -154,7 +157,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en session.Status, session.MessageHash, session.PublicKey, - "", // delegatePartyID - not stored in DB yet, will be empty for old sessions + session.DelegatePartyID, session.CreatedBy, session.CreatedAt, session.UpdatedAt, @@ -169,7 +172,7 @@ func (r *SessionPostgresRepo) FindByUUID(ctx context.Context, id uuid.UUID) (*en func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_objects.SessionStatus) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version + message_hash, public_key, delegate_party_id, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE status = $1 `, status.String()) if err != nil { @@ -184,7 +187,7 @@ func (r *SessionPostgresRepo) FindByStatus(ctx context.Context, status value_obj func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version + message_hash, public_key, delegate_party_id, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE expires_at < NOW() AND status IN ('created', 'in_progress') `) @@ -200,7 +203,7 @@ func (r *SessionPostgresRepo) FindExpired(ctx context.Context) ([]*entities.MPCS func (r *SessionPostgresRepo) FindActive(ctx context.Context) ([]*entities.MPCSession, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_type, threshold_n, threshold_t, status, - message_hash, public_key, created_by, created_at, updated_at, expires_at, completed_at, version + message_hash, public_key, delegate_party_id, created_by, created_at, updated_at, expires_at, completed_at, version FROM mpc_sessions WHERE status IN ('created', 'in_progress') ORDER BY created_at ASC @@ -483,6 +486,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows) &s.Status, &s.MessageHash, &s.PublicKey, + &s.DelegatePartyID, &s.CreatedBy, &s.CreatedAt, &s.UpdatedAt, @@ -507,7 +511,7 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows) s.Status, s.MessageHash, s.PublicKey, - "", // delegatePartyID - not stored in DB yet + s.DelegatePartyID, s.CreatedBy, s.CreatedAt, s.UpdatedAt, @@ -527,19 +531,20 @@ func (r *SessionPostgresRepo) scanSessions(ctx context.Context, rows *sql.Rows) // Row types for scanning type sessionRow struct { - ID uuid.UUID - SessionType string - ThresholdN int - ThresholdT int - Status string - MessageHash []byte - PublicKey []byte - CreatedBy string - CreatedAt time.Time - UpdatedAt time.Time - ExpiresAt time.Time - CompletedAt *time.Time - Version int64 + ID uuid.UUID + SessionType string + ThresholdN int + ThresholdT int + Status string + MessageHash []byte + PublicKey []byte + DelegatePartyID string + CreatedBy string + CreatedAt time.Time + UpdatedAt time.Time + ExpiresAt time.Time + CompletedAt *time.Time + Version int64 } type participantRow struct {