package use_cases import ( "context" "errors" "math/big" "time" "github.com/google/uuid" "github.com/rwadurian/mpc-system/pkg/crypto" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/pkg/tss" "github.com/rwadurian/mpc-system/services/server-party/domain/entities" "github.com/rwadurian/mpc-system/services/server-party/domain/repositories" "go.uber.org/zap" ) var ( ErrSigningFailed = errors.New("signing failed") ErrSigningTimeout = errors.New("signing timeout") ErrKeyShareNotFound = errors.New("key share not found") ErrInvalidSignSession = errors.New("invalid sign session") ) // ParticipateSigningInput contains input for signing participation type ParticipateSigningInput struct { SessionID uuid.UUID PartyID string JoinToken string MessageHash []byte // For delegate parties: encrypted share provided by user (not loaded from DB) UserShareData []byte } // ParticipateSigningOutput contains output from signing participation type ParticipateSigningOutput struct { Success bool Signature []byte R *big.Int S *big.Int } // ParticipateSigningUseCase handles signing participation type ParticipateSigningUseCase struct { keyShareRepo repositories.KeyShareRepository sessionClient SessionCoordinatorClient messageRouter MessageRouterClient cryptoService *crypto.CryptoService } // NewParticipateSigningUseCase creates a new participate signing use case func NewParticipateSigningUseCase( keyShareRepo repositories.KeyShareRepository, sessionClient SessionCoordinatorClient, messageRouter MessageRouterClient, cryptoService *crypto.CryptoService, ) *ParticipateSigningUseCase { return &ParticipateSigningUseCase{ keyShareRepo: keyShareRepo, sessionClient: sessionClient, messageRouter: messageRouter, cryptoService: cryptoService, } } // ExecuteWithSessionInfo participates in a signing session with pre-obtained SessionInfo // This is used by server-party-co-managed which has already called JoinSession in session_created phase // and receives session_started event when all participants have joined func (uc *ParticipateSigningUseCase) ExecuteWithSessionInfo( ctx context.Context, sessionID uuid.UUID, partyID string, sessionInfo *SessionInfo, ) (*ParticipateSigningOutput, error) { // Validate session type if sessionInfo.SessionType != "sign" && sessionInfo.SessionType != "co_managed_sign" { return nil, ErrInvalidSignSession } logger.Info("ExecuteWithSessionInfo: starting signing with pre-obtained session info", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.String("session_type", sessionInfo.SessionType), zap.Int("participants", len(sessionInfo.Participants))) // Delegate to the common execution logic (skipping JoinSession) return uc.executeWithSessionInfo(ctx, sessionID, partyID, sessionInfo) } // Execute participates in a signing session using real TSS protocol func (uc *ParticipateSigningUseCase) Execute( ctx context.Context, input ParticipateSigningInput, ) (*ParticipateSigningOutput, error) { // 1. Join session via coordinator sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken) if err != nil { return nil, err } if sessionInfo.SessionType != "sign" { return nil, ErrInvalidSignSession } // 2. Get share data - either from user input (delegate) or from database (persistent) var shareData []byte var keyShareForUpdate *entities.PartyKeyShare var originalThresholdN int // Original total parties from keygen if len(input.UserShareData) > 0 { // Delegate party: use share provided by user shareData, err = uc.cryptoService.DecryptShare(input.UserShareData, input.PartyID) if err != nil { return nil, err } // For delegate party, get threshold info from session originalThresholdN = sessionInfo.ThresholdN logger.Info("Using user-provided share (delegate party)", zap.String("party_id", input.PartyID), zap.String("session_id", input.SessionID.String())) } else { // Persistent party: load from database // If KeygenSessionID is provided, use it to load the specific share // Otherwise, use the most recent share (fallback for backward compatibility) if sessionInfo.KeygenSessionID != uuid.Nil { // Load the specific share for this keygen session keyShareForUpdate, err = uc.keyShareRepo.FindBySessionAndParty(ctx, sessionInfo.KeygenSessionID, input.PartyID) if err != nil { logger.Error("Failed to find keyshare for keygen session", zap.String("party_id", input.PartyID), zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()), zap.Error(err)) return nil, ErrKeyShareNotFound } logger.Info("Using specific keyshare by keygen_session_id", zap.String("party_id", input.PartyID), zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String())) } else { // Fallback: use the most recent key share // TODO: This should be removed once all signing sessions provide keygen_session_id keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID) if err != nil || len(keyShares) == 0 { return nil, ErrKeyShareNotFound } keyShareForUpdate = keyShares[len(keyShares)-1] logger.Warn("Using most recent keyshare (keygen_session_id not provided)", zap.String("party_id", input.PartyID), zap.String("fallback_session_id", keyShareForUpdate.SessionID.String())) } // Get original threshold_n from keygen originalThresholdN = keyShareForUpdate.ThresholdN // Decrypt share data shareData, err = uc.cryptoService.DecryptShare(keyShareForUpdate.ShareData, input.PartyID) if err != nil { return nil, err } logger.Info("Using database share (persistent party)", zap.String("party_id", input.PartyID), zap.String("session_id", input.SessionID.String()), zap.String("keygen_session_id", keyShareForUpdate.SessionID.String()), zap.Int("original_threshold_n", originalThresholdN), zap.Int("threshold_t", keyShareForUpdate.ThresholdT)) } // 4. Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) for _, p := range sessionInfo.Participants { partyIndexMap[p.PartyID] = p.PartyIndex if p.PartyID == input.PartyID { selfIndex = p.PartyIndex } } // 5. Subscribe to messages msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID) if err != nil { return nil, err } // Wait for all parties to subscribe before starting TSS protocol // This prevents race condition where some parties send broadcast messages // before others have subscribed to the session expectedParties := len(sessionInfo.Participants) logger.Info("Waiting for all parties to subscribe", zap.String("session_id", input.SessionID.String()), zap.String("party_id", input.PartyID), zap.Int("expected_parties", expectedParties)) time.Sleep(500 * time.Millisecond) // Use message hash from session if not provided messageHash := input.MessageHash if len(messageHash) == 0 { messageHash = sessionInfo.MessageHash } // 6. Run TSS Signing protocol signature, r, s, err := uc.runSigningProtocol( ctx, input.SessionID, input.PartyID, selfIndex, sessionInfo.Participants, sessionInfo.ThresholdT, originalThresholdN, shareData, messageHash, msgChan, partyIndexMap, ) if err != nil { return nil, err } // 7. Update key share last used (only for persistent parties) if keyShareForUpdate != nil { keyShareForUpdate.MarkUsed() if err := uc.keyShareRepo.Update(ctx, keyShareForUpdate); err != nil { logger.Warn("failed to update key share last used", zap.Error(err)) } } // 8. Report completion to coordinator if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil { logger.Error("failed to report signing completion", zap.Error(err)) } return &ParticipateSigningOutput{ Success: true, Signature: signature, R: r, S: s, }, nil } // executeWithSessionInfo is the internal logic for ExecuteWithSessionInfo (persistent party only) func (uc *ParticipateSigningUseCase) executeWithSessionInfo( ctx context.Context, sessionID uuid.UUID, partyID string, sessionInfo *SessionInfo, ) (*ParticipateSigningOutput, error) { // Get share data from database (persistent party only - used by server-party-co-managed) var shareData []byte var keyShareForUpdate *entities.PartyKeyShare var originalThresholdN int var err error // Load from database using KeygenSessionID if sessionInfo.KeygenSessionID != uuid.Nil { keyShareForUpdate, err = uc.keyShareRepo.FindBySessionAndParty(ctx, sessionInfo.KeygenSessionID, partyID) if err != nil { logger.Error("Failed to find keyshare for keygen session", zap.String("party_id", partyID), zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()), zap.Error(err)) return nil, ErrKeyShareNotFound } logger.Info("Using specific keyshare by keygen_session_id", zap.String("party_id", partyID), zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String())) } else { // Fallback: use the most recent key share keyShares, err := uc.keyShareRepo.ListByParty(ctx, partyID) if err != nil || len(keyShares) == 0 { return nil, ErrKeyShareNotFound } keyShareForUpdate = keyShares[len(keyShares)-1] logger.Warn("Using most recent keyshare (keygen_session_id not provided)", zap.String("party_id", partyID), zap.String("fallback_session_id", keyShareForUpdate.SessionID.String())) } originalThresholdN = keyShareForUpdate.ThresholdN shareData, err = uc.cryptoService.DecryptShare(keyShareForUpdate.ShareData, partyID) if err != nil { return nil, err } logger.Info("Using database share (persistent party)", zap.String("party_id", partyID), zap.String("session_id", sessionID.String()), zap.String("keygen_session_id", keyShareForUpdate.SessionID.String()), zap.Int("original_threshold_n", originalThresholdN), zap.Int("threshold_t", keyShareForUpdate.ThresholdT)) // Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) for _, p := range sessionInfo.Participants { partyIndexMap[p.PartyID] = p.PartyIndex if p.PartyID == partyID { selfIndex = p.PartyIndex } } // Subscribe to messages msgChan, err := uc.messageRouter.SubscribeMessages(ctx, sessionID, partyID) if err != nil { return nil, err } // Wait for all parties to subscribe expectedParties := len(sessionInfo.Participants) logger.Info("Waiting for all parties to subscribe", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.Int("expected_parties", expectedParties)) time.Sleep(500 * time.Millisecond) messageHash := sessionInfo.MessageHash // Run TSS Signing protocol signature, r, s, err := uc.runSigningProtocol( ctx, sessionID, partyID, selfIndex, sessionInfo.Participants, sessionInfo.ThresholdT, originalThresholdN, shareData, messageHash, msgChan, partyIndexMap, ) if err != nil { return nil, err } // Update key share last used if keyShareForUpdate != nil { keyShareForUpdate.MarkUsed() if err := uc.keyShareRepo.Update(ctx, keyShareForUpdate); err != nil { logger.Warn("failed to update key share last used", zap.Error(err)) } } // Report completion to coordinator if err := uc.sessionClient.ReportCompletion(ctx, sessionID, partyID, signature); err != nil { logger.Error("failed to report signing completion", zap.Error(err)) } return &ParticipateSigningOutput{ Success: true, Signature: signature, R: r, S: s, }, nil } // runSigningProtocol runs the TSS signing protocol using tss-lib func (uc *ParticipateSigningUseCase) runSigningProtocol( ctx context.Context, sessionID uuid.UUID, partyID string, selfIndex int, participants []ParticipantInfo, t int, n int, // Original total parties from keygen shareData []byte, messageHash []byte, msgChan <-chan *MPCMessage, partyIndexMap map[string]int, ) ([]byte, *big.Int, *big.Int, error) { logger.Info("Running signing protocol", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.Int("self_index", selfIndex), zap.Int("t", t), zap.Int("n", n), zap.Int("current_signers", len(participants)), zap.Int("message_hash_len", len(messageHash))) // Create message handler adapter msgHandler := &signingMessageHandler{ sessionID: sessionID, partyID: partyID, messageRouter: uc.messageRouter, msgChan: make(chan *tss.ReceivedMessage, 100), partyIndexMap: partyIndexMap, } // Start message conversion goroutine go msgHandler.convertMessages(ctx, msgChan) // Create signing config // IMPORTANT: TotalParties must be the original n from keygen, not current signers // For 2-of-3: t=2, n=3, but only 2 parties participate in signing config := tss.SigningConfig{ Threshold: t, TotalParties: n, // Original total from keygen TotalSigners: len(participants), Timeout: 5 * time.Minute, } // Create party list allParties := make([]tss.SigningParty, len(participants)) for i, p := range participants { allParties[i] = tss.SigningParty{ PartyID: p.PartyID, PartyIndex: p.PartyIndex, } } selfParty := tss.SigningParty{ PartyID: partyID, PartyIndex: selfIndex, } // Create signing session session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler) if err != nil { return nil, nil, nil, err } // Run signing result, err := session.Start(ctx) if err != nil { return nil, nil, nil, err } logger.Info("Signing completed successfully", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) return result.Signature, result.R, result.S, nil } // signingMessageHandler adapts MPCMessage channel to tss.MessageHandler type signingMessageHandler struct { sessionID uuid.UUID partyID string messageRouter MessageRouterClient msgChan chan *tss.ReceivedMessage partyIndexMap map[string]int } func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error { return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes) } func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage { return h.msgChan } func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) { logger.Debug("convertMessages started, waiting for messages", zap.String("session_id", h.sessionID.String()), zap.String("party_id", h.partyID)) for { select { case <-ctx.Done(): logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String())) close(h.msgChan) return case msg, ok := <-inChan: if !ok { logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String())) close(h.msgChan) return } logger.Debug("Received MPC message for conversion", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Bool("is_broadcast", msg.IsBroadcast), zap.Int("payload_size", len(msg.Payload))) fromIndex, exists := h.partyIndexMap[msg.FromParty] if !exists { logger.Warn("Message from unknown party - dropping", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Any("known_parties", h.partyIndexMap)) continue } tssMsg := &tss.ReceivedMessage{ FromPartyIndex: fromIndex, IsBroadcast: msg.IsBroadcast, MsgBytes: msg.Payload, } logger.Debug("Converted message, sending to TSS", zap.String("session_id", h.sessionID.String()), zap.String("from_party", msg.FromParty), zap.Int("from_index", fromIndex)) select { case h.msgChan <- tssMsg: logger.Debug("Message sent to TSS successfully", zap.String("session_id", h.sessionID.String())) case <-ctx.Done(): return } } } }