package use_cases import ( "context" "errors" "os" "time" "github.com/google/uuid" "github.com/rwadurian/mpc-system/pkg/crypto" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/pkg/tss" "github.com/rwadurian/mpc-system/services/server-party/domain/entities" "github.com/rwadurian/mpc-system/services/server-party/domain/repositories" "go.uber.org/zap" ) var ( ErrKeygenFailed = errors.New("keygen failed") ErrKeygenTimeout = errors.New("keygen timeout") ErrInvalidSession = errors.New("invalid session") ErrShareSaveFailed = errors.New("failed to save share") ) // ParticipateKeygenInput contains input for keygen participation type ParticipateKeygenInput struct { SessionID uuid.UUID PartyID string JoinToken string } // ParticipateKeygenOutput contains output from keygen participation type ParticipateKeygenOutput struct { Success bool KeyShare *entities.PartyKeyShare PublicKey []byte ShareForUser []byte // For delegate parties: encrypted share to return to user (not saved to DB) } // SessionCoordinatorClient defines the interface for session coordinator communication type SessionCoordinatorClient interface { JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error) ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error } // MessageRouterClient defines the interface for message router communication type MessageRouterClient interface { RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error) } // SessionInfo contains session information from coordinator type SessionInfo struct { SessionID uuid.UUID SessionType string ThresholdN int ThresholdT int MessageHash []byte KeygenSessionID uuid.UUID // For signing sessions: which keygen session's share to use Participants []ParticipantInfo } // ParticipantInfo contains participant information type ParticipantInfo struct { PartyID string PartyIndex int } // MPCMessage represents an MPC message from the router type MPCMessage struct { MessageID string // Unique message ID for acknowledgment SessionID string // Session ID for acknowledgment FromParty string IsBroadcast bool RoundNumber int Payload []byte } // ParticipateKeygenUseCase handles keygen participation type ParticipateKeygenUseCase struct { keyShareRepo repositories.KeyShareRepository sessionClient SessionCoordinatorClient messageRouter MessageRouterClient cryptoService *crypto.CryptoService } // NewParticipateKeygenUseCase creates a new participate keygen use case func NewParticipateKeygenUseCase( keyShareRepo repositories.KeyShareRepository, sessionClient SessionCoordinatorClient, messageRouter MessageRouterClient, cryptoService *crypto.CryptoService, ) *ParticipateKeygenUseCase { return &ParticipateKeygenUseCase{ keyShareRepo: keyShareRepo, sessionClient: sessionClient, messageRouter: messageRouter, cryptoService: cryptoService, } } // Execute participates in a keygen session using real TSS protocol func (uc *ParticipateKeygenUseCase) Execute( ctx context.Context, input ParticipateKeygenInput, ) (*ParticipateKeygenOutput, error) { // 1. Join session via coordinator sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken) if err != nil { return nil, err } if sessionInfo.SessionType != "keygen" { return nil, ErrInvalidSession } // 2. Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) for _, p := range sessionInfo.Participants { partyIndexMap[p.PartyID] = p.PartyIndex if p.PartyID == input.PartyID { selfIndex = p.PartyIndex } logger.Debug("Added participant to index map", zap.String("party_id", p.PartyID), zap.Int("party_index", p.PartyIndex)) } logger.Info("Built party index map", zap.String("session_id", input.SessionID.String()), zap.String("self_party_id", input.PartyID), zap.Int("self_index", selfIndex), zap.Int("total_participants", len(sessionInfo.Participants))) // 3. Subscribe to messages msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID) if err != nil { return nil, err } // 4. Run TSS Keygen protocol saveData, publicKey, err := uc.runKeygenProtocol( ctx, input.SessionID, input.PartyID, selfIndex, sessionInfo.Participants, sessionInfo.ThresholdN, sessionInfo.ThresholdT, msgChan, partyIndexMap, ) if err != nil { return nil, err } // 5. Encrypt the share encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID) if err != nil { return nil, err } keyShare := entities.NewPartyKeyShare( input.PartyID, selfIndex, input.SessionID, sessionInfo.ThresholdN, sessionInfo.ThresholdT, encryptedShare, publicKey, ) // 6. Handle share based on party role partyRole := uc.getPartyRole() var shareForUser []byte switch partyRole { case "persistent": // Persistent Party: save share to database if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil { return nil, ErrShareSaveFailed } logger.Info("Share saved to database (persistent party)", zap.String("party_id", input.PartyID), zap.String("session_id", input.SessionID.String())) case "delegate": // Delegate Party: do NOT save to database, return to user shareForUser = encryptedShare logger.Info("Share NOT saved, will be returned to user (delegate party)", zap.String("party_id", input.PartyID), zap.String("session_id", input.SessionID.String()), zap.Int("share_size", len(shareForUser))) case "temporary": // Temporary Party: optionally save to temp storage (not implemented yet) logger.Info("Temporary party - share not saved", zap.String("party_id", input.PartyID)) default: // Default to persistent for safety if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil { return nil, ErrShareSaveFailed } logger.Warn("Unknown party role, defaulting to persistent", zap.String("party_id", input.PartyID), zap.String("role", partyRole)) } // 7. Report completion to coordinator if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil { logger.Error("failed to report completion", zap.Error(err)) // Don't fail - share is handled } return &ParticipateKeygenOutput{ Success: true, KeyShare: keyShare, PublicKey: publicKey, ShareForUser: shareForUser, // Only populated for delegate parties }, nil } // runKeygenProtocol runs the TSS keygen protocol using tss-lib func (uc *ParticipateKeygenUseCase) runKeygenProtocol( ctx context.Context, sessionID uuid.UUID, partyID string, selfIndex int, participants []ParticipantInfo, n, t int, msgChan <-chan *MPCMessage, partyIndexMap map[string]int, ) ([]byte, []byte, error) { logger.Info("Running keygen protocol", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.Int("self_index", selfIndex), zap.Int("n", n), zap.Int("t", t)) // Create message handler adapter msgHandler := &keygenMessageHandler{ sessionID: sessionID, partyID: partyID, messageRouter: uc.messageRouter, msgChan: make(chan *tss.ReceivedMessage, 100), partyIndexMap: partyIndexMap, } // Start message conversion goroutine go msgHandler.convertMessages(ctx, msgChan) // Create keygen config config := tss.KeygenConfig{ Threshold: t, TotalParties: n, Timeout: 10 * time.Minute, } // Create party list allParties := make([]tss.KeygenParty, len(participants)) for i, p := range participants { allParties[i] = tss.KeygenParty{ PartyID: p.PartyID, PartyIndex: p.PartyIndex, } } selfParty := tss.KeygenParty{ PartyID: partyID, PartyIndex: selfIndex, } // Create keygen session session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler) if err != nil { return nil, nil, err } // Run keygen result, err := session.Start(ctx) if err != nil { return nil, nil, err } logger.Info("Keygen completed successfully", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) return result.LocalPartySaveData, result.PublicKeyBytes, nil } // keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler type keygenMessageHandler struct { sessionID uuid.UUID partyID string messageRouter MessageRouterClient msgChan chan *tss.ReceivedMessage partyIndexMap map[string]int } func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error { return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes) } func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage { return h.msgChan } func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) { for { select { case <-ctx.Done(): logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String())) close(h.msgChan) return case msg, ok := <-inChan: if !ok { logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String())) close(h.msgChan) return } logger.Debug("Received MPC message for conversion", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Bool("is_broadcast", msg.IsBroadcast), zap.Int("payload_size", len(msg.Payload))) fromIndex, exists := h.partyIndexMap[msg.FromParty] if !exists { logger.Warn("Message from unknown party - dropping", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Any("known_parties", h.partyIndexMap)) continue } tssMsg := &tss.ReceivedMessage{ FromPartyIndex: fromIndex, IsBroadcast: msg.IsBroadcast, MsgBytes: msg.Payload, } logger.Debug("Converted message, sending to TSS", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Int("from_index", fromIndex)) select { case h.msgChan <- tssMsg: logger.Debug("Message sent to TSS successfully", zap.String("session_id", h.sessionID.String())) case <-ctx.Done(): return } } } } // getPartyRole gets the party role from environment variable // Returns "persistent" (default), "delegate", or "temporary" func (uc *ParticipateKeygenUseCase) getPartyRole() string { role := os.Getenv("PARTY_ROLE") if role == "" { return "persistent" // Default to persistent for safety } return role }