package tss import ( "context" "encoding/json" "errors" "fmt" "math/big" "strings" "sync" "time" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" ) var ( ErrSigningTimeout = errors.New("signing timeout") ErrSigningFailed = errors.New("signing failed") ErrInvalidSignerCount = errors.New("invalid signer count") ErrInvalidShareData = errors.New("invalid share data") ) // SigningResult contains the result of a signing operation type SigningResult struct { // Signature is the full ECDSA signature (R || S) Signature []byte // R is the R component of the signature R *big.Int // S is the S component of the signature S *big.Int // RecoveryID is the recovery ID for ecrecover RecoveryID int } // SigningParty represents a party participating in signing type SigningParty struct { PartyID string PartyIndex int } // SigningConfig contains configuration for signing type SigningConfig struct { Threshold int // t in t-of-n (number of signers required) TotalSigners int // Number of parties participating in this signing Timeout time.Duration // Signing timeout } // SigningSession manages a signing session for a single party type SigningSession struct { config SigningConfig selfParty SigningParty allParties []SigningParty messageHash *big.Int saveData *keygen.LocalPartySaveData tssPartyIDs []*tss.PartyID selfTSSID *tss.PartyID params *tss.Parameters localParty tss.Party outCh chan tss.Message endCh chan *common.SignatureData errCh chan error msgHandler MessageHandler mu sync.Mutex started bool } // NewSigningSession creates a new signing session func NewSigningSession( config SigningConfig, selfParty SigningParty, allParties []SigningParty, messageHash []byte, saveDataBytes []byte, msgHandler MessageHandler, ) (*SigningSession, error) { if config.TotalSigners < config.Threshold { return nil, ErrInvalidSignerCount } if len(allParties) != config.TotalSigners { return nil, ErrInvalidSignerCount } // Deserialize save data var saveData keygen.LocalPartySaveData if err := json.Unmarshal(saveDataBytes, &saveData); err != nil { return nil, fmt.Errorf("%w: %v", ErrInvalidShareData, err) } // Create TSS party IDs for signers tssPartyIDs := make([]*tss.PartyID, len(allParties)) var selfTSSID *tss.PartyID for i, p := range allParties { partyID := tss.NewPartyID( p.PartyID, fmt.Sprintf("party-%d", p.PartyIndex), big.NewInt(int64(p.PartyIndex+1)), ) tssPartyIDs[i] = partyID if p.PartyID == selfParty.PartyID { selfTSSID = partyID } } if selfTSSID == nil { return nil, errors.New("self party not found in all parties") } // Sort party IDs sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) // Create peer context and parameters peerCtx := tss.NewPeerContext(sortedPartyIDs) params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold) // Convert message hash to big.Int msgHash := new(big.Int).SetBytes(messageHash) return &SigningSession{ config: config, selfParty: selfParty, allParties: allParties, messageHash: msgHash, saveData: &saveData, tssPartyIDs: sortedPartyIDs, selfTSSID: selfTSSID, params: params, outCh: make(chan tss.Message, config.TotalSigners*10), endCh: make(chan *common.SignatureData, 1), errCh: make(chan error, 1), msgHandler: msgHandler, }, nil } // Start begins the signing protocol func (s *SigningSession) Start(ctx context.Context) (*SigningResult, error) { s.mu.Lock() if s.started { s.mu.Unlock() return nil, errors.New("session already started") } s.started = true s.mu.Unlock() // Create local party for signing s.localParty = signing.NewLocalParty(s.messageHash, s.params, *s.saveData, s.outCh, s.endCh) // Start the local party go func() { if err := s.localParty.Start(); err != nil { s.errCh <- err } }() // Handle outgoing messages go s.handleOutgoingMessages(ctx) // Handle incoming messages go s.handleIncomingMessages(ctx) // Wait for completion or timeout timeout := s.config.Timeout if timeout == 0 { timeout = 5 * time.Minute } select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(timeout): return nil, ErrSigningTimeout case tssErr := <-s.errCh: return nil, fmt.Errorf("%w: %v", ErrSigningFailed, tssErr) case signData := <-s.endCh: return s.buildResult(signData) } } func (s *SigningSession) handleOutgoingMessages(ctx context.Context) { for { select { case <-ctx.Done(): return case msg := <-s.outCh: if msg == nil { return } msgBytes, _, err := msg.WireBytes() if err != nil { continue } var toParties []string isBroadcast := msg.IsBroadcast() if !isBroadcast { for _, to := range msg.GetTo() { toParties = append(toParties, to.Id) } } if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil { continue } } } } func (s *SigningSession) handleIncomingMessages(ctx context.Context) { msgCh := s.msgHandler.ReceiveMessages() for { select { case <-ctx.Done(): return case msg, ok := <-msgCh: if !ok { return } // Parse the message parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast) if err != nil { continue } // Update the party go func() { ok, err := s.localParty.Update(parsedMsg) if err != nil { s.errCh <- err } _ = ok }() } } } func (s *SigningSession) buildResult(signData *common.SignatureData) (*SigningResult, error) { // Get R and S as big.Int r := new(big.Int).SetBytes(signData.R) rS := new(big.Int).SetBytes(signData.S) // Build full signature (R || S) signature := make([]byte, 64) rBytes := signData.R sBytes := signData.S // Pad to 32 bytes each copy(signature[32-len(rBytes):32], rBytes) copy(signature[64-len(sBytes):64], sBytes) // Calculate recovery ID recoveryID := int(signData.SignatureRecovery[0]) return &SigningResult{ Signature: signature, R: r, S: rS, RecoveryID: recoveryID, }, nil } // LocalSigningResult contains local signing result for standalone testing type LocalSigningResult struct { Signature []byte R *big.Int S *big.Int RecoveryID int } // RunLocalSigning runs signing locally with all parties in the same process (for testing) func RunLocalSigning( threshold int, keygenResults []*LocalKeygenResult, messageHash []byte, ) (*LocalSigningResult, error) { signerCount := len(keygenResults) if signerCount < threshold { return nil, ErrInvalidSignerCount } // Create party IDs for signers using their ORIGINAL party indices from keygen // This is critical for subset signing - party IDs must match the original keygen party IDs partyIDs := make([]*tss.PartyID, signerCount) for i, result := range keygenResults { idx := result.PartyIndex partyIDs[i] = tss.NewPartyID( fmt.Sprintf("party-%d", idx), fmt.Sprintf("party-%d", idx), big.NewInt(int64(idx+1)), ) } sortedPartyIDs := tss.SortPartyIDs(partyIDs) peerCtx := tss.NewPeerContext(sortedPartyIDs) // Convert message hash to big.Int msgHash := new(big.Int).SetBytes(messageHash) // Create channels for each party outChs := make([]chan tss.Message, signerCount) endChs := make([]chan *common.SignatureData, signerCount) parties := make([]tss.Party, signerCount) // Map sorted party IDs back to keygen results sortedKeygenResults := make([]*LocalKeygenResult, signerCount) for i, pid := range sortedPartyIDs { for _, result := range keygenResults { expectedID := fmt.Sprintf("party-%d", result.PartyIndex) if pid.Id == expectedID { sortedKeygenResults[i] = result break } } } for i := 0; i < signerCount; i++ { outChs[i] = make(chan tss.Message, signerCount*10) endChs[i] = make(chan *common.SignatureData, 1) params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold) parties[i] = signing.NewLocalParty(msgHash, params, *sortedKeygenResults[i].SaveData, outChs[i], endChs[i]) } // Start all parties var wg sync.WaitGroup errCh := make(chan error, signerCount) for i := 0; i < signerCount; i++ { wg.Add(1) go func(idx int) { defer wg.Done() if err := parties[idx].Start(); err != nil { errCh <- err } }(i) } // Route messages between parties var routeWg sync.WaitGroup doneCh := make(chan struct{}) for i := 0; i < signerCount; i++ { routeWg.Add(1) go func(idx int) { defer routeWg.Done() for { select { case <-doneCh: return case msg := <-outChs[idx]: if msg == nil { return } dest := msg.GetTo() if msg.IsBroadcast() { for j := 0; j < signerCount; j++ { if j != idx { go updateSignParty(parties[j], msg, errCh) } } } else { for _, d := range dest { for j := 0; j < signerCount; j++ { if sortedPartyIDs[j].Id == d.Id { go updateSignParty(parties[j], msg, errCh) break } } } } } } }(i) } // Collect first result (all parties should produce same signature) var result *LocalSigningResult for i := 0; i < signerCount; i++ { select { case signData := <-endChs[i]: if result == nil { r := new(big.Int).SetBytes(signData.R) rS := new(big.Int).SetBytes(signData.S) signature := make([]byte, 64) copy(signature[32-len(signData.R):32], signData.R) copy(signature[64-len(signData.S):64], signData.S) result = &LocalSigningResult{ Signature: signature, R: r, S: rS, RecoveryID: int(signData.SignatureRecovery[0]), } } case err := <-errCh: close(doneCh) return nil, err case <-time.After(5 * time.Minute): close(doneCh) return nil, ErrSigningTimeout } } close(doneCh) return result, nil } func updateSignParty(party tss.Party, msg tss.Message, errCh chan error) { bytes, routing, err := msg.WireBytes() if err != nil { errCh <- err return } parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast) if err != nil { errCh <- err return } if _, err := party.Update(parsedMsg); err != nil { // Only send error if it's not a duplicate message error if err.Error() != "" && !isSignDuplicateMessageError(err) { errCh <- err } } } // isSignDuplicateMessageError checks if an error is a duplicate message error func isSignDuplicateMessageError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received") }