rwadurian/backend/mpc-system/services/message-router/adapters/output/rabbitmq/message_broker.go

389 lines
9.5 KiB
Go

package rabbitmq
import (
"context"
"encoding/json"
"fmt"
"sync"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
"go.uber.org/zap"
)
// MessageBrokerAdapter implements MessageBroker using RabbitMQ
type MessageBrokerAdapter struct {
conn *amqp.Connection
channel *amqp.Channel
mu sync.Mutex
}
// NewMessageBrokerAdapter creates a new RabbitMQ message broker
func NewMessageBrokerAdapter(conn *amqp.Connection) (*MessageBrokerAdapter, error) {
channel, err := conn.Channel()
if err != nil {
return nil, fmt.Errorf("failed to create channel: %w", err)
}
// Declare exchange for party messages
err = channel.ExchangeDeclare(
"mpc.messages", // name
"direct", // type
true, // durable
false, // auto-deleted
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to declare exchange: %w", err)
}
// Declare exchange for session broadcasts
err = channel.ExchangeDeclare(
"mpc.session.broadcast", // name
"fanout", // type
true, // durable
false, // auto-deleted
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to declare broadcast exchange: %w", err)
}
return &MessageBrokerAdapter{
conn: conn,
channel: channel,
}, nil
}
// PublishToParty publishes a message to a specific party
func (a *MessageBrokerAdapter) PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error {
a.mu.Lock()
defer a.mu.Unlock()
// Ensure queue exists for the party
queueName := fmt.Sprintf("mpc.party.%s", partyID)
_, err := a.channel.QueueDeclare(
queueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
nil, // arguments
)
if err != nil {
return fmt.Errorf("failed to declare queue: %w", err)
}
// Bind queue to exchange
err = a.channel.QueueBind(
queueName, // queue name
partyID, // routing key
"mpc.messages", // exchange
false, // no-wait
nil, // arguments
)
if err != nil {
return fmt.Errorf("failed to bind queue: %w", err)
}
body, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
err = a.channel.PublishWithContext(
ctx,
"mpc.messages", // exchange
partyID, // routing key
false, // mandatory
false, // immediate
amqp.Publishing{
ContentType: "application/json",
DeliveryMode: amqp.Persistent,
Body: body,
},
)
if err != nil {
return fmt.Errorf("failed to publish message: %w", err)
}
logger.Debug("published message to party",
zap.String("party_id", partyID),
zap.String("message_id", message.ID))
return nil
}
// PublishToSession publishes a message to all parties in a session (except sender)
func (a *MessageBrokerAdapter) PublishToSession(
ctx context.Context,
sessionID string,
excludeParty string,
message *entities.MessageDTO,
) error {
a.mu.Lock()
defer a.mu.Unlock()
// Use session-specific exchange
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
// Declare session-specific fanout exchange
err := a.channel.ExchangeDeclare(
exchangeName, // name
"fanout", // type
false, // durable (temporary for session)
true, // auto-delete when unused
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
return fmt.Errorf("failed to declare session exchange: %w", err)
}
body, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
err = a.channel.PublishWithContext(
ctx,
exchangeName, // exchange
"", // routing key (ignored for fanout)
false, // mandatory
false, // immediate
amqp.Publishing{
ContentType: "application/json",
DeliveryMode: amqp.Persistent,
Body: body,
Headers: amqp.Table{
"exclude_party": excludeParty,
},
},
)
if err != nil {
return fmt.Errorf("failed to publish broadcast: %w", err)
}
logger.Debug("broadcast message to session",
zap.String("session_id", sessionID),
zap.String("message_id", message.ID),
zap.String("exclude_party", excludeParty))
return nil
}
// SubscribeToPartyMessages subscribes to messages for a specific party
func (a *MessageBrokerAdapter) SubscribeToPartyMessages(
ctx context.Context,
partyID string,
) (<-chan *entities.MessageDTO, error) {
a.mu.Lock()
defer a.mu.Unlock()
queueName := fmt.Sprintf("mpc.party.%s", partyID)
// Ensure queue exists
_, err := a.channel.QueueDeclare(
queueName, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to declare queue: %w", err)
}
// Bind queue to exchange
err = a.channel.QueueBind(
queueName, // queue name
partyID, // routing key
"mpc.messages", // exchange
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to bind queue: %w", err)
}
// Start consuming
msgs, err := a.channel.Consume(
queueName, // queue
"", // consumer
false, // auto-ack (we'll ack manually)
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
return nil, fmt.Errorf("failed to register consumer: %w", err)
}
// Create output channel
out := make(chan *entities.MessageDTO, 100)
// Start goroutine to forward messages
go func() {
defer close(out)
for {
select {
case <-ctx.Done():
return
case msg, ok := <-msgs:
if !ok {
return
}
var dto entities.MessageDTO
if err := json.Unmarshal(msg.Body, &dto); err != nil {
logger.Error("failed to unmarshal message", zap.Error(err))
msg.Nack(false, false)
continue
}
select {
case out <- &dto:
msg.Ack(false)
case <-ctx.Done():
msg.Nack(false, true) // Requeue
return
}
}
}
}()
return out, nil
}
// SubscribeToSessionMessages subscribes to all messages in a session
func (a *MessageBrokerAdapter) SubscribeToSessionMessages(
ctx context.Context,
sessionID string,
partyID string,
) (<-chan *entities.MessageDTO, error) {
a.mu.Lock()
defer a.mu.Unlock()
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
queueName := fmt.Sprintf("mpc.session.%s.%s", sessionID, partyID)
// Declare session-specific fanout exchange
err := a.channel.ExchangeDeclare(
exchangeName, // name
"fanout", // type
false, // durable
true, // auto-delete
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to declare session exchange: %w", err)
}
// Declare temporary queue for this subscriber
_, err = a.channel.QueueDeclare(
queueName, // name
false, // durable
true, // delete when unused
true, // exclusive
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to declare queue: %w", err)
}
// Bind queue to session exchange
err = a.channel.QueueBind(
queueName, // queue name
"", // routing key (ignored for fanout)
exchangeName, // exchange
false, // no-wait
nil, // arguments
)
if err != nil {
return nil, fmt.Errorf("failed to bind queue: %w", err)
}
// Start consuming
msgs, err := a.channel.Consume(
queueName, // queue
"", // consumer
false, // auto-ack
true, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
return nil, fmt.Errorf("failed to register consumer: %w", err)
}
// Create output channel
out := make(chan *entities.MessageDTO, 100)
// Start goroutine to forward messages
go func() {
defer close(out)
for {
select {
case <-ctx.Done():
return
case msg, ok := <-msgs:
if !ok {
return
}
// Check if this message should be excluded for this party
if excludeParty, ok := msg.Headers["exclude_party"].(string); ok {
if excludeParty == partyID {
msg.Ack(false)
continue
}
}
var dto entities.MessageDTO
if err := json.Unmarshal(msg.Body, &dto); err != nil {
logger.Error("failed to unmarshal message", zap.Error(err))
msg.Nack(false, false)
continue
}
select {
case out <- &dto:
msg.Ack(false)
case <-ctx.Done():
msg.Nack(false, true)
return
}
}
}
}()
return out, nil
}
// Close closes the connection
func (a *MessageBrokerAdapter) Close() error {
a.mu.Lock()
defer a.mu.Unlock()
if a.channel != nil {
return a.channel.Close()
}
return nil
}
// Ensure interface compliance
var _ use_cases.MessageBroker = (*MessageBrokerAdapter)(nil)