rwadurian/backend/mpc-system/pkg/tss/keygen.go

406 lines
10 KiB
Go

package tss
import (
"context"
"crypto/ecdsa"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"
"sync"
"time"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/tss"
)
var (
ErrKeygenTimeout = errors.New("keygen timeout")
ErrKeygenFailed = errors.New("keygen failed")
ErrInvalidPartyCount = errors.New("invalid party count")
ErrInvalidThreshold = errors.New("invalid threshold")
)
// KeygenResult contains the result of a keygen operation
type KeygenResult struct {
// LocalPartySaveData is the serialized save data for this party
LocalPartySaveData []byte
// PublicKey is the group ECDSA public key
PublicKey *ecdsa.PublicKey
// PublicKeyBytes is the compressed public key bytes
PublicKeyBytes []byte
}
// KeygenParty represents a party participating in keygen
type KeygenParty struct {
PartyID string
PartyIndex int
}
// KeygenConfig contains configuration for keygen
type KeygenConfig struct {
Threshold int // t in t-of-n
TotalParties int // n in t-of-n
Timeout time.Duration // Keygen timeout
}
// KeygenSession manages a keygen session for a single party
type KeygenSession struct {
config KeygenConfig
selfParty KeygenParty
allParties []KeygenParty
tssPartyIDs []*tss.PartyID
selfTSSID *tss.PartyID
params *tss.Parameters
localParty tss.Party
outCh chan tss.Message
endCh chan *keygen.LocalPartySaveData
errCh chan error
msgHandler MessageHandler
mu sync.Mutex
started bool
}
// MessageHandler handles outgoing and incoming TSS messages
type MessageHandler interface {
// SendMessage sends a message to other parties
SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error
// ReceiveMessages returns a channel for receiving messages
ReceiveMessages() <-chan *ReceivedMessage
}
// ReceivedMessage represents a received TSS message
type ReceivedMessage struct {
FromPartyIndex int
IsBroadcast bool
MsgBytes []byte
}
// NewKeygenSession creates a new keygen session
func NewKeygenSession(
config KeygenConfig,
selfParty KeygenParty,
allParties []KeygenParty,
msgHandler MessageHandler,
) (*KeygenSession, error) {
if config.TotalParties < 2 {
return nil, ErrInvalidPartyCount
}
if config.Threshold < 1 || config.Threshold > config.TotalParties {
return nil, ErrInvalidThreshold
}
if len(allParties) != config.TotalParties {
return nil, ErrInvalidPartyCount
}
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(allParties))
var selfTSSID *tss.PartyID
for i, p := range allParties {
partyID := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyID
if p.PartyID == selfParty.PartyID {
selfTSSID = partyID
}
}
if selfTSSID == nil {
return nil, errors.New("self party not found in all parties")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), config.Threshold)
return &KeygenSession{
config: config,
selfParty: selfParty,
allParties: allParties,
tssPartyIDs: sortedPartyIDs,
selfTSSID: selfTSSID,
params: params,
outCh: make(chan tss.Message, config.TotalParties*10),
endCh: make(chan *keygen.LocalPartySaveData, 1),
errCh: make(chan error, 1),
msgHandler: msgHandler,
}, nil
}
// Start begins the keygen protocol
func (s *KeygenSession) Start(ctx context.Context) (*KeygenResult, error) {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return nil, errors.New("session already started")
}
s.started = true
s.mu.Unlock()
// Create local party
s.localParty = keygen.NewLocalParty(s.params, s.outCh, s.endCh)
// Start the local party
go func() {
if err := s.localParty.Start(); err != nil {
s.errCh <- err
}
}()
// Handle outgoing messages
go s.handleOutgoingMessages(ctx)
// Handle incoming messages
go s.handleIncomingMessages(ctx)
// Wait for completion or timeout
timeout := s.config.Timeout
if timeout == 0 {
timeout = 10 * time.Minute
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(timeout):
return nil, ErrKeygenTimeout
case tssErr := <-s.errCh:
return nil, fmt.Errorf("%w: %v", ErrKeygenFailed, tssErr)
case saveData := <-s.endCh:
return s.buildResult(saveData)
}
}
func (s *KeygenSession) handleOutgoingMessages(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case msg := <-s.outCh:
if msg == nil {
return
}
msgBytes, _, err := msg.WireBytes()
if err != nil {
continue
}
var toParties []string
isBroadcast := msg.IsBroadcast()
if !isBroadcast {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
// Log error but continue
continue
}
}
}
}
func (s *KeygenSession) handleIncomingMessages(ctx context.Context) {
msgCh := s.msgHandler.ReceiveMessages()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-msgCh:
if !ok {
return
}
// Parse the message
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
if err != nil {
continue
}
// Update the party
go func() {
ok, err := s.localParty.Update(parsedMsg)
if err != nil {
s.errCh <- err
}
_ = ok
}()
}
}
}
func (s *KeygenSession) buildResult(saveData *keygen.LocalPartySaveData) (*KeygenResult, error) {
// Serialize save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return nil, fmt.Errorf("failed to serialize save data: %w", err)
}
// Get public key
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
// Compress public key
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
return &KeygenResult{
LocalPartySaveData: saveDataBytes,
PublicKey: pubKey,
PublicKeyBytes: pubKeyBytes,
}, nil
}
// LocalKeygenResult contains local keygen result for standalone testing
type LocalKeygenResult struct {
SaveData *keygen.LocalPartySaveData
PublicKey *ecdsa.PublicKey
PartyIndex int
}
// RunLocalKeygen runs keygen locally with all parties in the same process (for testing)
func RunLocalKeygen(threshold, totalParties int) ([]*LocalKeygenResult, error) {
if totalParties < 2 {
return nil, ErrInvalidPartyCount
}
if threshold < 1 || threshold > totalParties {
return nil, ErrInvalidThreshold
}
// Create party IDs
partyIDs := make([]*tss.PartyID, totalParties)
for i := 0; i < totalParties; i++ {
partyIDs[i] = tss.NewPartyID(
fmt.Sprintf("party-%d", i),
fmt.Sprintf("party-%d", i),
big.NewInt(int64(i+1)),
)
}
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
peerCtx := tss.NewPeerContext(sortedPartyIDs)
// Create channels for each party
outChs := make([]chan tss.Message, totalParties)
endChs := make([]chan *keygen.LocalPartySaveData, totalParties)
parties := make([]tss.Party, totalParties)
for i := 0; i < totalParties; i++ {
outChs[i] = make(chan tss.Message, totalParties*10)
endChs[i] = make(chan *keygen.LocalPartySaveData, 1)
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], totalParties, threshold)
parties[i] = keygen.NewLocalParty(params, outChs[i], endChs[i])
}
// Start all parties
var wg sync.WaitGroup
errCh := make(chan error, totalParties)
for i := 0; i < totalParties; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
if err := parties[idx].Start(); err != nil {
errCh <- err
}
}(i)
}
// Route messages between parties
var routeWg sync.WaitGroup
doneCh := make(chan struct{})
for i := 0; i < totalParties; i++ {
routeWg.Add(1)
go func(idx int) {
defer routeWg.Done()
for {
select {
case <-doneCh:
return
case msg := <-outChs[idx]:
if msg == nil {
return
}
dest := msg.GetTo()
if msg.IsBroadcast() {
for j := 0; j < totalParties; j++ {
if j != idx {
go updateParty(parties[j], msg, errCh)
}
}
} else {
for _, d := range dest {
for j := 0; j < totalParties; j++ {
if sortedPartyIDs[j].Id == d.Id {
go updateParty(parties[j], msg, errCh)
break
}
}
}
}
}
}
}(i)
}
// Collect results
results := make([]*LocalKeygenResult, totalParties)
for i := 0; i < totalParties; i++ {
select {
case saveData := <-endChs[i]:
results[i] = &LocalKeygenResult{
SaveData: saveData,
PublicKey: saveData.ECDSAPub.ToECDSAPubKey(),
PartyIndex: i,
}
case err := <-errCh:
close(doneCh)
return nil, err
case <-time.After(5 * time.Minute):
close(doneCh)
return nil, ErrKeygenTimeout
}
}
close(doneCh)
return results, nil
}
func updateParty(party tss.Party, msg tss.Message, errCh chan error) {
bytes, routing, err := msg.WireBytes()
if err != nil {
errCh <- err
return
}
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
if err != nil {
errCh <- err
return
}
if _, err := party.Update(parsedMsg); err != nil {
// Only send error if it's not a duplicate message error
// Check if error message contains "duplicate message" indication
if err.Error() != "" && !isDuplicateMessageError(err) {
errCh <- err
}
}
}
// isDuplicateMessageError checks if an error is a duplicate message error
func isDuplicateMessageError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
}