package postgres import ( "context" "database/sql" "time" "github.com/google/uuid" "github.com/lib/pq" "github.com/rwadurian/mpc-system/services/message-router/domain/entities" "github.com/rwadurian/mpc-system/services/message-router/domain/repositories" ) // MessagePostgresRepo implements MessageRepository for PostgreSQL type MessagePostgresRepo struct { db *sql.DB } // NewMessagePostgresRepo creates a new PostgreSQL message repository func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo { return &MessagePostgresRepo{db: db} } // Save persists a new message func (r *MessagePostgresRepo) Save(ctx context.Context, msg *entities.MPCMessage) error { _, err := r.db.ExecContext(ctx, ` INSERT INTO mpc_messages ( id, session_id, from_party, to_parties, round_number, message_type, payload, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) `, msg.ID, msg.SessionID, msg.FromParty, pq.Array(msg.ToParties), msg.RoundNumber, msg.MessageType, msg.Payload, msg.CreatedAt, ) return err } // GetByID retrieves a message by ID func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error) { var msg entities.MPCMessage var toParties []string err := r.db.QueryRowContext(ctx, ` SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at FROM mpc_messages WHERE id = $1 `, id).Scan( &msg.ID, &msg.SessionID, &msg.FromParty, pq.Array(&toParties), &msg.RoundNumber, &msg.MessageType, &msg.Payload, &msg.CreatedAt, &msg.DeliveredAt, ) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } msg.ToParties = toParties return &msg, nil } // GetPendingMessages retrieves pending messages for a party func (r *MessagePostgresRepo) GetPendingMessages( ctx context.Context, sessionID uuid.UUID, partyID string, afterTime time.Time, ) ([]*entities.MPCMessage, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at FROM mpc_messages WHERE session_id = $1 AND created_at > $2 AND from_party != $3 AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $3 = ANY(to_parties)) ORDER BY round_number ASC, created_at ASC `, sessionID, afterTime, partyID) if err != nil { return nil, err } defer rows.Close() return r.scanMessages(rows) } // CountPendingByParty counts all pending messages for a party across all sessions func (r *MessagePostgresRepo) CountPendingByParty(ctx context.Context, partyID string) (int64, error) { var count int64 err := r.db.QueryRowContext(ctx, ` SELECT COUNT(*) FROM mpc_messages WHERE delivered_at IS NULL AND from_party != $1 AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $1 = ANY(to_parties)) `, partyID).Scan(&count) if err != nil { return 0, err } return count, nil } // GetMessagesByRound retrieves messages for a specific round func (r *MessagePostgresRepo) GetMessagesByRound( ctx context.Context, sessionID uuid.UUID, roundNumber int, ) ([]*entities.MPCMessage, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at FROM mpc_messages WHERE session_id = $1 AND round_number = $2 ORDER BY created_at ASC `, sessionID, roundNumber) if err != nil { return nil, err } defer rows.Close() return r.scanMessages(rows) } // MarkDelivered marks a message as delivered func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error { _, err := r.db.ExecContext(ctx, ` UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1 `, messageID) return err } // DeleteBySession deletes all messages for a session func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID uuid.UUID) error { _, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID) return err } // DeleteOlderThan deletes messages older than a specific time func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) { result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before) if err != nil { return 0, err } return result.RowsAffected() } func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.MPCMessage, error) { var messages []*entities.MPCMessage for rows.Next() { var msg entities.MPCMessage var toParties []string err := rows.Scan( &msg.ID, &msg.SessionID, &msg.FromParty, pq.Array(&toParties), &msg.RoundNumber, &msg.MessageType, &msg.Payload, &msg.CreatedAt, &msg.DeliveredAt, ) if err != nil { return nil, err } msg.ToParties = toParties messages = append(messages, &msg) } return messages, rows.Err() } // Ensure interface compliance var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)