package use_cases import ( "context" "errors" "math/big" "time" "github.com/google/uuid" "github.com/rwadurian/mpc-system/pkg/crypto" "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/pkg/tss" "github.com/rwadurian/mpc-system/services/server-party/domain/repositories" "go.uber.org/zap" ) var ( ErrSigningFailed = errors.New("signing failed") ErrSigningTimeout = errors.New("signing timeout") ErrKeyShareNotFound = errors.New("key share not found") ErrInvalidSignSession = errors.New("invalid sign session") ) // ParticipateSigningInput contains input for signing participation type ParticipateSigningInput struct { SessionID uuid.UUID PartyID string JoinToken string MessageHash []byte } // ParticipateSigningOutput contains output from signing participation type ParticipateSigningOutput struct { Success bool Signature []byte R *big.Int S *big.Int } // ParticipateSigningUseCase handles signing participation type ParticipateSigningUseCase struct { keyShareRepo repositories.KeyShareRepository sessionClient SessionCoordinatorClient messageRouter MessageRouterClient cryptoService *crypto.CryptoService } // NewParticipateSigningUseCase creates a new participate signing use case func NewParticipateSigningUseCase( keyShareRepo repositories.KeyShareRepository, sessionClient SessionCoordinatorClient, messageRouter MessageRouterClient, cryptoService *crypto.CryptoService, ) *ParticipateSigningUseCase { return &ParticipateSigningUseCase{ keyShareRepo: keyShareRepo, sessionClient: sessionClient, messageRouter: messageRouter, cryptoService: cryptoService, } } // Execute participates in a signing session using real TSS protocol func (uc *ParticipateSigningUseCase) Execute( ctx context.Context, input ParticipateSigningInput, ) (*ParticipateSigningOutput, error) { // 1. Join session via coordinator sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken) if err != nil { return nil, err } if sessionInfo.SessionType != "sign" { return nil, ErrInvalidSignSession } // 2. Load key share for this party keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID) if err != nil || len(keyShares) == 0 { return nil, ErrKeyShareNotFound } // Use the most recent key share (in production, would match by public key or session reference) keyShare := keyShares[len(keyShares)-1] // 3. Decrypt share data shareData, err := uc.cryptoService.DecryptShare(keyShare.ShareData, input.PartyID) if err != nil { return nil, err } // 4. Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) for _, p := range sessionInfo.Participants { partyIndexMap[p.PartyID] = p.PartyIndex if p.PartyID == input.PartyID { selfIndex = p.PartyIndex } } // 5. Subscribe to messages msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID) if err != nil { return nil, err } // Use message hash from session if not provided messageHash := input.MessageHash if len(messageHash) == 0 { messageHash = sessionInfo.MessageHash } // 6. Run TSS Signing protocol signature, r, s, err := uc.runSigningProtocol( ctx, input.SessionID, input.PartyID, selfIndex, sessionInfo.Participants, sessionInfo.ThresholdT, shareData, messageHash, msgChan, partyIndexMap, ) if err != nil { return nil, err } // 7. Update key share last used keyShare.MarkUsed() if err := uc.keyShareRepo.Update(ctx, keyShare); err != nil { logger.Warn("failed to update key share last used", zap.Error(err)) } // 8. Report completion to coordinator if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil { logger.Error("failed to report signing completion", zap.Error(err)) } return &ParticipateSigningOutput{ Success: true, Signature: signature, R: r, S: s, }, nil } // runSigningProtocol runs the TSS signing protocol using tss-lib func (uc *ParticipateSigningUseCase) runSigningProtocol( ctx context.Context, sessionID uuid.UUID, partyID string, selfIndex int, participants []ParticipantInfo, t int, shareData []byte, messageHash []byte, msgChan <-chan *MPCMessage, partyIndexMap map[string]int, ) ([]byte, *big.Int, *big.Int, error) { logger.Info("Running signing protocol", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID), zap.Int("self_index", selfIndex), zap.Int("t", t), zap.Int("message_hash_len", len(messageHash))) // Create message handler adapter msgHandler := &signingMessageHandler{ sessionID: sessionID, partyID: partyID, messageRouter: uc.messageRouter, msgChan: make(chan *tss.ReceivedMessage, 100), partyIndexMap: partyIndexMap, } // Start message conversion goroutine go msgHandler.convertMessages(ctx, msgChan) // Create signing config config := tss.SigningConfig{ Threshold: t, TotalSigners: len(participants), Timeout: 5 * time.Minute, } // Create party list allParties := make([]tss.SigningParty, len(participants)) for i, p := range participants { allParties[i] = tss.SigningParty{ PartyID: p.PartyID, PartyIndex: p.PartyIndex, } } selfParty := tss.SigningParty{ PartyID: partyID, PartyIndex: selfIndex, } // Create signing session session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler) if err != nil { return nil, nil, nil, err } // Run signing result, err := session.Start(ctx) if err != nil { return nil, nil, nil, err } logger.Info("Signing completed successfully", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) return result.Signature, result.R, result.S, nil } // signingMessageHandler adapts MPCMessage channel to tss.MessageHandler type signingMessageHandler struct { sessionID uuid.UUID partyID string messageRouter MessageRouterClient msgChan chan *tss.ReceivedMessage partyIndexMap map[string]int } func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error { return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes) } func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage { return h.msgChan } func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) { for { select { case <-ctx.Done(): close(h.msgChan) return case msg, ok := <-inChan: if !ok { close(h.msgChan) return } fromIndex, exists := h.partyIndexMap[msg.FromParty] if !exists { continue } tssMsg := &tss.ReceivedMessage{ FromPartyIndex: fromIndex, IsBroadcast: msg.IsBroadcast, MsgBytes: msg.Payload, } select { case h.msgChan <- tssMsg: case <-ctx.Done(): return } } } }