package postgres import ( "context" "database/sql" "errors" "github.com/google/uuid" "github.com/rwadurian/mpc-system/services/account/domain/entities" "github.com/rwadurian/mpc-system/services/account/domain/repositories" "github.com/rwadurian/mpc-system/services/account/domain/value_objects" ) // AccountSharePostgresRepo implements AccountShareRepository using PostgreSQL type AccountSharePostgresRepo struct { db *sql.DB } // NewAccountSharePostgresRepo creates a new AccountSharePostgresRepo func NewAccountSharePostgresRepo(db *sql.DB) repositories.AccountShareRepository { return &AccountSharePostgresRepo{db: db} } // Create creates a new account share func (r *AccountSharePostgresRepo) Create(ctx context.Context, share *entities.AccountShare) error { query := ` INSERT INTO account_shares (id, account_id, share_type, party_id, party_index, device_type, device_id, created_at, last_used_at, is_active) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` _, err := r.db.ExecContext(ctx, query, share.ID, share.AccountID.UUID(), share.ShareType.String(), share.PartyID, share.PartyIndex, share.DeviceType, share.DeviceID, share.CreatedAt, share.LastUsedAt, share.IsActive, ) return err } // GetByID retrieves a share by ID func (r *AccountSharePostgresRepo) GetByID(ctx context.Context, id string) (*entities.AccountShare, error) { shareID, err := uuid.Parse(id) if err != nil { return nil, entities.ErrShareNotFound } query := ` SELECT id, account_id, share_type, party_id, party_index, device_type, device_id, created_at, last_used_at, is_active FROM account_shares WHERE id = $1 ` return r.scanShare(r.db.QueryRowContext(ctx, query, shareID)) } // GetByAccountID retrieves all shares for an account func (r *AccountSharePostgresRepo) GetByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) { query := ` SELECT id, account_id, share_type, party_id, party_index, device_type, device_id, created_at, last_used_at, is_active FROM account_shares WHERE account_id = $1 ORDER BY party_index ` return r.queryShares(ctx, query, accountID.UUID()) } // GetActiveByAccountID retrieves active shares for an account func (r *AccountSharePostgresRepo) GetActiveByAccountID(ctx context.Context, accountID value_objects.AccountID) ([]*entities.AccountShare, error) { query := ` SELECT id, account_id, share_type, party_id, party_index, device_type, device_id, created_at, last_used_at, is_active FROM account_shares WHERE account_id = $1 AND is_active = TRUE ORDER BY party_index ` return r.queryShares(ctx, query, accountID.UUID()) } // GetByPartyID retrieves shares by party ID func (r *AccountSharePostgresRepo) GetByPartyID(ctx context.Context, partyID string) ([]*entities.AccountShare, error) { query := ` SELECT id, account_id, share_type, party_id, party_index, device_type, device_id, created_at, last_used_at, is_active FROM account_shares WHERE party_id = $1 ORDER BY created_at DESC ` return r.queryShares(ctx, query, partyID) } // Update updates a share func (r *AccountSharePostgresRepo) Update(ctx context.Context, share *entities.AccountShare) error { query := ` UPDATE account_shares SET share_type = $2, party_id = $3, party_index = $4, device_type = $5, device_id = $6, last_used_at = $7, is_active = $8 WHERE id = $1 ` result, err := r.db.ExecContext(ctx, query, share.ID, share.ShareType.String(), share.PartyID, share.PartyIndex, share.DeviceType, share.DeviceID, share.LastUsedAt, share.IsActive, ) if err != nil { return err } rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return entities.ErrShareNotFound } return nil } // Delete deletes a share func (r *AccountSharePostgresRepo) Delete(ctx context.Context, id string) error { shareID, err := uuid.Parse(id) if err != nil { return entities.ErrShareNotFound } query := `DELETE FROM account_shares WHERE id = $1` result, err := r.db.ExecContext(ctx, query, shareID) if err != nil { return err } rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return entities.ErrShareNotFound } return nil } // DeactivateByAccountID deactivates all shares for an account func (r *AccountSharePostgresRepo) DeactivateByAccountID(ctx context.Context, accountID value_objects.AccountID) error { query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1` _, err := r.db.ExecContext(ctx, query, accountID.UUID()) return err } // DeactivateByShareType deactivates shares of a specific type for an account func (r *AccountSharePostgresRepo) DeactivateByShareType(ctx context.Context, accountID value_objects.AccountID, shareType value_objects.ShareType) error { query := `UPDATE account_shares SET is_active = FALSE WHERE account_id = $1 AND share_type = $2` _, err := r.db.ExecContext(ctx, query, accountID.UUID(), shareType.String()) return err } // scanShare scans a single share row func (r *AccountSharePostgresRepo) scanShare(row *sql.Row) (*entities.AccountShare, error) { var ( id uuid.UUID accountID uuid.UUID shareType string partyID string partyIndex int deviceType sql.NullString deviceID sql.NullString share entities.AccountShare ) err := row.Scan( &id, &accountID, &shareType, &partyID, &partyIndex, &deviceType, &deviceID, &share.CreatedAt, &share.LastUsedAt, &share.IsActive, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, entities.ErrShareNotFound } return nil, err } share.ID = id share.AccountID = value_objects.AccountIDFromUUID(accountID) share.ShareType = value_objects.ShareType(shareType) share.PartyID = partyID share.PartyIndex = partyIndex if deviceType.Valid { share.DeviceType = &deviceType.String } if deviceID.Valid { share.DeviceID = &deviceID.String } return &share, nil } // queryShares queries multiple shares func (r *AccountSharePostgresRepo) queryShares(ctx context.Context, query string, args ...interface{}) ([]*entities.AccountShare, error) { rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var shares []*entities.AccountShare for rows.Next() { var ( id uuid.UUID accountID uuid.UUID shareType string partyID string partyIndex int deviceType sql.NullString deviceID sql.NullString share entities.AccountShare ) err := rows.Scan( &id, &accountID, &shareType, &partyID, &partyIndex, &deviceType, &deviceID, &share.CreatedAt, &share.LastUsedAt, &share.IsActive, ) if err != nil { return nil, err } share.ID = id share.AccountID = value_objects.AccountIDFromUUID(accountID) share.ShareType = value_objects.ShareType(shareType) share.PartyID = partyID share.PartyIndex = partyIndex if deviceType.Valid { share.DeviceType = &deviceType.String } if deviceID.Valid { share.DeviceID = &deviceID.String } shares = append(shares, &share) } return shares, rows.Err() }