rwadurian/backend/mpc-system/services/message-router/application/use_cases/route_message.go

171 lines
4.7 KiB
Go

package use_cases
import (
"context"
"errors"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
"go.uber.org/zap"
)
var (
ErrInvalidSessionID = errors.New("invalid session ID")
ErrInvalidPartyID = errors.New("invalid party ID")
ErrEmptyPayload = errors.New("empty payload")
)
// RouteMessageInput contains input for routing a message
type RouteMessageInput struct {
SessionID string
FromParty string
ToParties []string // nil/empty means broadcast
RoundNumber int
MessageType string
Payload []byte
}
// RouteMessageOutput contains output from routing a message
type RouteMessageOutput struct {
MessageID string
Success bool
}
// MessageBroker defines the interface for message delivery
type MessageBroker interface {
// PublishToParty publishes a message to a specific party
PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error
// PublishToSession publishes a message to all parties in a session (except sender)
PublishToSession(ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO) error
}
// RouteMessageUseCase handles message routing
type RouteMessageUseCase struct {
messageRepo repositories.MessageRepository
messageBroker MessageBroker
}
// NewRouteMessageUseCase creates a new route message use case
func NewRouteMessageUseCase(
messageRepo repositories.MessageRepository,
messageBroker MessageBroker,
) *RouteMessageUseCase {
return &RouteMessageUseCase{
messageRepo: messageRepo,
messageBroker: messageBroker,
}
}
// Execute routes an MPC message
func (uc *RouteMessageUseCase) Execute(ctx context.Context, input RouteMessageInput) (*RouteMessageOutput, error) {
// Validate input
sessionID, err := uuid.Parse(input.SessionID)
if err != nil {
return nil, ErrInvalidSessionID
}
if input.FromParty == "" {
return nil, ErrInvalidPartyID
}
if len(input.Payload) == 0 {
return nil, ErrEmptyPayload
}
// Create message entity
msg := entities.NewMPCMessage(
sessionID,
input.FromParty,
input.ToParties,
input.RoundNumber,
input.MessageType,
input.Payload,
)
// Persist message for reliability (offline scenarios)
if err := uc.messageRepo.Save(ctx, msg); err != nil {
logger.Error("failed to save message", zap.Error(err))
return nil, err
}
// Route message
dto := msg.ToDTO()
if msg.IsBroadcast() {
// Broadcast to all parties except sender
if err := uc.messageBroker.PublishToSession(ctx, input.SessionID, input.FromParty, &dto); err != nil {
logger.Error("failed to broadcast message",
zap.String("session_id", input.SessionID),
zap.Error(err))
// Don't fail - message is persisted and can be retrieved via polling
}
} else {
// Unicast to specific parties
for _, toParty := range input.ToParties {
if err := uc.messageBroker.PublishToParty(ctx, toParty, &dto); err != nil {
logger.Error("failed to send message to party",
zap.String("party_id", toParty),
zap.Error(err))
// Don't fail - continue sending to other parties
}
}
}
return &RouteMessageOutput{
MessageID: msg.ID.String(),
Success: true,
}, nil
}
// GetPendingMessagesInput contains input for getting pending messages
type GetPendingMessagesInput struct {
SessionID string
PartyID string
AfterTimestamp int64
}
// GetPendingMessagesUseCase retrieves pending messages for a party
type GetPendingMessagesUseCase struct {
messageRepo repositories.MessageRepository
}
// NewGetPendingMessagesUseCase creates a new get pending messages use case
func NewGetPendingMessagesUseCase(messageRepo repositories.MessageRepository) *GetPendingMessagesUseCase {
return &GetPendingMessagesUseCase{
messageRepo: messageRepo,
}
}
// Execute retrieves pending messages
func (uc *GetPendingMessagesUseCase) Execute(ctx context.Context, input GetPendingMessagesInput) ([]*entities.MessageDTO, error) {
sessionID, err := uuid.Parse(input.SessionID)
if err != nil {
return nil, ErrInvalidSessionID
}
if input.PartyID == "" {
return nil, ErrInvalidPartyID
}
afterTime := time.Time{}
if input.AfterTimestamp > 0 {
afterTime = time.UnixMilli(input.AfterTimestamp)
}
messages, err := uc.messageRepo.GetPendingMessages(ctx, sessionID, input.PartyID, afterTime)
if err != nil {
return nil, err
}
// Convert to DTOs
dtos := make([]*entities.MessageDTO, len(messages))
for i, msg := range messages {
dto := msg.ToDTO()
dtos[i] = &dto
}
return dtos, nil
}