package rabbitmq import ( "context" "encoding/json" "fmt" "sync" amqp "github.com/rabbitmq/amqp091-go" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/services/message-router/application/use_cases" "github.com/rwadurian/mpc-system/services/message-router/domain/entities" "go.uber.org/zap" ) // MessageBrokerAdapter implements MessageBroker using RabbitMQ type MessageBrokerAdapter struct { conn *amqp.Connection channel *amqp.Channel mu sync.Mutex } // NewMessageBrokerAdapter creates a new RabbitMQ message broker func NewMessageBrokerAdapter(conn *amqp.Connection) (*MessageBrokerAdapter, error) { channel, err := conn.Channel() if err != nil { return nil, fmt.Errorf("failed to create channel: %w", err) } // Declare exchange for party messages err = channel.ExchangeDeclare( "mpc.messages", // name "direct", // type true, // durable false, // auto-deleted false, // internal false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to declare exchange: %w", err) } // Declare exchange for session broadcasts err = channel.ExchangeDeclare( "mpc.session.broadcast", // name "fanout", // type true, // durable false, // auto-deleted false, // internal false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to declare broadcast exchange: %w", err) } return &MessageBrokerAdapter{ conn: conn, channel: channel, }, nil } // PublishToParty publishes a message to a specific party func (a *MessageBrokerAdapter) PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error { a.mu.Lock() defer a.mu.Unlock() // Ensure queue exists for the party queueName := fmt.Sprintf("mpc.party.%s", partyID) _, err := a.channel.QueueDeclare( queueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait nil, // arguments ) if err != nil { return fmt.Errorf("failed to declare queue: %w", err) } // Bind queue to exchange err = a.channel.QueueBind( queueName, // queue name partyID, // routing key "mpc.messages", // exchange false, // no-wait nil, // arguments ) if err != nil { return fmt.Errorf("failed to bind queue: %w", err) } body, err := json.Marshal(message) if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } err = a.channel.PublishWithContext( ctx, "mpc.messages", // exchange partyID, // routing key false, // mandatory false, // immediate amqp.Publishing{ ContentType: "application/json", DeliveryMode: amqp.Persistent, Body: body, }, ) if err != nil { return fmt.Errorf("failed to publish message: %w", err) } logger.Debug("published message to party", zap.String("party_id", partyID), zap.String("message_id", message.ID)) return nil } // PublishToSession publishes a message to all parties in a session (except sender) func (a *MessageBrokerAdapter) PublishToSession( ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO, ) error { a.mu.Lock() defer a.mu.Unlock() // Use session-specific exchange exchangeName := fmt.Sprintf("mpc.session.%s", sessionID) // Declare session-specific fanout exchange err := a.channel.ExchangeDeclare( exchangeName, // name "fanout", // type false, // durable (temporary for session) true, // auto-delete when unused false, // internal false, // no-wait nil, // arguments ) if err != nil { return fmt.Errorf("failed to declare session exchange: %w", err) } body, err := json.Marshal(message) if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } err = a.channel.PublishWithContext( ctx, exchangeName, // exchange "", // routing key (ignored for fanout) false, // mandatory false, // immediate amqp.Publishing{ ContentType: "application/json", DeliveryMode: amqp.Persistent, Body: body, Headers: amqp.Table{ "exclude_party": excludeParty, }, }, ) if err != nil { return fmt.Errorf("failed to publish broadcast: %w", err) } logger.Debug("broadcast message to session", zap.String("session_id", sessionID), zap.String("message_id", message.ID), zap.String("exclude_party", excludeParty)) return nil } // SubscribeToPartyMessages subscribes to messages for a specific party func (a *MessageBrokerAdapter) SubscribeToPartyMessages( ctx context.Context, partyID string, ) (<-chan *entities.MessageDTO, error) { a.mu.Lock() defer a.mu.Unlock() queueName := fmt.Sprintf("mpc.party.%s", partyID) // Ensure queue exists _, err := a.channel.QueueDeclare( queueName, // name true, // durable false, // delete when unused false, // exclusive false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to declare queue: %w", err) } // Bind queue to exchange err = a.channel.QueueBind( queueName, // queue name partyID, // routing key "mpc.messages", // exchange false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to bind queue: %w", err) } // Start consuming msgs, err := a.channel.Consume( queueName, // queue "", // consumer false, // auto-ack (we'll ack manually) false, // exclusive false, // no-local false, // no-wait nil, // args ) if err != nil { return nil, fmt.Errorf("failed to register consumer: %w", err) } // Create output channel out := make(chan *entities.MessageDTO, 100) // Start goroutine to forward messages go func() { defer close(out) for { select { case <-ctx.Done(): return case msg, ok := <-msgs: if !ok { return } var dto entities.MessageDTO if err := json.Unmarshal(msg.Body, &dto); err != nil { logger.Error("failed to unmarshal message", zap.Error(err)) msg.Nack(false, false) continue } select { case out <- &dto: msg.Ack(false) case <-ctx.Done(): msg.Nack(false, true) // Requeue return } } } }() return out, nil } // SubscribeToSessionMessages subscribes to all messages in a session func (a *MessageBrokerAdapter) SubscribeToSessionMessages( ctx context.Context, sessionID string, partyID string, ) (<-chan *entities.MessageDTO, error) { a.mu.Lock() defer a.mu.Unlock() exchangeName := fmt.Sprintf("mpc.session.%s", sessionID) queueName := fmt.Sprintf("mpc.session.%s.%s", sessionID, partyID) // Declare session-specific fanout exchange err := a.channel.ExchangeDeclare( exchangeName, // name "fanout", // type false, // durable true, // auto-delete false, // internal false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to declare session exchange: %w", err) } // Declare temporary queue for this subscriber _, err = a.channel.QueueDeclare( queueName, // name false, // durable true, // delete when unused true, // exclusive false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to declare queue: %w", err) } // Bind queue to session exchange err = a.channel.QueueBind( queueName, // queue name "", // routing key (ignored for fanout) exchangeName, // exchange false, // no-wait nil, // arguments ) if err != nil { return nil, fmt.Errorf("failed to bind queue: %w", err) } // Start consuming msgs, err := a.channel.Consume( queueName, // queue "", // consumer false, // auto-ack true, // exclusive false, // no-local false, // no-wait nil, // args ) if err != nil { return nil, fmt.Errorf("failed to register consumer: %w", err) } // Create output channel out := make(chan *entities.MessageDTO, 100) // Start goroutine to forward messages go func() { defer close(out) for { select { case <-ctx.Done(): return case msg, ok := <-msgs: if !ok { return } // Check if this message should be excluded for this party if excludeParty, ok := msg.Headers["exclude_party"].(string); ok { if excludeParty == partyID { msg.Ack(false) continue } } var dto entities.MessageDTO if err := json.Unmarshal(msg.Body, &dto); err != nil { logger.Error("failed to unmarshal message", zap.Error(err)) msg.Nack(false, false) continue } select { case out <- &dto: msg.Ack(false) case <-ctx.Done(): msg.Nack(false, true) return } } } }() return out, nil } // Close closes the connection func (a *MessageBrokerAdapter) Close() error { a.mu.Lock() defer a.mu.Unlock() if a.channel != nil { return a.channel.Close() } return nil } // Ensure interface compliance var _ use_cases.MessageBroker = (*MessageBrokerAdapter)(nil)