package tss import ( "context" "crypto/ecdsa" "encoding/json" "errors" "fmt" "math/big" "strings" "sync" "time" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/tss" ) var ( ErrKeygenTimeout = errors.New("keygen timeout") ErrKeygenFailed = errors.New("keygen failed") ErrInvalidPartyCount = errors.New("invalid party count") ErrInvalidThreshold = errors.New("invalid threshold") ) // KeygenResult contains the result of a keygen operation type KeygenResult struct { // LocalPartySaveData is the serialized save data for this party LocalPartySaveData []byte // PublicKey is the group ECDSA public key PublicKey *ecdsa.PublicKey // PublicKeyBytes is the compressed public key bytes PublicKeyBytes []byte } // KeygenParty represents a party participating in keygen type KeygenParty struct { PartyID string PartyIndex int } // KeygenConfig contains configuration for keygen type KeygenConfig struct { Threshold int // t in t-of-n TotalParties int // n in t-of-n Timeout time.Duration // Keygen timeout } // KeygenSession manages a keygen session for a single party type KeygenSession struct { config KeygenConfig selfParty KeygenParty allParties []KeygenParty tssPartyIDs []*tss.PartyID selfTSSID *tss.PartyID params *tss.Parameters localParty tss.Party outCh chan tss.Message endCh chan *keygen.LocalPartySaveData errCh chan error msgHandler MessageHandler mu sync.Mutex started bool } // MessageHandler handles outgoing and incoming TSS messages type MessageHandler interface { // SendMessage sends a message to other parties SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error // ReceiveMessages returns a channel for receiving messages ReceiveMessages() <-chan *ReceivedMessage } // ReceivedMessage represents a received TSS message type ReceivedMessage struct { FromPartyIndex int IsBroadcast bool MsgBytes []byte } // NewKeygenSession creates a new keygen session func NewKeygenSession( config KeygenConfig, selfParty KeygenParty, allParties []KeygenParty, msgHandler MessageHandler, ) (*KeygenSession, error) { if config.TotalParties < 2 { return nil, ErrInvalidPartyCount } if config.Threshold < 1 || config.Threshold > config.TotalParties { return nil, ErrInvalidThreshold } if len(allParties) != config.TotalParties { return nil, ErrInvalidPartyCount } // Create TSS party IDs tssPartyIDs := make([]*tss.PartyID, len(allParties)) var selfTSSID *tss.PartyID for i, p := range allParties { partyID := tss.NewPartyID( p.PartyID, fmt.Sprintf("party-%d", p.PartyIndex), big.NewInt(int64(p.PartyIndex+1)), ) tssPartyIDs[i] = partyID if p.PartyID == selfParty.PartyID { selfTSSID = partyID } } if selfTSSID == nil { return nil, errors.New("self party not found in all parties") } // Sort party IDs sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) // Create peer context and parameters peerCtx := tss.NewPeerContext(sortedPartyIDs) params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold) return &KeygenSession{ config: config, selfParty: selfParty, allParties: allParties, tssPartyIDs: sortedPartyIDs, selfTSSID: selfTSSID, params: params, outCh: make(chan tss.Message, config.TotalParties*10), endCh: make(chan *keygen.LocalPartySaveData, 1), errCh: make(chan error, 1), msgHandler: msgHandler, }, nil } // Start begins the keygen protocol func (s *KeygenSession) Start(ctx context.Context) (*KeygenResult, error) { s.mu.Lock() if s.started { s.mu.Unlock() return nil, errors.New("session already started") } s.started = true s.mu.Unlock() // Create local party s.localParty = keygen.NewLocalParty(s.params, s.outCh, s.endCh) // Start the local party go func() { if err := s.localParty.Start(); err != nil { s.errCh <- err } }() // Handle outgoing messages go s.handleOutgoingMessages(ctx) // Handle incoming messages go s.handleIncomingMessages(ctx) // Wait for completion or timeout timeout := s.config.Timeout if timeout == 0 { timeout = 10 * time.Minute } select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(timeout): return nil, ErrKeygenTimeout case tssErr := <-s.errCh: return nil, fmt.Errorf("%w: %v", ErrKeygenFailed, tssErr) case saveData := <-s.endCh: return s.buildResult(saveData) } } func (s *KeygenSession) handleOutgoingMessages(ctx context.Context) { for { select { case <-ctx.Done(): return case msg := <-s.outCh: if msg == nil { return } msgBytes, _, err := msg.WireBytes() if err != nil { continue } var toParties []string isBroadcast := msg.IsBroadcast() if !isBroadcast { for _, to := range msg.GetTo() { toParties = append(toParties, to.Id) } } if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil { // Log error but continue continue } } } } func (s *KeygenSession) handleIncomingMessages(ctx context.Context) { msgCh := s.msgHandler.ReceiveMessages() for { select { case <-ctx.Done(): return case msg, ok := <-msgCh: if !ok { return } // Parse the message parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast) if err != nil { continue } // Update the party go func() { ok, err := s.localParty.Update(parsedMsg) if err != nil { s.errCh <- err } _ = ok }() } } } func (s *KeygenSession) buildResult(saveData *keygen.LocalPartySaveData) (*KeygenResult, error) { // Serialize save data saveDataBytes, err := json.Marshal(saveData) if err != nil { return nil, fmt.Errorf("failed to serialize save data: %w", err) } // Get public key pubKey := saveData.ECDSAPub.ToECDSAPubKey() // Compress public key pubKeyBytes := make([]byte, 33) pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0)) xBytes := pubKey.X.Bytes() copy(pubKeyBytes[33-len(xBytes):], xBytes) return &KeygenResult{ LocalPartySaveData: saveDataBytes, PublicKey: pubKey, PublicKeyBytes: pubKeyBytes, }, nil } // LocalKeygenResult contains local keygen result for standalone testing type LocalKeygenResult struct { SaveData *keygen.LocalPartySaveData PublicKey *ecdsa.PublicKey PartyIndex int } // RunLocalKeygen runs keygen locally with all parties in the same process (for testing) func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) { if totalParties < 2 { return nil, ErrInvalidPartyCount } if threshold < 1 || threshold > totalParties { return nil, ErrInvalidThreshold } // Create party IDs partyIDs := make([]*tss.PartyID, totalParties) for i := 0; i < totalParties; i++ { partyIDs[i] = tss.NewPartyID( fmt.Sprintf("party-%d", i), fmt.Sprintf("party-%d", i), big.NewInt(int64(i+1)), ) } sortedPartyIDs := tss.SortPartyIDs(partyIDs) peerCtx := tss.NewPeerContext(sortedPartyIDs) // Create channels for each party outChs := make([]chan tss.Message, totalParties) endChs := make([]chan *keygen.LocalPartySaveData, totalParties) parties := make([]tss.Party, totalParties) for i := 0; i < totalParties; i++ { outChs[i] = make(chan tss.Message, totalParties*10) endChs[i] = make(chan *keygen.LocalPartySaveData, 1) params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], totalParties, threshold) parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i]) } // Start all parties var wg sync.WaitGroup errCh := make(chan error, totalParties) for i := 0; i < totalParties; i++ { wg.Add(1) go func(idx int) { defer wg.Done() if err := parties[idx].Start(); err != nil { errCh <- err } }(i) } // Route messages between parties var routeWg sync.WaitGroup doneCh := make(chan struct{}) for i := 0; i < totalParties; i++ { routeWg.Add(1) go func(idx int) { defer routeWg.Done() for { select { case <-doneCh: return case msg := <-outChs[idx]: if msg == nil { return } dest := msg.GetTo() if msg.IsBroadcast() { for j := 0; j < totalParties; j++ { if j != idx { go updateParty(parties[j], msg, errCh) } } } else { for _, d := range dest { for j := 0; j < totalParties; j++ { if sortedPartyIDs[j].Id == d.Id { go updateParty(parties[j], msg, errCh) break } } } } } } }(i) } // Collect results results := make([]*LocalKeygenResult, totalParties) for i := 0; i < totalParties; i++ { select { case saveData := <-endChs[i]: results[i] = &LocalKeygenResult{ SaveData: saveData, PublicKey: saveData.ECDSAPub.ToECDSAPubKey(), PartyIndex: i, } case err := <-errCh: close(doneCh) return nil, err case <-time.After(5 * time.Minute): close(doneCh) return nil, ErrKeygenTimeout } } close(doneCh) return results, nil } func updateParty(party tss.Party, msg tss.Message, errCh chan error) { bytes, routing, err := msg.WireBytes() if err != nil { errCh <- err return } parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast) if err != nil { errCh <- err return } if _, err := party.Update(parsedMsg); err != nil { // Only send error if it's not a duplicate message error // Check if error message contains "duplicate message" indication if err.Error() != "" && !isDuplicateMessageError(err) { errCh <- err } } } // isDuplicateMessageError checks if an error is a duplicate message error func isDuplicateMessageError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received") }