389 lines
9.5 KiB
Go
389 lines
9.5 KiB
Go
package rabbitmq
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
|
|
amqp "github.com/rabbitmq/amqp091-go"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// MessageBrokerAdapter implements MessageBroker using RabbitMQ
|
|
type MessageBrokerAdapter struct {
|
|
conn *amqp.Connection
|
|
channel *amqp.Channel
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewMessageBrokerAdapter creates a new RabbitMQ message broker
|
|
func NewMessageBrokerAdapter(conn *amqp.Connection) (*MessageBrokerAdapter, error) {
|
|
channel, err := conn.Channel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create channel: %w", err)
|
|
}
|
|
|
|
// Declare exchange for party messages
|
|
err = channel.ExchangeDeclare(
|
|
"mpc.messages", // name
|
|
"direct", // type
|
|
true, // durable
|
|
false, // auto-deleted
|
|
false, // internal
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to declare exchange: %w", err)
|
|
}
|
|
|
|
// Declare exchange for session broadcasts
|
|
err = channel.ExchangeDeclare(
|
|
"mpc.session.broadcast", // name
|
|
"fanout", // type
|
|
true, // durable
|
|
false, // auto-deleted
|
|
false, // internal
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to declare broadcast exchange: %w", err)
|
|
}
|
|
|
|
return &MessageBrokerAdapter{
|
|
conn: conn,
|
|
channel: channel,
|
|
}, nil
|
|
}
|
|
|
|
// PublishToParty publishes a message to a specific party
|
|
func (a *MessageBrokerAdapter) PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
// Ensure queue exists for the party
|
|
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
|
_, err := a.channel.QueueDeclare(
|
|
queueName, // name
|
|
true, // durable
|
|
false, // delete when unused
|
|
false, // exclusive
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to declare queue: %w", err)
|
|
}
|
|
|
|
// Bind queue to exchange
|
|
err = a.channel.QueueBind(
|
|
queueName, // queue name
|
|
partyID, // routing key
|
|
"mpc.messages", // exchange
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to bind queue: %w", err)
|
|
}
|
|
|
|
body, err := json.Marshal(message)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal message: %w", err)
|
|
}
|
|
|
|
err = a.channel.PublishWithContext(
|
|
ctx,
|
|
"mpc.messages", // exchange
|
|
partyID, // routing key
|
|
false, // mandatory
|
|
false, // immediate
|
|
amqp.Publishing{
|
|
ContentType: "application/json",
|
|
DeliveryMode: amqp.Persistent,
|
|
Body: body,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to publish message: %w", err)
|
|
}
|
|
|
|
logger.Debug("published message to party",
|
|
zap.String("party_id", partyID),
|
|
zap.String("message_id", message.ID))
|
|
|
|
return nil
|
|
}
|
|
|
|
// PublishToSession publishes a message to all parties in a session (except sender)
|
|
func (a *MessageBrokerAdapter) PublishToSession(
|
|
ctx context.Context,
|
|
sessionID string,
|
|
excludeParty string,
|
|
message *entities.MessageDTO,
|
|
) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
// Use session-specific exchange
|
|
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
|
|
|
// Declare session-specific fanout exchange
|
|
err := a.channel.ExchangeDeclare(
|
|
exchangeName, // name
|
|
"fanout", // type
|
|
false, // durable (temporary for session)
|
|
true, // auto-delete when unused
|
|
false, // internal
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to declare session exchange: %w", err)
|
|
}
|
|
|
|
body, err := json.Marshal(message)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal message: %w", err)
|
|
}
|
|
|
|
err = a.channel.PublishWithContext(
|
|
ctx,
|
|
exchangeName, // exchange
|
|
"", // routing key (ignored for fanout)
|
|
false, // mandatory
|
|
false, // immediate
|
|
amqp.Publishing{
|
|
ContentType: "application/json",
|
|
DeliveryMode: amqp.Persistent,
|
|
Body: body,
|
|
Headers: amqp.Table{
|
|
"exclude_party": excludeParty,
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to publish broadcast: %w", err)
|
|
}
|
|
|
|
logger.Debug("broadcast message to session",
|
|
zap.String("session_id", sessionID),
|
|
zap.String("message_id", message.ID),
|
|
zap.String("exclude_party", excludeParty))
|
|
|
|
return nil
|
|
}
|
|
|
|
// SubscribeToPartyMessages subscribes to messages for a specific party
|
|
func (a *MessageBrokerAdapter) SubscribeToPartyMessages(
|
|
ctx context.Context,
|
|
partyID string,
|
|
) (<-chan *entities.MessageDTO, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
queueName := fmt.Sprintf("mpc.party.%s", partyID)
|
|
|
|
// Ensure queue exists
|
|
_, err := a.channel.QueueDeclare(
|
|
queueName, // name
|
|
true, // durable
|
|
false, // delete when unused
|
|
false, // exclusive
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
|
}
|
|
|
|
// Bind queue to exchange
|
|
err = a.channel.QueueBind(
|
|
queueName, // queue name
|
|
partyID, // routing key
|
|
"mpc.messages", // exchange
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
|
}
|
|
|
|
// Start consuming
|
|
msgs, err := a.channel.Consume(
|
|
queueName, // queue
|
|
"", // consumer
|
|
false, // auto-ack (we'll ack manually)
|
|
false, // exclusive
|
|
false, // no-local
|
|
false, // no-wait
|
|
nil, // args
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
|
}
|
|
|
|
// Create output channel
|
|
out := make(chan *entities.MessageDTO, 100)
|
|
|
|
// Start goroutine to forward messages
|
|
go func() {
|
|
defer close(out)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-msgs:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var dto entities.MessageDTO
|
|
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
|
logger.Error("failed to unmarshal message", zap.Error(err))
|
|
msg.Nack(false, false)
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case out <- &dto:
|
|
msg.Ack(false)
|
|
case <-ctx.Done():
|
|
msg.Nack(false, true) // Requeue
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return out, nil
|
|
}
|
|
|
|
// SubscribeToSessionMessages subscribes to all messages in a session
|
|
func (a *MessageBrokerAdapter) SubscribeToSessionMessages(
|
|
ctx context.Context,
|
|
sessionID string,
|
|
partyID string,
|
|
) (<-chan *entities.MessageDTO, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
exchangeName := fmt.Sprintf("mpc.session.%s", sessionID)
|
|
queueName := fmt.Sprintf("mpc.session.%s.%s", sessionID, partyID)
|
|
|
|
// Declare session-specific fanout exchange
|
|
err := a.channel.ExchangeDeclare(
|
|
exchangeName, // name
|
|
"fanout", // type
|
|
false, // durable
|
|
true, // auto-delete
|
|
false, // internal
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to declare session exchange: %w", err)
|
|
}
|
|
|
|
// Declare temporary queue for this subscriber
|
|
_, err = a.channel.QueueDeclare(
|
|
queueName, // name
|
|
false, // durable
|
|
true, // delete when unused
|
|
true, // exclusive
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to declare queue: %w", err)
|
|
}
|
|
|
|
// Bind queue to session exchange
|
|
err = a.channel.QueueBind(
|
|
queueName, // queue name
|
|
"", // routing key (ignored for fanout)
|
|
exchangeName, // exchange
|
|
false, // no-wait
|
|
nil, // arguments
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to bind queue: %w", err)
|
|
}
|
|
|
|
// Start consuming
|
|
msgs, err := a.channel.Consume(
|
|
queueName, // queue
|
|
"", // consumer
|
|
false, // auto-ack
|
|
true, // exclusive
|
|
false, // no-local
|
|
false, // no-wait
|
|
nil, // args
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to register consumer: %w", err)
|
|
}
|
|
|
|
// Create output channel
|
|
out := make(chan *entities.MessageDTO, 100)
|
|
|
|
// Start goroutine to forward messages
|
|
go func() {
|
|
defer close(out)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-msgs:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// Check if this message should be excluded for this party
|
|
if excludeParty, ok := msg.Headers["exclude_party"].(string); ok {
|
|
if excludeParty == partyID {
|
|
msg.Ack(false)
|
|
continue
|
|
}
|
|
}
|
|
|
|
var dto entities.MessageDTO
|
|
if err := json.Unmarshal(msg.Body, &dto); err != nil {
|
|
logger.Error("failed to unmarshal message", zap.Error(err))
|
|
msg.Nack(false, false)
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case out <- &dto:
|
|
msg.Ack(false)
|
|
case <-ctx.Done():
|
|
msg.Nack(false, true)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return out, nil
|
|
}
|
|
|
|
// Close closes the connection
|
|
func (a *MessageBrokerAdapter) Close() error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.channel != nil {
|
|
return a.channel.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Ensure interface compliance
|
|
var _ use_cases.MessageBroker = (*MessageBrokerAdapter)(nil)
|