rwadurian/backend/mpc-system/services/message-router/domain/session_event_broadcaster.go

131 lines
4.1 KiB
Go

package domain
import (
"sync"
"time"
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
"github.com/rwadurian/mpc-system/pkg/logger"
"go.uber.org/zap"
)
// SessionEventBroadcaster manages session event subscriptions and broadcasting
type SessionEventBroadcaster struct {
subscribers map[string]chan *pb.SessionEvent // partyID -> event channel
mu sync.RWMutex
}
// NewSessionEventBroadcaster creates a new session event broadcaster
func NewSessionEventBroadcaster() *SessionEventBroadcaster {
return &SessionEventBroadcaster{
subscribers: make(map[string]chan *pb.SessionEvent),
}
}
// Subscribe subscribes a party to session events
// Returns the channel for receiving events and a unique subscription ID
// The subscription ID is used to safely unsubscribe without affecting newer subscriptions
func (b *SessionEventBroadcaster) Subscribe(partyID string) (<-chan *pb.SessionEvent, int64) {
b.mu.Lock()
defer b.mu.Unlock()
// Close existing channel if party is re-subscribing (e.g., after reconnect)
// This will cause the old gRPC stream to exit cleanly
if oldCh, exists := b.subscribers[partyID]; exists {
close(oldCh)
logger.Debug("Closed old subscription channel for re-subscribing party",
zap.String("party_id", partyID))
}
// Create buffered channel for this subscriber
ch := make(chan *pb.SessionEvent, 100)
b.subscribers[partyID] = ch
// Generate a unique subscription ID (using current time in nanoseconds)
subscriptionID := time.Now().UnixNano()
return ch, subscriptionID
}
// Unsubscribe removes a party's subscription only if the channel matches
// This prevents a race condition where a newer subscription is accidentally removed
// when an old gRPC stream exits after the party has already re-subscribed
func (b *SessionEventBroadcaster) Unsubscribe(partyID string, ch <-chan *pb.SessionEvent) {
b.mu.Lock()
defer b.mu.Unlock()
if currentCh, exists := b.subscribers[partyID]; exists {
// Only delete if the channel matches (i.e., this is still our subscription)
// If the channel doesn't match, a newer subscription has been created
// and we should not delete it
if currentCh == ch {
// Don't close the channel here - it was already closed by Subscribe
// when the new subscription was created, or we're the last one
delete(b.subscribers, partyID)
logger.Debug("Unsubscribed party from session events",
zap.String("party_id", partyID))
} else {
logger.Debug("Skipping unsubscribe - channel mismatch (newer subscription exists)",
zap.String("party_id", partyID))
}
}
}
// Broadcast sends an event to all subscribers
func (b *SessionEventBroadcaster) Broadcast(event *pb.SessionEvent) {
b.mu.RLock()
defer b.mu.RUnlock()
for _, ch := range b.subscribers {
// Non-blocking send to prevent slow subscribers from blocking
select {
case ch <- event:
default:
// Channel full, skip this subscriber
}
}
}
// BroadcastToParties sends an event to specific parties only
func (b *SessionEventBroadcaster) BroadcastToParties(event *pb.SessionEvent, partyIDs []string) {
b.mu.RLock()
defer b.mu.RUnlock()
sentCount := 0
missedParties := []string{}
for _, partyID := range partyIDs {
if ch, exists := b.subscribers[partyID]; exists {
// Non-blocking send
select {
case ch <- event:
sentCount++
default:
// Channel full, skip this subscriber
missedParties = append(missedParties, partyID+" (channel full)")
}
} else {
// Party not subscribed - this is a problem for session_started events!
missedParties = append(missedParties, partyID+" (not subscribed)")
}
}
// Log if any parties were missed (helps debug event delivery issues)
if len(missedParties) > 0 {
logger.Warn("Some parties missed session event broadcast",
zap.String("event_type", event.EventType),
zap.String("session_id", event.SessionId),
zap.Int("sent_count", sentCount),
zap.Int("missed_count", len(missedParties)),
zap.Strings("missed_parties", missedParties))
}
}
// SubscriberCount returns the number of active subscribers
func (b *SessionEventBroadcaster) SubscriberCount() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.subscribers)
}