package use_cases import ( "context" "errors" "time" "github.com/google/uuid" "github.com/rwadurian/mpc-system/pkg/crypto" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/pkg/tss" "github.com/rwadurian/mpc-system/services/server-party/domain/entities" "github.com/rwadurian/mpc-system/services/server-party/domain/repositories" "go.uber.org/zap" ) var ( ErrKeygenFailed = errors.New("keygen failed") ErrKeygenTimeout = errors.New("keygen timeout") ErrInvalidSession = errors.New("invalid session") ErrShareSaveFailed = errors.New("failed to save share") ) // ParticipateKeygenInput contains input for keygen participation type ParticipateKeygenInput struct { SessionID uuid.UUID PartyID string JoinToken string } // ParticipateKeygenOutput contains output from keygen participation type ParticipateKeygenOutput struct { Success bool KeyShare *entities.PartyKeyShare PublicKey []byte } // SessionCoordinatorClient defines the interface for session coordinator communication type SessionCoordinatorClient interface { JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error) ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error } // MessageRouterClient defines the interface for message router communication type MessageRouterClient interface { RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error) } // SessionInfo contains session information from coordinator type SessionInfo struct { SessionID uuid.UUID SessionType string ThresholdN int ThresholdT int MessageHash []byte Participants []ParticipantInfo } // ParticipantInfo contains participant information type ParticipantInfo struct { PartyID string PartyIndex int } // MPCMessage represents an MPC message from the router type MPCMessage struct { FromParty string IsBroadcast bool RoundNumber int Payload []byte } // ParticipateKeygenUseCase handles keygen participation type ParticipateKeygenUseCase struct { keyShareRepo repositories.KeyShareRepository sessionClient SessionCoordinatorClient messageRouter MessageRouterClient cryptoService *crypto.CryptoService } // NewParticipateKeygenUseCase creates a new participate keygen use case func NewParticipateKeygenUseCase( keyShareRepo repositories.KeyShareRepository, sessionClient SessionCoordinatorClient, messageRouter MessageRouterClient, cryptoService *crypto.CryptoService, ) *ParticipateKeygenUseCase { return &ParticipateKeygenUseCase{ keyShareRepo: keyShareRepo, sessionClient: sessionClient, messageRouter: messageRouter, cryptoService: cryptoService, } } // Execute participates in a keygen session using real TSS protocol func (uc *ParticipateKeygenUseCase) Execute( ctx context.Context, input ParticipateKeygenInput, ) (*ParticipateKeygenOutput, error) { // 1. Join session via coordinator sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken) if err != nil { return nil, err } if sessionInfo.SessionType != "keygen" { return nil, ErrInvalidSession } // 2. Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) for _, p := range sessionInfo.Participants { partyIndexMap[p.PartyID] = p.PartyIndex if p.PartyID == input.PartyID { selfIndex = p.PartyIndex } } // 3. Subscribe to messages msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID) if err != nil { return nil, err } // 4. Run TSS Keygen protocol saveData, publicKey, err := uc.runKeygenProtocol( ctx, input.SessionID, input.PartyID, selfIndex, sessionInfo.Participants, sessionInfo.ThresholdN, sessionInfo.ThresholdT, msgChan, partyIndexMap, ) if err != nil { return nil, err } // 5. Encrypt and save the share encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID) if err != nil { return nil, err } keyShare := entities.NewPartyKeyShare( input.PartyID, selfIndex, input.SessionID, sessionInfo.ThresholdN, sessionInfo.ThresholdT, encryptedShare, publicKey, ) if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil { return nil, ErrShareSaveFailed } // 6. Report completion to coordinator if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil { logger.Error("failed to report completion", zap.Error(err)) // Don't fail - share is saved } return &ParticipateKeygenOutput{ Success: true, KeyShare: keyShare, PublicKey: publicKey, }, nil } // runKeygenProtocol runs the TSS keygen protocol using tss-lib func (uc *ParticipateKeygenUseCase) runKeygenProtocol( ctx context.Context, sessionID uuid.UUID, partyID string, selfIndex int, participants []ParticipantInfo, n, t int, msgChan <-chan *MPCMessage, partyIndexMap map[string]int, ) ([]byte, []byte, error) { logger.Info("Running keygen protocol", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.Int("self_index", selfIndex), zap.Int("n", n), zap.Int("t", t)) // Create message handler adapter msgHandler := &keygenMessageHandler{ sessionID: sessionID, partyID: partyID, messageRouter: uc.messageRouter, msgChan: make(chan *tss.ReceivedMessage, 100), partyIndexMap: partyIndexMap, } // Start message conversion goroutine go msgHandler.convertMessages(ctx, msgChan) // Create keygen config config := tss.KeygenConfig{ Threshold: t, TotalParties: n, Timeout: 10 * time.Minute, } // Create party list allParties := make([]tss.KeygenParty, len(participants)) for i, p := range participants { allParties[i] = tss.KeygenParty{ PartyID: p.PartyID, PartyIndex: p.PartyIndex, } } selfParty := tss.KeygenParty{ PartyID: partyID, PartyIndex: selfIndex, } // Create keygen session session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler) if err != nil { return nil, nil, err } // Run keygen result, err := session.Start(ctx) if err != nil { return nil, nil, err } logger.Info("Keygen completed successfully", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) return result.LocalPartySaveData, result.PublicKeyBytes, nil } // keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler type keygenMessageHandler struct { sessionID uuid.UUID partyID string messageRouter MessageRouterClient msgChan chan *tss.ReceivedMessage partyIndexMap map[string]int } func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error { return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes) } func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage { return h.msgChan } func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) { for { select { case <-ctx.Done(): close(h.msgChan) return case msg, ok := <-inChan: if !ok { close(h.msgChan) return } fromIndex, exists := h.partyIndexMap[msg.FromParty] if !exists { continue } tssMsg := &tss.ReceivedMessage{ FromPartyIndex: fromIndex, IsBroadcast: msg.IsBroadcast, MsgBytes: msg.Payload, } select { case h.msgChan <- tssMsg: case <-ctx.Done(): return } } } }