rwadurian/backend/mpc-system/services/message-router/adapters/output/postgres/message_repo.go

186 lines
5.1 KiB
Go

package postgres
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
)
// MessagePostgresRepo implements MessageRepository for PostgreSQL
type MessagePostgresRepo struct {
db *sql.DB
}
// NewMessagePostgresRepo creates a new PostgreSQL message repository
func NewMessagePostgresRepo(db *sql.DB) *MessagePostgresRepo {
return &MessagePostgresRepo{db: db}
}
// Save persists a new message
func (r *MessagePostgresRepo) Save(ctx context.Context, msg *entities.MPCMessage) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO mpc_messages (
id, session_id, from_party, to_parties, round_number, message_type, payload, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`,
msg.ID,
msg.SessionID,
msg.FromParty,
pq.Array(msg.ToParties),
msg.RoundNumber,
msg.MessageType,
msg.Payload,
msg.CreatedAt,
)
return err
}
// GetByID retrieves a message by ID
func (r *MessagePostgresRepo) GetByID(ctx context.Context, id uuid.UUID) (*entities.MPCMessage, error) {
var msg entities.MPCMessage
var toParties []string
err := r.db.QueryRowContext(ctx, `
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
FROM mpc_messages WHERE id = $1
`, id).Scan(
&msg.ID,
&msg.SessionID,
&msg.FromParty,
pq.Array(&toParties),
&msg.RoundNumber,
&msg.MessageType,
&msg.Payload,
&msg.CreatedAt,
&msg.DeliveredAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
msg.ToParties = toParties
return &msg, nil
}
// GetPendingMessages retrieves pending messages for a party
func (r *MessagePostgresRepo) GetPendingMessages(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
afterTime time.Time,
) ([]*entities.MPCMessage, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
FROM mpc_messages
WHERE session_id = $1
AND created_at > $2
AND from_party != $3
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $3 = ANY(to_parties))
ORDER BY round_number ASC, created_at ASC
`, sessionID, afterTime, partyID)
if err != nil {
return nil, err
}
defer rows.Close()
return r.scanMessages(rows)
}
// CountPendingByParty counts all pending messages for a party across all sessions
func (r *MessagePostgresRepo) CountPendingByParty(ctx context.Context, partyID string) (int64, error) {
var count int64
err := r.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM mpc_messages
WHERE delivered_at IS NULL
AND from_party != $1
AND (to_parties IS NULL OR cardinality(to_parties) = 0 OR $1 = ANY(to_parties))
`, partyID).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// GetMessagesByRound retrieves messages for a specific round
func (r *MessagePostgresRepo) GetMessagesByRound(
ctx context.Context,
sessionID uuid.UUID,
roundNumber int,
) ([]*entities.MPCMessage, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, session_id, from_party, to_parties, round_number, message_type, payload, created_at, delivered_at
FROM mpc_messages
WHERE session_id = $1 AND round_number = $2
ORDER BY created_at ASC
`, sessionID, roundNumber)
if err != nil {
return nil, err
}
defer rows.Close()
return r.scanMessages(rows)
}
// MarkDelivered marks a message as delivered
func (r *MessagePostgresRepo) MarkDelivered(ctx context.Context, messageID uuid.UUID) error {
_, err := r.db.ExecContext(ctx, `
UPDATE mpc_messages SET delivered_at = NOW() WHERE id = $1
`, messageID)
return err
}
// DeleteBySession deletes all messages for a session
func (r *MessagePostgresRepo) DeleteBySession(ctx context.Context, sessionID uuid.UUID) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE session_id = $1`, sessionID)
return err
}
// DeleteOlderThan deletes messages older than a specific time
func (r *MessagePostgresRepo) DeleteOlderThan(ctx context.Context, before time.Time) (int64, error) {
result, err := r.db.ExecContext(ctx, `DELETE FROM mpc_messages WHERE created_at < $1`, before)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (r *MessagePostgresRepo) scanMessages(rows *sql.Rows) ([]*entities.MPCMessage, error) {
var messages []*entities.MPCMessage
for rows.Next() {
var msg entities.MPCMessage
var toParties []string
err := rows.Scan(
&msg.ID,
&msg.SessionID,
&msg.FromParty,
pq.Array(&toParties),
&msg.RoundNumber,
&msg.MessageType,
&msg.Payload,
&msg.CreatedAt,
&msg.DeliveredAt,
)
if err != nil {
return nil, err
}
msg.ToParties = toParties
messages = append(messages, &msg)
}
return messages, rows.Err()
}
// Ensure interface compliance
var _ repositories.MessageRepository = (*MessagePostgresRepo)(nil)