267 lines
7.0 KiB
Go
267 lines
7.0 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/services/account/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/account/domain/repositories"
|
|
"github.com/rwadurian/mpc-system/services/account/domain/value_objects"
|
|
)
|
|
|
|
// RecoverySessionPostgresRepo implements RecoverySessionRepository using PostgreSQL
|
|
type RecoverySessionPostgresRepo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewRecoverySessionPostgresRepo creates a new RecoverySessionPostgresRepo
|
|
func NewRecoverySessionPostgresRepo(db *sql.DB) repositories.RecoverySessionRepository {
|
|
return &RecoverySessionPostgresRepo{db: db}
|
|
}
|
|
|
|
// Create creates a new recovery session
|
|
func (r *RecoverySessionPostgresRepo) Create(ctx context.Context, session *entities.RecoverySession) error {
|
|
query := `
|
|
INSERT INTO account_recovery_sessions (id, account_id, recovery_type, old_share_type,
|
|
new_keygen_session_id, status, requested_at, completed_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`
|
|
|
|
var oldShareType *string
|
|
if session.OldShareType != nil {
|
|
s := session.OldShareType.String()
|
|
oldShareType = &s
|
|
}
|
|
|
|
_, err := r.db.ExecContext(ctx, query,
|
|
session.ID,
|
|
session.AccountID.UUID(),
|
|
session.RecoveryType.String(),
|
|
oldShareType,
|
|
session.NewKeygenSessionID,
|
|
session.Status.String(),
|
|
session.RequestedAt,
|
|
session.CompletedAt,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// GetByID retrieves a recovery session by ID
|
|
func (r *RecoverySessionPostgresRepo) GetByID(ctx context.Context, id string) (*entities.RecoverySession, error) {
|
|
sessionID, err := uuid.Parse(id)
|
|
if err != nil {
|
|
return nil, entities.ErrRecoveryNotFound
|
|
}
|
|
|
|
query := `
|
|
SELECT id, account_id, recovery_type, old_share_type,
|
|
new_keygen_session_id, status, requested_at, completed_at
|
|
FROM account_recovery_sessions
|
|
WHERE id = $1
|
|
`
|
|
|
|
return r.scanSession(r.db.QueryRowContext(ctx, query, sessionID))
|
|
}
|
|
|
|
// GetByAccountID retrieves recovery sessions for an account
|
|
func (r *RecoverySessionPostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.RecoverySession, error) {
|
|
query := `
|
|
SELECT id, account_id, recovery_type, old_share_type,
|
|
new_keygen_session_id, status, requested_at, completed_at
|
|
FROM account_recovery_sessions
|
|
WHERE account_id = $1
|
|
ORDER BY requested_at DESC
|
|
`
|
|
|
|
return r.querySessions(ctx, query, accountID.UUID())
|
|
}
|
|
|
|
// GetActiveByAccountID retrieves active recovery sessions for an account
|
|
func (r *RecoverySessionPostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) (*entities.RecoverySession, error) {
|
|
query := `
|
|
SELECT id, account_id, recovery_type, old_share_type,
|
|
new_keygen_session_id, status, requested_at, completed_at
|
|
FROM account_recovery_sessions
|
|
WHERE account_id = $1 AND status IN ('requested', 'in_progress')
|
|
ORDER BY requested_at DESC
|
|
LIMIT 1
|
|
`
|
|
|
|
return r.scanSession(r.db.QueryRowContext(ctx, query, accountID.UUID()))
|
|
}
|
|
|
|
// Update updates a recovery session
|
|
func (r *RecoverySessionPostgresRepo) Update(ctx context.Context, session *entities.RecoverySession) error {
|
|
query := `
|
|
UPDATE account_recovery_sessions
|
|
SET recovery_type = $2, old_share_type = $3, new_keygen_session_id = $4,
|
|
status = $5, completed_at = $6
|
|
WHERE id = $1
|
|
`
|
|
|
|
var oldShareType *string
|
|
if session.OldShareType != nil {
|
|
s := session.OldShareType.String()
|
|
oldShareType = &s
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, query,
|
|
session.ID,
|
|
session.RecoveryType.String(),
|
|
oldShareType,
|
|
session.NewKeygenSessionID,
|
|
session.Status.String(),
|
|
session.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return entities.ErrRecoveryNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete deletes a recovery session
|
|
func (r *RecoverySessionPostgresRepo) Delete(ctx context.Context, id string) error {
|
|
sessionID, err := uuid.Parse(id)
|
|
if err != nil {
|
|
return entities.ErrRecoveryNotFound
|
|
}
|
|
|
|
query := `DELETE FROM account_recovery_sessions WHERE id = $1`
|
|
|
|
result, err := r.db.ExecContext(ctx, query, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return entities.ErrRecoveryNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// scanSession scans a single recovery session row
|
|
func (r *RecoverySessionPostgresRepo) scanSession(row *sql.Row) (*entities.RecoverySession, error) {
|
|
var (
|
|
id uuid.UUID
|
|
accountID uuid.UUID
|
|
recoveryType string
|
|
oldShareType sql.NullString
|
|
newKeygenSessionID sql.NullString
|
|
status string
|
|
session entities.RecoverySession
|
|
)
|
|
|
|
err := row.Scan(
|
|
&id,
|
|
&accountID,
|
|
&recoveryType,
|
|
&oldShareType,
|
|
&newKeygenSessionID,
|
|
&status,
|
|
&session.RequestedAt,
|
|
&session.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, entities.ErrRecoveryNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
session.ID = id
|
|
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
|
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
|
session.Status = value_objects.RecoveryStatus(status)
|
|
|
|
if oldShareType.Valid {
|
|
st := value_objects.ShareType(oldShareType.String)
|
|
session.OldShareType = &st
|
|
}
|
|
|
|
if newKeygenSessionID.Valid {
|
|
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
|
session.NewKeygenSessionID = &keygenID
|
|
}
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// querySessions queries multiple recovery sessions
|
|
func (r *RecoverySessionPostgresRepo) querySessions(ctx context.Context, query string, args ...interface{}) ([]*entities.RecoverySession, error) {
|
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessions []*entities.RecoverySession
|
|
for rows.Next() {
|
|
var (
|
|
id uuid.UUID
|
|
accountID uuid.UUID
|
|
recoveryType string
|
|
oldShareType sql.NullString
|
|
newKeygenSessionID sql.NullString
|
|
status string
|
|
session entities.RecoverySession
|
|
)
|
|
|
|
err := rows.Scan(
|
|
&id,
|
|
&accountID,
|
|
&recoveryType,
|
|
&oldShareType,
|
|
&newKeygenSessionID,
|
|
&status,
|
|
&session.RequestedAt,
|
|
&session.CompletedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session.ID = id
|
|
session.AccountID = value_objects.AccountIDFromUUID(accountID)
|
|
session.RecoveryType = value_objects.RecoveryType(recoveryType)
|
|
session.Status = value_objects.RecoveryStatus(status)
|
|
|
|
if oldShareType.Valid {
|
|
st := value_objects.ShareType(oldShareType.String)
|
|
session.OldShareType = &st
|
|
}
|
|
|
|
if newKeygenSessionID.Valid {
|
|
if keygenID, err := uuid.Parse(newKeygenSessionID.String); err == nil {
|
|
session.NewKeygenSessionID = &keygenID
|
|
}
|
|
}
|
|
|
|
sessions = append(sessions, &session)
|
|
}
|
|
|
|
return sessions, rows.Err()
|
|
}
|