package use_cases import ( "context" "errors" "time" "github.com/google/uuid" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/services/message-router/domain/entities" "github.com/rwadurian/mpc-system/services/message-router/domain/repositories" "go.uber.org/zap" ) var ( ErrInvalidSessionID = errors.New("invalid session ID") ErrInvalidPartyID = errors.New("invalid party ID") ErrEmptyPayload = errors.New("empty payload") ) // RouteMessageInput contains input for routing a message type RouteMessageInput struct { SessionID string FromParty string ToParties []string // nil/empty means broadcast RoundNumber int MessageType string Payload []byte } // RouteMessageOutput contains output from routing a message type RouteMessageOutput struct { MessageID string Success bool } // MessageBroker defines the interface for message delivery type MessageBroker interface { // PublishToParty publishes a message to a specific party PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error // PublishToSession publishes a message to all parties in a session (except sender) PublishToSession(ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO) error } // RouteMessageUseCase handles message routing type RouteMessageUseCase struct { messageRepo repositories.MessageRepository messageBroker MessageBroker } // NewRouteMessageUseCase creates a new route message use case func NewRouteMessageUseCase( messageRepo repositories.MessageRepository, messageBroker MessageBroker, ) *RouteMessageUseCase { return &RouteMessageUseCase{ messageRepo: messageRepo, messageBroker: messageBroker, } } // Execute routes an MPC message func (uc *RouteMessageUseCase) Execute(ctx context.Context, input RouteMessageInput) (*RouteMessageOutput, error) { // Validate input sessionID, err := uuid.Parse(input.SessionID) if err != nil { return nil, ErrInvalidSessionID } if input.FromParty == "" { return nil, ErrInvalidPartyID } if len(input.Payload) == 0 { return nil, ErrEmptyPayload } // Create message entity msg := entities.NewMPCMessage( sessionID, input.FromParty, input.ToParties, input.RoundNumber, input.MessageType, input.Payload, ) // Persist message for reliability (offline scenarios) if err := uc.messageRepo.Save(ctx, msg); err != nil { logger.Error("failed to save message", zap.Error(err)) return nil, err } // Route message dto := msg.ToDTO() if msg.IsBroadcast() { // Broadcast to all parties except sender if err := uc.messageBroker.PublishToSession(ctx, input.SessionID, input.FromParty, &dto); err != nil { logger.Error("failed to broadcast message", zap.String("session_id", input.SessionID), zap.Error(err)) // Don't fail - message is persisted and can be retrieved via polling } } else { // Unicast to specific parties for _, toParty := range input.ToParties { if err := uc.messageBroker.PublishToParty(ctx, toParty, &dto); err != nil { logger.Error("failed to send message to party", zap.String("party_id", toParty), zap.Error(err)) // Don't fail - continue sending to other parties } } } return &RouteMessageOutput{ MessageID: msg.ID.String(), Success: true, }, nil } // GetPendingMessagesInput contains input for getting pending messages type GetPendingMessagesInput struct { SessionID string PartyID string AfterTimestamp int64 } // GetPendingMessagesUseCase retrieves pending messages for a party type GetPendingMessagesUseCase struct { messageRepo repositories.MessageRepository } // NewGetPendingMessagesUseCase creates a new get pending messages use case func NewGetPendingMessagesUseCase(messageRepo repositories.MessageRepository) *GetPendingMessagesUseCase { return &GetPendingMessagesUseCase{ messageRepo: messageRepo, } } // Execute retrieves pending messages func (uc *GetPendingMessagesUseCase) Execute(ctx context.Context, input GetPendingMessagesInput) ([]*entities.MessageDTO, error) { sessionID, err := uuid.Parse(input.SessionID) if err != nil { return nil, ErrInvalidSessionID } if input.PartyID == "" { return nil, ErrInvalidPartyID } afterTime := time.Time{} if input.AfterTimestamp > 0 { afterTime = time.UnixMilli(input.AfterTimestamp) } messages, err := uc.messageRepo.GetPendingMessages(ctx, sessionID, input.PartyID, afterTime) if err != nil { return nil, err } // Convert to DTOs dtos := make([]*entities.MessageDTO, len(messages)) for i, msg := range messages { dto := msg.ToDTO() dtos[i] = &dto } return dtos, nil }