131 lines
4.1 KiB
Go
131 lines
4.1 KiB
Go
package domain
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
|
|
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// SessionEventBroadcaster manages session event subscriptions and broadcasting
|
|
type SessionEventBroadcaster struct {
|
|
subscribers map[string]chan *pb.SessionEvent // partyID -> event channel
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewSessionEventBroadcaster creates a new session event broadcaster
|
|
func NewSessionEventBroadcaster() *SessionEventBroadcaster {
|
|
return &SessionEventBroadcaster{
|
|
subscribers: make(map[string]chan *pb.SessionEvent),
|
|
}
|
|
}
|
|
|
|
// Subscribe subscribes a party to session events
|
|
// Returns the channel for receiving events and a unique subscription ID
|
|
// The subscription ID is used to safely unsubscribe without affecting newer subscriptions
|
|
func (b *SessionEventBroadcaster) Subscribe(partyID string) (<-chan *pb.SessionEvent, int64) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
// Close existing channel if party is re-subscribing (e.g., after reconnect)
|
|
// This will cause the old gRPC stream to exit cleanly
|
|
if oldCh, exists := b.subscribers[partyID]; exists {
|
|
close(oldCh)
|
|
logger.Debug("Closed old subscription channel for re-subscribing party",
|
|
zap.String("party_id", partyID))
|
|
}
|
|
|
|
// Create buffered channel for this subscriber
|
|
ch := make(chan *pb.SessionEvent, 100)
|
|
b.subscribers[partyID] = ch
|
|
|
|
// Generate a unique subscription ID (using current time in nanoseconds)
|
|
subscriptionID := time.Now().UnixNano()
|
|
|
|
return ch, subscriptionID
|
|
}
|
|
|
|
// Unsubscribe removes a party's subscription only if the channel matches
|
|
// This prevents a race condition where a newer subscription is accidentally removed
|
|
// when an old gRPC stream exits after the party has already re-subscribed
|
|
func (b *SessionEventBroadcaster) Unsubscribe(partyID string, ch <-chan *pb.SessionEvent) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if currentCh, exists := b.subscribers[partyID]; exists {
|
|
// Only delete if the channel matches (i.e., this is still our subscription)
|
|
// If the channel doesn't match, a newer subscription has been created
|
|
// and we should not delete it
|
|
if currentCh == ch {
|
|
// Don't close the channel here - it was already closed by Subscribe
|
|
// when the new subscription was created, or we're the last one
|
|
delete(b.subscribers, partyID)
|
|
logger.Debug("Unsubscribed party from session events",
|
|
zap.String("party_id", partyID))
|
|
} else {
|
|
logger.Debug("Skipping unsubscribe - channel mismatch (newer subscription exists)",
|
|
zap.String("party_id", partyID))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Broadcast sends an event to all subscribers
|
|
func (b *SessionEventBroadcaster) Broadcast(event *pb.SessionEvent) {
|
|
b.mu.RLock()
|
|
defer b.mu.RUnlock()
|
|
|
|
for _, ch := range b.subscribers {
|
|
// Non-blocking send to prevent slow subscribers from blocking
|
|
select {
|
|
case ch <- event:
|
|
default:
|
|
// Channel full, skip this subscriber
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastToParties sends an event to specific parties only
|
|
func (b *SessionEventBroadcaster) BroadcastToParties(event *pb.SessionEvent, partyIDs []string) {
|
|
b.mu.RLock()
|
|
defer b.mu.RUnlock()
|
|
|
|
sentCount := 0
|
|
missedParties := []string{}
|
|
|
|
for _, partyID := range partyIDs {
|
|
if ch, exists := b.subscribers[partyID]; exists {
|
|
// Non-blocking send
|
|
select {
|
|
case ch <- event:
|
|
sentCount++
|
|
default:
|
|
// Channel full, skip this subscriber
|
|
missedParties = append(missedParties, partyID+" (channel full)")
|
|
}
|
|
} else {
|
|
// Party not subscribed - this is a problem for session_started events!
|
|
missedParties = append(missedParties, partyID+" (not subscribed)")
|
|
}
|
|
}
|
|
|
|
// Log if any parties were missed (helps debug event delivery issues)
|
|
if len(missedParties) > 0 {
|
|
logger.Warn("Some parties missed session event broadcast",
|
|
zap.String("event_type", event.EventType),
|
|
zap.String("session_id", event.SessionId),
|
|
zap.Int("sent_count", sentCount),
|
|
zap.Int("missed_count", len(missedParties)),
|
|
zap.Strings("missed_parties", missedParties))
|
|
}
|
|
}
|
|
|
|
// SubscriberCount returns the number of active subscribers
|
|
func (b *SessionEventBroadcaster) SubscriberCount() int {
|
|
b.mu.RLock()
|
|
defer b.mu.RUnlock()
|
|
|
|
return len(b.subscribers)
|
|
}
|