171 lines
4.7 KiB
Go
171 lines
4.7 KiB
Go
package use_cases
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain/repositories"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidSessionID = errors.New("invalid session ID")
|
|
ErrInvalidPartyID = errors.New("invalid party ID")
|
|
ErrEmptyPayload = errors.New("empty payload")
|
|
)
|
|
|
|
// RouteMessageInput contains input for routing a message
|
|
type RouteMessageInput struct {
|
|
SessionID string
|
|
FromParty string
|
|
ToParties []string // nil/empty means broadcast
|
|
RoundNumber int
|
|
MessageType string
|
|
Payload []byte
|
|
}
|
|
|
|
// RouteMessageOutput contains output from routing a message
|
|
type RouteMessageOutput struct {
|
|
MessageID string
|
|
Success bool
|
|
}
|
|
|
|
// MessageBroker defines the interface for message delivery
|
|
type MessageBroker interface {
|
|
// PublishToParty publishes a message to a specific party
|
|
PublishToParty(ctx context.Context, partyID string, message *entities.MessageDTO) error
|
|
// PublishToSession publishes a message to all parties in a session (except sender)
|
|
PublishToSession(ctx context.Context, sessionID string, excludeParty string, message *entities.MessageDTO) error
|
|
}
|
|
|
|
// RouteMessageUseCase handles message routing
|
|
type RouteMessageUseCase struct {
|
|
messageRepo repositories.MessageRepository
|
|
messageBroker MessageBroker
|
|
}
|
|
|
|
// NewRouteMessageUseCase creates a new route message use case
|
|
func NewRouteMessageUseCase(
|
|
messageRepo repositories.MessageRepository,
|
|
messageBroker MessageBroker,
|
|
) *RouteMessageUseCase {
|
|
return &RouteMessageUseCase{
|
|
messageRepo: messageRepo,
|
|
messageBroker: messageBroker,
|
|
}
|
|
}
|
|
|
|
// Execute routes an MPC message
|
|
func (uc *RouteMessageUseCase) Execute(ctx context.Context, input RouteMessageInput) (*RouteMessageOutput, error) {
|
|
// Validate input
|
|
sessionID, err := uuid.Parse(input.SessionID)
|
|
if err != nil {
|
|
return nil, ErrInvalidSessionID
|
|
}
|
|
|
|
if input.FromParty == "" {
|
|
return nil, ErrInvalidPartyID
|
|
}
|
|
|
|
if len(input.Payload) == 0 {
|
|
return nil, ErrEmptyPayload
|
|
}
|
|
|
|
// Create message entity
|
|
msg := entities.NewMPCMessage(
|
|
sessionID,
|
|
input.FromParty,
|
|
input.ToParties,
|
|
input.RoundNumber,
|
|
input.MessageType,
|
|
input.Payload,
|
|
)
|
|
|
|
// Persist message for reliability (offline scenarios)
|
|
if err := uc.messageRepo.Save(ctx, msg); err != nil {
|
|
logger.Error("failed to save message", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
// Route message
|
|
dto := msg.ToDTO()
|
|
if msg.IsBroadcast() {
|
|
// Broadcast to all parties except sender
|
|
if err := uc.messageBroker.PublishToSession(ctx, input.SessionID, input.FromParty, &dto); err != nil {
|
|
logger.Error("failed to broadcast message",
|
|
zap.String("session_id", input.SessionID),
|
|
zap.Error(err))
|
|
// Don't fail - message is persisted and can be retrieved via polling
|
|
}
|
|
} else {
|
|
// Unicast to specific parties
|
|
for _, toParty := range input.ToParties {
|
|
if err := uc.messageBroker.PublishToParty(ctx, toParty, &dto); err != nil {
|
|
logger.Error("failed to send message to party",
|
|
zap.String("party_id", toParty),
|
|
zap.Error(err))
|
|
// Don't fail - continue sending to other parties
|
|
}
|
|
}
|
|
}
|
|
|
|
return &RouteMessageOutput{
|
|
MessageID: msg.ID.String(),
|
|
Success: true,
|
|
}, nil
|
|
}
|
|
|
|
// GetPendingMessagesInput contains input for getting pending messages
|
|
type GetPendingMessagesInput struct {
|
|
SessionID string
|
|
PartyID string
|
|
AfterTimestamp int64
|
|
}
|
|
|
|
// GetPendingMessagesUseCase retrieves pending messages for a party
|
|
type GetPendingMessagesUseCase struct {
|
|
messageRepo repositories.MessageRepository
|
|
}
|
|
|
|
// NewGetPendingMessagesUseCase creates a new get pending messages use case
|
|
func NewGetPendingMessagesUseCase(messageRepo repositories.MessageRepository) *GetPendingMessagesUseCase {
|
|
return &GetPendingMessagesUseCase{
|
|
messageRepo: messageRepo,
|
|
}
|
|
}
|
|
|
|
// Execute retrieves pending messages
|
|
func (uc *GetPendingMessagesUseCase) Execute(ctx context.Context, input GetPendingMessagesInput) ([]*entities.MessageDTO, error) {
|
|
sessionID, err := uuid.Parse(input.SessionID)
|
|
if err != nil {
|
|
return nil, ErrInvalidSessionID
|
|
}
|
|
|
|
if input.PartyID == "" {
|
|
return nil, ErrInvalidPartyID
|
|
}
|
|
|
|
afterTime := time.Time{}
|
|
if input.AfterTimestamp > 0 {
|
|
afterTime = time.UnixMilli(input.AfterTimestamp)
|
|
}
|
|
|
|
messages, err := uc.messageRepo.GetPendingMessages(ctx, sessionID, input.PartyID, afterTime)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert to DTOs
|
|
dtos := make([]*entities.MessageDTO, len(messages))
|
|
for i, msg := range messages {
|
|
dto := msg.ToDTO()
|
|
dtos[i] = &dto
|
|
}
|
|
|
|
return dtos, nil
|
|
}
|