package postgres import ( "context" "database/sql" "errors" "github.com/google/uuid" "github.com/rwadurian/mpc-system/services/account/domain/entities" "github.com/rwadurian/mpc-system/services/account/domain/repositories" "github.com/rwadurian/mpc-system/services/account/domain/value_objects" ) // RecoverySessionPostgresRepo implements RecoverySessionRepository using PostgreSQL type RecoverySessionPostgresRepo struct { db *sql.DB } // NewRecoverySessionPostgresRepo creates a new RecoverySessionPostgresRepo func NewRecoverySessionPostgresRepo(db *sql.DB) repositories.RecoverySessionRepository { return &RecoverySessionPostgresRepo{db: db} } // Create creates a new recovery session func (r *RecoverySessionPostgresRepo) Create(ctx context.Context, session *entities.RecoverySession) error { query := ` INSERT INTO account_recovery_sessions (id, account_id, recovery_type, old_share_type, new_keygen_session_id, status, requested_at, completed_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ` var oldShareType *string if session.OldShareType != nil { s := session.OldShareType.String() oldShareType = &s } _, err := r.db.ExecContext(ctx, query, session.ID, session.AccountID.UUID(), session.RecoveryType.String(), oldShareType, session.NewKeygenSessionID, session.Status.String(), session.RequestedAt, session.CompletedAt, ) return err } // GetByID retrieves a recovery session by ID func (r *RecoverySessionPostgresRepo) GetByID(ctx context.Context, id string) (*entities.RecoverySession, error) { sessionID, err := uuid.Parse(id) if err != nil { return nil, entities.ErrRecoveryNotFound } query := ` SELECT id, account_id, recovery_type, old_share_type, new_keygen_session_id, status, requested_at, completed_at FROM account_recovery_sessions WHERE id = $1 ` return r.scanSession(r.db.QueryRowContext(ctx, query, sessionID)) } // GetByAccountID retrieves recovery sessions for an account func (r *RecoverySessionPostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error) { query := ` SELECT id, account_id, recovery_type, old_share_type, new_keygen_session_id, status, requested_at, completed_at FROM account_recovery_sessions WHERE account_id = $1 ORDER BY requested_at DESC ` return r.querySessions(ctx, query, accountID.UUID()) } // GetActiveByAccountID retrieves active recovery sessions for an account func (r *RecoverySessionPostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error) { query := ` SELECT id, account_id, recovery_type, old_share_type, new_keygen_session_id, status, requested_at, completed_at FROM account_recovery_sessions WHERE account_id = $1 AND status IN ('requested', 'in_progress') ORDER BY requested_at DESC LIMIT 1 ` return r.scanSession(r.db.QueryRowContext(ctx, query, accountID.UUID())) } // Update updates a recovery session func (r *RecoverySessionPostgresRepo) Update(ctx context.Context, session *entities.RecoverySession) error { query := ` UPDATE account_recovery_sessions SET recovery_type = $2, old_share_type = $3, new_keygen_session_id = $4, status = $5, completed_at = $6 WHERE id = $1 ` var oldShareType *string if session.OldShareType != nil { s := session.OldShareType.String() oldShareType = &s } result, err := r.db.ExecContext(ctx, query, session.ID, session.RecoveryType.String(), oldShareType, session.NewKeygenSessionID, session.Status.String(), session.CompletedAt, ) if err != nil { return err } rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return entities.ErrRecoveryNotFound } return nil } // Delete deletes a recovery session func (r *RecoverySessionPostgresRepo) Delete(ctx context.Context, id string) error { sessionID, err := uuid.Parse(id) if err != nil { return entities.ErrRecoveryNotFound } query := `DELETE FROM account_recovery_sessions WHERE id = $1` result, err := r.db.ExecContext(ctx, query, sessionID) if err != nil { return err } rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return entities.ErrRecoveryNotFound } return nil } // scanSession scans a single recovery session row func (r *RecoverySessionPostgresRepo) scanSession(row *sql.Row) (*entities.RecoverySession, error) { var ( id uuid.UUID accountID uuid.UUID recoveryType string oldShareType sql.NullString newKeygenSessionID sql.NullString status string session entities.RecoverySession ) err := row.Scan( &id, &accountID, &recoveryType, &oldShareType, &newKeygenSessionID, &status, &session.RequestedAt, &session.CompletedAt, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, entities.ErrRecoveryNotFound } return nil, err } session.ID = id session.AccountID = value_objects.AccountIDFromUUID(accountID) session.RecoveryType = value_objects.RecoveryType(recoveryType) session.Status = value_objects.RecoveryStatus(status) if oldShareType.Valid { st := value_objects.ShareType(oldShareType.String) session.OldShareType = &st } if newKeygenSessionID.Valid { if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil { session.NewKeygenSessionID = &keygenID } } return &session, nil } // querySessions queries multiple recovery sessions func (r *RecoverySessionPostgresRepo) querySessions(ctx context.Context, query string, args ...interface{}) ([]*entities.RecoverySession, error) { rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var sessions []*entities.RecoverySession for rows.Next() { var ( id uuid.UUID accountID uuid.UUID recoveryType string oldShareType sql.NullString newKeygenSessionID sql.NullString status string session entities.RecoverySession ) err := rows.Scan( &id, &accountID, &recoveryType, &oldShareType, &newKeygenSessionID, &status, &session.RequestedAt, &session.CompletedAt, ) if err != nil { return nil, err } session.ID = id session.AccountID = value_objects.AccountIDFromUUID(accountID) session.RecoveryType = value_objects.RecoveryType(recoveryType) session.Status = value_objects.RecoveryStatus(status) if oldShareType.Valid { st := value_objects.ShareType(oldShareType.String) session.OldShareType = &st } if newKeygenSessionID.Valid { if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil { session.NewKeygenSessionID = &keygenID } } sessions = append(sessions, &session) } return sessions, rows.Err() }