409 lines
10 KiB
Go
409 lines
10 KiB
Go
package tss
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
|
)
|
|
|
|
var (
|
|
ErrKeygenTimeout = errors.New("keygen timeout")
|
|
ErrKeygenFailed = errors.New("keygen failed")
|
|
ErrInvalidPartyCount = errors.New("invalid party count")
|
|
ErrInvalidThreshold = errors.New("invalid threshold")
|
|
)
|
|
|
|
// KeygenResult contains the result of a keygen operation
|
|
type KeygenResult struct {
|
|
// LocalPartySaveData is the serialized save data for this party
|
|
LocalPartySaveData []byte
|
|
// PublicKey is the group ECDSA public key
|
|
PublicKey *ecdsa.PublicKey
|
|
// PublicKeyBytes is the compressed public key bytes
|
|
PublicKeyBytes []byte
|
|
}
|
|
|
|
// KeygenParty represents a party participating in keygen
|
|
type KeygenParty struct {
|
|
PartyID string
|
|
PartyIndex int
|
|
}
|
|
|
|
// KeygenConfig contains configuration for keygen
|
|
type KeygenConfig struct {
|
|
Threshold int // t in t-of-n
|
|
TotalParties int // n in t-of-n
|
|
Timeout time.Duration // Keygen timeout
|
|
}
|
|
|
|
// KeygenSession manages a keygen session for a single party
|
|
type KeygenSession struct {
|
|
config KeygenConfig
|
|
selfParty KeygenParty
|
|
allParties []KeygenParty
|
|
tssPartyIDs []*tss.PartyID
|
|
selfTSSID *tss.PartyID
|
|
params *tss.Parameters
|
|
localParty tss.Party
|
|
outCh chan tss.Message
|
|
endCh chan *keygen.LocalPartySaveData
|
|
errCh chan error
|
|
msgHandler MessageHandler
|
|
mu sync.Mutex
|
|
started bool
|
|
}
|
|
|
|
// MessageHandler handles outgoing and incoming TSS messages
|
|
type MessageHandler interface {
|
|
// SendMessage sends a message to other parties
|
|
SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error
|
|
// ReceiveMessages returns a channel for receiving messages
|
|
ReceiveMessages() <-chan *ReceivedMessage
|
|
}
|
|
|
|
// ReceivedMessage represents a received TSS message
|
|
type ReceivedMessage struct {
|
|
FromPartyIndex int
|
|
IsBroadcast bool
|
|
MsgBytes []byte
|
|
}
|
|
|
|
// NewKeygenSession creates a new keygen session
|
|
func NewKeygenSession(
|
|
config KeygenConfig,
|
|
selfParty KeygenParty,
|
|
allParties []KeygenParty,
|
|
msgHandler MessageHandler,
|
|
) (*KeygenSession, error) {
|
|
if config.TotalParties < 2 {
|
|
return nil, ErrInvalidPartyCount
|
|
}
|
|
if config.Threshold < 1 || config.Threshold > config.TotalParties {
|
|
return nil, ErrInvalidThreshold
|
|
}
|
|
if len(allParties) != config.TotalParties {
|
|
return nil, ErrInvalidPartyCount
|
|
}
|
|
|
|
// Create TSS party IDs
|
|
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
|
var selfTSSID *tss.PartyID
|
|
for i, p := range allParties {
|
|
partyID := tss.NewPartyID(
|
|
p.PartyID,
|
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
|
big.NewInt(int64(p.PartyIndex+1)),
|
|
)
|
|
tssPartyIDs[i] = partyID
|
|
if p.PartyID == selfParty.PartyID {
|
|
selfTSSID = partyID
|
|
}
|
|
}
|
|
|
|
if selfTSSID == nil {
|
|
return nil, errors.New("self party not found in all parties")
|
|
}
|
|
|
|
// Sort party IDs
|
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
|
|
|
// Create peer context and parameters
|
|
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
|
|
// User says "2-of-3" meaning 2 signers needed, so we pass (Threshold-1) to TSS-lib
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
tssThreshold := config.Threshold - 1
|
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
|
|
|
|
return &KeygenSession{
|
|
config: config,
|
|
selfParty: selfParty,
|
|
allParties: allParties,
|
|
tssPartyIDs: sortedPartyIDs,
|
|
selfTSSID: selfTSSID,
|
|
params: params,
|
|
outCh: make(chan tss.Message, config.TotalParties*10),
|
|
endCh: make(chan *keygen.LocalPartySaveData, 1),
|
|
errCh: make(chan error, 1),
|
|
msgHandler: msgHandler,
|
|
}, nil
|
|
}
|
|
|
|
// Start begins the keygen protocol
|
|
func (s *KeygenSession) Start(ctx context.Context) (*KeygenResult, error) {
|
|
s.mu.Lock()
|
|
if s.started {
|
|
s.mu.Unlock()
|
|
return nil, errors.New("session already started")
|
|
}
|
|
s.started = true
|
|
s.mu.Unlock()
|
|
|
|
// Create local party
|
|
s.localParty = keygen.NewLocalParty(s.params, s.outCh, s.endCh)
|
|
|
|
// Start the local party
|
|
go func() {
|
|
if err := s.localParty.Start(); err != nil {
|
|
s.errCh <- err
|
|
}
|
|
}()
|
|
|
|
// Handle outgoing messages
|
|
go s.handleOutgoingMessages(ctx)
|
|
|
|
// Handle incoming messages
|
|
go s.handleIncomingMessages(ctx)
|
|
|
|
// Wait for completion or timeout
|
|
timeout := s.config.Timeout
|
|
if timeout == 0 {
|
|
timeout = 10 * time.Minute
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(timeout):
|
|
return nil, ErrKeygenTimeout
|
|
case tssErr := <-s.errCh:
|
|
return nil, fmt.Errorf("%w: %v", ErrKeygenFailed, tssErr)
|
|
case saveData := <-s.endCh:
|
|
return s.buildResult(saveData)
|
|
}
|
|
}
|
|
|
|
func (s *KeygenSession) handleOutgoingMessages(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg := <-s.outCh:
|
|
if msg == nil {
|
|
return
|
|
}
|
|
msgBytes, _, err := msg.WireBytes()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var toParties []string
|
|
isBroadcast := msg.IsBroadcast()
|
|
if !isBroadcast {
|
|
for _, to := range msg.GetTo() {
|
|
toParties = append(toParties, to.Id)
|
|
}
|
|
}
|
|
|
|
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
|
// Log error but continue
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *KeygenSession) handleIncomingMessages(ctx context.Context) {
|
|
msgCh := s.msgHandler.ReceiveMessages()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-msgCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// Parse the message
|
|
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Update the party
|
|
go func() {
|
|
ok, err := s.localParty.Update(parsedMsg)
|
|
if err != nil {
|
|
s.errCh <- err
|
|
}
|
|
_ = ok
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *KeygenSession) buildResult(saveData *keygen.LocalPartySaveData) (*KeygenResult, error) {
|
|
// Serialize save data
|
|
saveDataBytes, err := json.Marshal(saveData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize save data: %w", err)
|
|
}
|
|
|
|
// Get public key
|
|
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
|
|
|
|
// Compress public key
|
|
pubKeyBytes := make([]byte, 33)
|
|
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
|
|
xBytes := pubKey.X.Bytes()
|
|
copy(pubKeyBytes[33-len(xBytes):], xBytes)
|
|
|
|
return &KeygenResult{
|
|
LocalPartySaveData: saveDataBytes,
|
|
PublicKey: pubKey,
|
|
PublicKeyBytes: pubKeyBytes,
|
|
}, nil
|
|
}
|
|
|
|
// LocalKeygenResult contains local keygen result for standalone testing
|
|
type LocalKeygenResult struct {
|
|
SaveData *keygen.LocalPartySaveData
|
|
PublicKey *ecdsa.PublicKey
|
|
PartyIndex int
|
|
}
|
|
|
|
// RunLocalKeygen runs keygen locally with all parties in the same process (for testing)
|
|
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
|
|
if totalParties < 2 {
|
|
return nil, ErrInvalidPartyCount
|
|
}
|
|
if threshold < 1 || threshold > totalParties {
|
|
return nil, ErrInvalidThreshold
|
|
}
|
|
|
|
// Create party IDs
|
|
partyIDs := make([]*tss.PartyID, totalParties)
|
|
for i := 0; i < totalParties; i++ {
|
|
partyIDs[i] = tss.NewPartyID(
|
|
fmt.Sprintf("party-%d", i),
|
|
fmt.Sprintf("party-%d", i),
|
|
big.NewInt(int64(i+1)),
|
|
)
|
|
}
|
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
|
|
// Create channels for each party
|
|
outChs := make([]chan tss.Message, totalParties)
|
|
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
|
|
parties := make([]tss.Party, totalParties)
|
|
|
|
for i := 0; i < totalParties; i++ {
|
|
outChs[i] = make(chan tss.Message, totalParties*10)
|
|
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
|
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], totalParties, threshold)
|
|
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
|
|
}
|
|
|
|
// Start all parties
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, totalParties)
|
|
|
|
for i := 0; i < totalParties; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
if err := parties[idx].Start(); err != nil {
|
|
errCh <- err
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Route messages between parties
|
|
var routeWg sync.WaitGroup
|
|
doneCh := make(chan struct{})
|
|
|
|
for i := 0; i < totalParties; i++ {
|
|
routeWg.Add(1)
|
|
go func(idx int) {
|
|
defer routeWg.Done()
|
|
for {
|
|
select {
|
|
case <-doneCh:
|
|
return
|
|
case msg := <-outChs[idx]:
|
|
if msg == nil {
|
|
return
|
|
}
|
|
dest := msg.GetTo()
|
|
if msg.IsBroadcast() {
|
|
for j := 0; j < totalParties; j++ {
|
|
if j != idx {
|
|
go updateParty(parties[j], msg, errCh)
|
|
}
|
|
}
|
|
} else {
|
|
for _, d := range dest {
|
|
for j := 0; j < totalParties; j++ {
|
|
if sortedPartyIDs[j].Id == d.Id {
|
|
go updateParty(parties[j], msg, errCh)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Collect results
|
|
results := make([]*LocalKeygenResult, totalParties)
|
|
for i := 0; i < totalParties; i++ {
|
|
select {
|
|
case saveData := <-endChs[i]:
|
|
results[i] = &LocalKeygenResult{
|
|
SaveData: saveData,
|
|
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
|
|
PartyIndex: i,
|
|
}
|
|
case err := <-errCh:
|
|
close(doneCh)
|
|
return nil, err
|
|
case <-time.After(5 * time.Minute):
|
|
close(doneCh)
|
|
return nil, ErrKeygenTimeout
|
|
}
|
|
}
|
|
|
|
close(doneCh)
|
|
return results, nil
|
|
}
|
|
|
|
func updateParty(party tss.Party, msg tss.Message, errCh chan error) {
|
|
bytes, routing, err := msg.WireBytes()
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
if _, err := party.Update(parsedMsg); err != nil {
|
|
// Only send error if it's not a duplicate message error
|
|
// Check if error message contains "duplicate message" indication
|
|
if err.Error() != "" && !isDuplicateMessageError(err) {
|
|
errCh <- err
|
|
}
|
|
}
|
|
}
|
|
|
|
// isDuplicateMessageError checks if an error is a duplicate message error
|
|
func isDuplicateMessageError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
errStr := err.Error()
|
|
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
|
}
|