186 lines
5.1 KiB
Go
186 lines
5.1 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
|
)
|
|
|
|
// MessagePostgresRepo implements MessageRepository for PostgreSQL
|
|
type MessagePostgresRepo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewMessagePostgresRepo creates a new PostgreSQL message repository
|
|
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
|
|
return &MessagePostgresRepo{db: db}
|
|
}
|
|
|
|
// Save persists a new message
|
|
func (r *MessagePostgresRepo) Save(ctx context.Context, msg *entities.MPCMessage) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO mpc_messages (
|
|
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`,
|
|
msg.ID,
|
|
msg.SessionID,
|
|
msg.FromParty,
|
|
pq.Array(msg.ToParties),
|
|
msg.RoundNumber,
|
|
msg.MessageType,
|
|
msg.Payload,
|
|
msg.CreatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// GetByID retrieves a message by ID
|
|
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error) {
|
|
var msg entities.MPCMessage
|
|
var toParties []string
|
|
|
|
err := r.db.QueryRowContext(ctx, `
|
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
|
FROM mpc_messages WHERE id = $1
|
|
`, id).Scan(
|
|
&msg.ID,
|
|
&msg.SessionID,
|
|
&msg.FromParty,
|
|
pq.Array(&toParties),
|
|
&msg.RoundNumber,
|
|
&msg.MessageType,
|
|
&msg.Payload,
|
|
&msg.CreatedAt,
|
|
&msg.DeliveredAt,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
msg.ToParties = toParties
|
|
return &msg, nil
|
|
}
|
|
|
|
// GetPendingMessages retrieves pending messages for a party
|
|
func (r *MessagePostgresRepo) GetPendingMessages(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
afterTime time.Time,
|
|
) ([]*entities.MPCMessage, error) {
|
|
rows, err := r.db.QueryContext(ctx, `
|
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
|
FROM mpc_messages
|
|
WHERE session_id = $1
|
|
AND created_at > $2
|
|
AND from_party != $3
|
|
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $3 = ANY(to_parties))
|
|
ORDER BY round_number ASC, created_at ASC
|
|
`, sessionID, afterTime, partyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return r.scanMessages(rows)
|
|
}
|
|
|
|
// CountPendingByParty counts all pending messages for a party across all sessions
|
|
func (r *MessagePostgresRepo) CountPendingByParty(ctx context.Context, partyID string) (int64, error) {
|
|
var count int64
|
|
err := r.db.QueryRowContext(ctx, `
|
|
SELECT COUNT(*)
|
|
FROM mpc_messages
|
|
WHERE delivered_at IS NULL
|
|
AND from_party != $1
|
|
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $1 = ANY(to_parties))
|
|
`, partyID).Scan(&count)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// GetMessagesByRound retrieves messages for a specific round
|
|
func (r *MessagePostgresRepo) GetMessagesByRound(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
roundNumber int,
|
|
) ([]*entities.MPCMessage, error) {
|
|
rows, err := r.db.QueryContext(ctx, `
|
|
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
|
|
FROM mpc_messages
|
|
WHERE session_id = $1 AND round_number = $2
|
|
ORDER BY created_at ASC
|
|
`, sessionID, roundNumber)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return r.scanMessages(rows)
|
|
}
|
|
|
|
// MarkDelivered marks a message as delivered
|
|
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
|
|
`, messageID)
|
|
return err
|
|
}
|
|
|
|
// DeleteBySession deletes all messages for a session
|
|
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID uuid.UUID) error {
|
|
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID)
|
|
return err
|
|
}
|
|
|
|
// DeleteOlderThan deletes messages older than a specific time
|
|
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
|
|
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.MPCMessage, error) {
|
|
var messages []*entities.MPCMessage
|
|
for rows.Next() {
|
|
var msg entities.MPCMessage
|
|
var toParties []string
|
|
|
|
err := rows.Scan(
|
|
&msg.ID,
|
|
&msg.SessionID,
|
|
&msg.FromParty,
|
|
pq.Array(&toParties),
|
|
&msg.RoundNumber,
|
|
&msg.MessageType,
|
|
&msg.Payload,
|
|
&msg.CreatedAt,
|
|
&msg.DeliveredAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msg.ToParties = toParties
|
|
messages = append(messages, &msg)
|
|
}
|
|
|
|
return messages, rows.Err()
|
|
}
|
|
|
|
// Ensure interface compliance
|
|
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)
|