package tss import ( "context" "encoding/json" "errors" "fmt" "math/big" "strings" "sync" "time" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" ) var ( ErrSigningTimeout = errors.New("signing timeout") ErrSigningFailed = errors.New("signing failed") ErrInvalidSignerCount = errors.New("invalid signer count") ErrInvalidShareData = errors.New("invalid share data") ) // SigningResult contains the result of a signing operation type SigningResult struct { // Signature is the full ECDSA signature (R || S) Signature []byte // R is the R component of the signature R *big.Int // S is the S component of the signature S *big.Int // RecoveryID is the recovery ID for ecrecover RecoveryID int } // SigningParty represents a party participating in signing type SigningParty struct { PartyID string PartyIndex int } // SigningConfig contains configuration for signing type SigningConfig struct { Threshold int // t in t-of-n threshold value from keygen TotalParties int // n in t-of-n - total parties from keygen (NOT current signers) TotalSigners int // Number of parties participating in this signing session Timeout time.Duration // Signing timeout } // SigningSession manages a signing session for a single party type SigningSession struct { config SigningConfig selfParty SigningParty allParties []SigningParty messageHash *big.Int saveData *keygen.LocalPartySaveData tssPartyIDs []*tss.PartyID selfTSSID *tss.PartyID params *tss.Parameters localParty tss.Party outCh chan tss.Message endCh chan *common.SignatureData errCh chan error msgHandler MessageHandler mu sync.Mutex started bool // keygenIndexToSortedIndex maps keygen party index to sorted array index // This is needed because TSS messages use keygen index, but tssPartyIDs is sorted keygenIndexToSortedIndex map[int]int } // NewSigningSession creates a new signing session func NewSigningSession( config SigningConfig, selfParty SigningParty, allParties []SigningParty, messageHash []byte, saveDataBytes []byte, msgHandler MessageHandler, ) (*SigningSession, error) { if config.TotalSigners < config.Threshold { return nil, ErrInvalidSignerCount } if len(allParties) != config.TotalSigners { return nil, ErrInvalidSignerCount } // Deserialize save data var saveData keygen.LocalPartySaveData if err := json.Unmarshal(saveDataBytes, &saveData); err != nil { return nil, fmt.Errorf("%w: %v", ErrInvalidShareData, err) } // Create TSS party IDs for signers tssPartyIDs := make([]*tss.PartyID, len(allParties)) var selfTSSID *tss.PartyID for i, p := range allParties { partyID := tss.NewPartyID( p.PartyID, fmt.Sprintf("party-%d", p.PartyIndex), big.NewInt(int64(p.PartyIndex+1)), ) tssPartyIDs[i] = partyID if p.PartyID == selfParty.PartyID { selfTSSID = partyID } } if selfTSSID == nil { return nil, errors.New("self party not found in all parties") } // Sort party IDs sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) // Build mapping from keygen index to sorted array index // The sorted array is ordered by big.Int key (PartyIndex+1) keygenIndexToSortedIndex := make(map[int]int) for sortedIdx, partyID := range sortedPartyIDs { // Find the original keygen index for this party for _, p := range allParties { if p.PartyID == partyID.Id { keygenIndexToSortedIndex[p.PartyIndex] = sortedIdx break } } } fmt.Printf("[TSS-SIGN] Built keygen index to sorted index mapping: %v party_id=%s\n", keygenIndexToSortedIndex, selfParty.PartyID) // Create peer context and parameters // IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required // This MUST match keygen exactly! Both use (Threshold-1) // The BuildLocalSaveDataSubset call in Start() will filter the save data to match peerCtx := tss.NewPeerContext(sortedPartyIDs) tssThreshold := config.Threshold - 1 params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold) fmt.Printf("[TSS-SIGN] NewParameters: partyCount=%d, tssThreshold=%d (from config.Threshold=%d) party_id=%s\n", len(sortedPartyIDs), tssThreshold, config.Threshold, selfParty.PartyID) // Convert message hash to big.Int msgHash := new(big.Int).SetBytes(messageHash) return &SigningSession{ config: config, selfParty: selfParty, allParties: allParties, messageHash: msgHash, saveData: &saveData, tssPartyIDs: sortedPartyIDs, selfTSSID: selfTSSID, params: params, outCh: make(chan tss.Message, config.TotalSigners*10), endCh: make(chan *common.SignatureData, 1), errCh: make(chan error, 1), msgHandler: msgHandler, keygenIndexToSortedIndex: keygenIndexToSortedIndex, }, nil } // Start begins the signing protocol func (s *SigningSession) Start(ctx context.Context) (*SigningResult, error) { s.mu.Lock() if s.started { s.mu.Unlock() return nil, errors.New("session already started") } s.started = true s.mu.Unlock() // CRITICAL: Build a subset of the save data for the current signing parties // When signing with fewer parties than keygen (e.g., 2-of-3 signing with only 2 parties), // we must filter the save data to only include the participating parties' data. // This ensures TSS-lib's internal indices match the actual signers. subsetSaveData := keygen.BuildLocalSaveDataSubset(*s.saveData, s.tssPartyIDs) fmt.Printf("[TSS-SIGN] Built save data subset for %d signing parties (original keygen had %d parties) party_id=%s\n", len(s.tssPartyIDs), len(s.saveData.Ks), s.selfParty.PartyID) // Create local party for signing with the SUBSET save data s.localParty = signing.NewLocalParty(s.messageHash, s.params, subsetSaveData, s.outCh, s.endCh) // Start the local party go func() { if err := s.localParty.Start(); err != nil { s.errCh <- err } }() // Handle outgoing messages go s.handleOutgoingMessages(ctx) // Handle incoming messages go s.handleIncomingMessages(ctx) // Wait for completion or timeout timeout := s.config.Timeout if timeout == 0 { timeout = 5 * time.Minute } select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(timeout): return nil, ErrSigningTimeout case tssErr := <-s.errCh: return nil, fmt.Errorf("%w: %v", ErrSigningFailed, tssErr) case signData := <-s.endCh: return s.buildResult(signData) } } func (s *SigningSession) handleOutgoingMessages(ctx context.Context) { fmt.Printf("[TSS-SIGN] handleOutgoingMessages started party_id=%s\n", s.selfParty.PartyID) for { select { case <-ctx.Done(): fmt.Printf("[TSS-SIGN] handleOutgoingMessages context cancelled party_id=%s\n", s.selfParty.PartyID) return case msg := <-s.outCh: if msg == nil { fmt.Printf("[TSS-SIGN] handleOutgoingMessages received nil message, stopping party_id=%s\n", s.selfParty.PartyID) return } msgBytes, _, err := msg.WireBytes() if err != nil { fmt.Printf("[TSS-SIGN] Failed to get wire bytes party_id=%s error=%v\n", s.selfParty.PartyID, err) continue } var toParties []string isBroadcast := msg.IsBroadcast() if !isBroadcast { for _, to := range msg.GetTo() { toParties = append(toParties, to.Id) } } fmt.Printf("[TSS-SIGN] sending outgoing message party_id=%s is_broadcast=%v to_parties=%v msg_type=%s\n", s.selfParty.PartyID, isBroadcast, toParties, msg.Type()) if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil { fmt.Printf("[TSS-SIGN] Failed to send message party_id=%s error=%v\n", s.selfParty.PartyID, err) continue } } } } func (s *SigningSession) handleIncomingMessages(ctx context.Context) { fmt.Printf("[TSS-SIGN] handleIncomingMessages started party_id=%s\n", s.selfParty.PartyID) msgCh := s.msgHandler.ReceiveMessages() for { select { case <-ctx.Done(): fmt.Printf("[TSS-SIGN] handleIncomingMessages context cancelled party_id=%s\n", s.selfParty.PartyID) return case msg, ok := <-msgCh: if !ok { fmt.Printf("[TSS-SIGN] handleIncomingMessages channel closed party_id=%s\n", s.selfParty.PartyID) return } fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_keygen_index=%d is_broadcast=%v msg_len=%d\n", s.selfParty.PartyID, msg.FromPartyIndex, msg.IsBroadcast, len(msg.MsgBytes)) // Map keygen index to sorted array index // msg.FromPartyIndex is the original keygen party index (e.g., 0, 1, 2) // We need the sorted array index (e.g., 0, 1 for a 2-party signing session) sortedIndex, exists := s.keygenIndexToSortedIndex[msg.FromPartyIndex] if !exists { fmt.Printf("[TSS-SIGN] ERROR: unknown keygen index=%d, mapping=%v party_id=%s\n", msg.FromPartyIndex, s.keygenIndexToSortedIndex, s.selfParty.PartyID) continue } fmt.Printf("[TSS-SIGN] mapped keygen_index=%d to sorted_index=%d party_id=%s\n", msg.FromPartyIndex, sortedIndex, s.selfParty.PartyID) // Check if sorted index is valid if sortedIndex < 0 || sortedIndex >= len(s.tssPartyIDs) { fmt.Printf("[TSS-SIGN] ERROR: invalid sortedIndex=%d, len(tssPartyIDs)=%d party_id=%s\n", sortedIndex, len(s.tssPartyIDs), s.selfParty.PartyID) continue } // Parse the message using the sorted index parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[sortedIndex], msg.IsBroadcast) if err != nil { fmt.Printf("[TSS-SIGN] ERROR: failed to parse wire message party_id=%s from_index=%d error=%v\n", s.selfParty.PartyID, msg.FromPartyIndex, err) continue } fmt.Printf("[TSS-SIGN] parsed message successfully party_id=%s msg_type=%s\n", s.selfParty.PartyID, parsedMsg.Type()) // Update the party go func() { ok, err := s.localParty.Update(parsedMsg) if err != nil { fmt.Printf("[TSS-SIGN] ERROR: party update failed party_id=%s error=%v\n", s.selfParty.PartyID, err) s.errCh <- err } else { fmt.Printf("[TSS-SIGN] party update succeeded party_id=%s ok=%v\n", s.selfParty.PartyID, ok) } }() } } } func (s *SigningSession) buildResult(signData *common.SignatureData) (*SigningResult, error) { // Get R and S as big.Int r := new(big.Int).SetBytes(signData.R) rS := new(big.Int).SetBytes(signData.S) // Build full signature (R || S) signature := make([]byte, 64) rBytes := signData.R sBytes := signData.S // Pad to 32 bytes each copy(signature[32-len(rBytes):32], rBytes) copy(signature[64-len(sBytes):64], sBytes) // Calculate recovery ID recoveryID := int(signData.SignatureRecovery[0]) return &SigningResult{ Signature: signature, R: r, S: rS, RecoveryID: recoveryID, }, nil } // LocalSigningResult contains local signing result for standalone testing type LocalSigningResult struct { Signature []byte R *big.Int S *big.Int RecoveryID int } // RunLocalSigning runs signing locally with all parties in the same process (for testing) func RunLocalSigning( threshold int, keygenResults []*LocalKeygenResult, messageHash []byte, ) (*LocalSigningResult, error) { signerCount := len(keygenResults) if signerCount < threshold { return nil, ErrInvalidSignerCount } // Create party IDs for signers using their ORIGINAL party indices from keygen // This is critical for subset signing - party IDs must match the original keygen party IDs partyIDs := make([]*tss.PartyID, signerCount) for i, result := range keygenResults { idx := result.PartyIndex partyIDs[i] = tss.NewPartyID( fmt.Sprintf("party-%d", idx), fmt.Sprintf("party-%d", idx), big.NewInt(int64(idx+1)), ) } sortedPartyIDs := tss.SortPartyIDs(partyIDs) peerCtx := tss.NewPeerContext(sortedPartyIDs) // Convert message hash to big.Int msgHash := new(big.Int).SetBytes(messageHash) // Create channels for each party outChs := make([]chan tss.Message, signerCount) endChs := make([]chan *common.SignatureData, signerCount) parties := make([]tss.Party, signerCount) // Map sorted party IDs back to keygen results sortedKeygenResults := make([]*LocalKeygenResult, signerCount) for i, pid := range sortedPartyIDs { for _, result := range keygenResults { expectedID := fmt.Sprintf("party-%d", result.PartyIndex) if pid.Id == expectedID { sortedKeygenResults[i] = result break } } } for i := 0; i < signerCount; i++ { outChs[i] = make(chan tss.Message, signerCount*10) endChs[i] = make(chan *common.SignatureData, 1) params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold) parties[i] = signing.NewLocalParty(msgHash, params, *sortedKeygenResults[i].SaveData, outChs[i], endChs[i]) } // Start all parties var wg sync.WaitGroup errCh := make(chan error, signerCount) for i := 0; i < signerCount; i++ { wg.Add(1) go func(idx int) { defer wg.Done() if err := parties[idx].Start(); err != nil { errCh <- err } }(i) } // Route messages between parties var routeWg sync.WaitGroup doneCh := make(chan struct{}) for i := 0; i < signerCount; i++ { routeWg.Add(1) go func(idx int) { defer routeWg.Done() for { select { case <-doneCh: return case msg := <-outChs[idx]: if msg == nil { return } dest := msg.GetTo() if msg.IsBroadcast() { for j := 0; j < signerCount; j++ { if j != idx { go updateSignParty(parties[j], msg, errCh) } } } else { for _, d := range dest { for j := 0; j < signerCount; j++ { if sortedPartyIDs[j].Id == d.Id { go updateSignParty(parties[j], msg, errCh) break } } } } } } }(i) } // Collect first result (all parties should produce same signature) var result *LocalSigningResult for i := 0; i < signerCount; i++ { select { case signData := <-endChs[i]: if result == nil { r := new(big.Int).SetBytes(signData.R) rS := new(big.Int).SetBytes(signData.S) signature := make([]byte, 64) copy(signature[32-len(signData.R):32], signData.R) copy(signature[64-len(signData.S):64], signData.S) result = &LocalSigningResult{ Signature: signature, R: r, S: rS, RecoveryID: int(signData.SignatureRecovery[0]), } } case err := <-errCh: close(doneCh) return nil, err case <-time.After(5 * time.Minute): close(doneCh) return nil, ErrSigningTimeout } } close(doneCh) return result, nil } func updateSignParty(party tss.Party, msg tss.Message, errCh chan error) { bytes, routing, err := msg.WireBytes() if err != nil { errCh <- err return } parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast) if err != nil { errCh <- err return } if _, err := party.Update(parsedMsg); err != nil { // Only send error if it's not a duplicate message error if err.Error() != "" && !isSignDuplicateMessageError(err) { errCh <- err } } } // isSignDuplicateMessageError checks if an error is a duplicate message error func isSignDuplicateMessageError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received") }