514 lines
16 KiB
Go
514 lines
16 KiB
Go
package tss
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bnb-chain/tss-lib/v2/common"
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
|
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
|
)
|
|
|
|
var (
|
|
ErrSigningTimeout = errors.New("signing timeout")
|
|
ErrSigningFailed = errors.New("signing failed")
|
|
ErrInvalidSignerCount = errors.New("invalid signer count")
|
|
ErrInvalidShareData = errors.New("invalid share data")
|
|
)
|
|
|
|
// SigningResult contains the result of a signing operation
|
|
type SigningResult struct {
|
|
// Signature is the full ECDSA signature (R || S)
|
|
Signature []byte
|
|
// R is the R component of the signature
|
|
R *big.Int
|
|
// S is the S component of the signature
|
|
S *big.Int
|
|
// RecoveryID is the recovery ID for ecrecover
|
|
RecoveryID int
|
|
}
|
|
|
|
// SigningParty represents a party participating in signing
|
|
type SigningParty struct {
|
|
PartyID string
|
|
PartyIndex int
|
|
}
|
|
|
|
// SigningConfig contains configuration for signing
|
|
type SigningConfig struct {
|
|
Threshold int // t in t-of-n threshold value from keygen
|
|
TotalParties int // n in t-of-n - total parties from keygen (NOT current signers)
|
|
TotalSigners int // Number of parties participating in this signing session
|
|
Timeout time.Duration // Signing timeout
|
|
}
|
|
|
|
// SigningSession manages a signing session for a single party
|
|
type SigningSession struct {
|
|
config SigningConfig
|
|
selfParty SigningParty
|
|
allParties []SigningParty
|
|
messageHash *big.Int
|
|
saveData *keygen.LocalPartySaveData
|
|
tssPartyIDs []*tss.PartyID
|
|
selfTSSID *tss.PartyID
|
|
params *tss.Parameters
|
|
localParty tss.Party
|
|
outCh chan tss.Message
|
|
endCh chan *common.SignatureData
|
|
errCh chan error
|
|
msgHandler MessageHandler
|
|
mu sync.Mutex
|
|
started bool
|
|
// keygenIndexToSortedIndex maps keygen party index to sorted array index
|
|
// This is needed because TSS messages use keygen index, but tssPartyIDs is sorted
|
|
keygenIndexToSortedIndex map[int]int
|
|
}
|
|
|
|
// NewSigningSession creates a new signing session
|
|
func NewSigningSession(
|
|
config SigningConfig,
|
|
selfParty SigningParty,
|
|
allParties []SigningParty,
|
|
messageHash []byte,
|
|
saveDataBytes []byte,
|
|
msgHandler MessageHandler,
|
|
) (*SigningSession, error) {
|
|
if config.TotalSigners < config.Threshold {
|
|
return nil, ErrInvalidSignerCount
|
|
}
|
|
if len(allParties) != config.TotalSigners {
|
|
return nil, ErrInvalidSignerCount
|
|
}
|
|
|
|
// Deserialize save data
|
|
var saveData keygen.LocalPartySaveData
|
|
if err := json.Unmarshal(saveDataBytes, &saveData); err != nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrInvalidShareData, err)
|
|
}
|
|
|
|
// Create TSS party IDs for signers
|
|
tssPartyIDs := make([]*tss.PartyID, len(allParties))
|
|
var selfTSSID *tss.PartyID
|
|
for i, p := range allParties {
|
|
partyID := tss.NewPartyID(
|
|
p.PartyID,
|
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
|
big.NewInt(int64(p.PartyIndex+1)),
|
|
)
|
|
tssPartyIDs[i] = partyID
|
|
if p.PartyID == selfParty.PartyID {
|
|
selfTSSID = partyID
|
|
}
|
|
}
|
|
|
|
if selfTSSID == nil {
|
|
return nil, errors.New("self party not found in all parties")
|
|
}
|
|
|
|
// Sort party IDs
|
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
|
|
|
// Build mapping from keygen index to sorted array index
|
|
// The sorted array is ordered by big.Int key (PartyIndex+1)
|
|
keygenIndexToSortedIndex := make(map[int]int)
|
|
for sortedIdx, partyID := range sortedPartyIDs {
|
|
// Find the original keygen index for this party
|
|
for _, p := range allParties {
|
|
if p.PartyID == partyID.Id {
|
|
keygenIndexToSortedIndex[p.PartyIndex] = sortedIdx
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
fmt.Printf("[TSS-SIGN] Built keygen index to sorted index mapping: %v party_id=%s\n",
|
|
keygenIndexToSortedIndex, selfParty.PartyID)
|
|
|
|
// Create peer context and parameters
|
|
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
|
|
// This MUST match keygen exactly! Both use (Threshold-1)
|
|
// The BuildLocalSaveDataSubset call in Start() will filter the save data to match
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
tssThreshold := config.Threshold - 1
|
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
|
|
|
|
fmt.Printf("[TSS-SIGN] NewParameters: partyCount=%d, tssThreshold=%d (from config.Threshold=%d) party_id=%s\n",
|
|
len(sortedPartyIDs), tssThreshold, config.Threshold, selfParty.PartyID)
|
|
|
|
// Convert message hash to big.Int
|
|
msgHash := new(big.Int).SetBytes(messageHash)
|
|
|
|
return &SigningSession{
|
|
config: config,
|
|
selfParty: selfParty,
|
|
allParties: allParties,
|
|
messageHash: msgHash,
|
|
saveData: &saveData,
|
|
tssPartyIDs: sortedPartyIDs,
|
|
selfTSSID: selfTSSID,
|
|
params: params,
|
|
outCh: make(chan tss.Message, config.TotalSigners*10),
|
|
endCh: make(chan *common.SignatureData, 1),
|
|
errCh: make(chan error, 1),
|
|
msgHandler: msgHandler,
|
|
keygenIndexToSortedIndex: keygenIndexToSortedIndex,
|
|
}, nil
|
|
}
|
|
|
|
// Start begins the signing protocol
|
|
func (s *SigningSession) Start(ctx context.Context) (*SigningResult, error) {
|
|
s.mu.Lock()
|
|
if s.started {
|
|
s.mu.Unlock()
|
|
return nil, errors.New("session already started")
|
|
}
|
|
s.started = true
|
|
s.mu.Unlock()
|
|
|
|
// CRITICAL: Build a subset of the save data for the current signing parties
|
|
// When signing with fewer parties than keygen (e.g., 2-of-3 signing with only 2 parties),
|
|
// we must filter the save data to only include the participating parties' data.
|
|
// This ensures TSS-lib's internal indices match the actual signers.
|
|
subsetSaveData := keygen.BuildLocalSaveDataSubset(*s.saveData, s.tssPartyIDs)
|
|
|
|
fmt.Printf("[TSS-SIGN] Built save data subset for %d signing parties (original keygen had %d parties) party_id=%s\n",
|
|
len(s.tssPartyIDs), len(s.saveData.Ks), s.selfParty.PartyID)
|
|
|
|
// Create local party for signing with the SUBSET save data
|
|
s.localParty = signing.NewLocalParty(s.messageHash, s.params, subsetSaveData, s.outCh, s.endCh)
|
|
|
|
// Start the local party
|
|
go func() {
|
|
if err := s.localParty.Start(); err != nil {
|
|
s.errCh <- err
|
|
}
|
|
}()
|
|
|
|
// Handle outgoing messages
|
|
go s.handleOutgoingMessages(ctx)
|
|
|
|
// Handle incoming messages
|
|
go s.handleIncomingMessages(ctx)
|
|
|
|
// Wait for completion or timeout
|
|
timeout := s.config.Timeout
|
|
if timeout == 0 {
|
|
timeout = 5 * time.Minute
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(timeout):
|
|
return nil, ErrSigningTimeout
|
|
case tssErr := <-s.errCh:
|
|
return nil, fmt.Errorf("%w: %v", ErrSigningFailed, tssErr)
|
|
case signData := <-s.endCh:
|
|
return s.buildResult(signData)
|
|
}
|
|
}
|
|
|
|
func (s *SigningSession) handleOutgoingMessages(ctx context.Context) {
|
|
fmt.Printf("[TSS-SIGN] handleOutgoingMessages started party_id=%s\n", s.selfParty.PartyID)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
fmt.Printf("[TSS-SIGN] handleOutgoingMessages context cancelled party_id=%s\n", s.selfParty.PartyID)
|
|
return
|
|
case msg := <-s.outCh:
|
|
if msg == nil {
|
|
fmt.Printf("[TSS-SIGN] handleOutgoingMessages received nil message, stopping party_id=%s\n", s.selfParty.PartyID)
|
|
return
|
|
}
|
|
msgBytes, _, err := msg.WireBytes()
|
|
if err != nil {
|
|
fmt.Printf("[TSS-SIGN] Failed to get wire bytes party_id=%s error=%v\n", s.selfParty.PartyID, err)
|
|
continue
|
|
}
|
|
|
|
var toParties []string
|
|
isBroadcast := msg.IsBroadcast()
|
|
if !isBroadcast {
|
|
for _, to := range msg.GetTo() {
|
|
toParties = append(toParties, to.Id)
|
|
}
|
|
}
|
|
|
|
fmt.Printf("[TSS-SIGN] sending outgoing message party_id=%s is_broadcast=%v to_parties=%v msg_type=%s\n",
|
|
s.selfParty.PartyID, isBroadcast, toParties, msg.Type())
|
|
|
|
if err := s.msgHandler.SendMessage(ctx, isBroadcast, toParties, msgBytes); err != nil {
|
|
fmt.Printf("[TSS-SIGN] Failed to send message party_id=%s error=%v\n", s.selfParty.PartyID, err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SigningSession) handleIncomingMessages(ctx context.Context) {
|
|
fmt.Printf("[TSS-SIGN] handleIncomingMessages started party_id=%s\n", s.selfParty.PartyID)
|
|
msgCh := s.msgHandler.ReceiveMessages()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
fmt.Printf("[TSS-SIGN] handleIncomingMessages context cancelled party_id=%s\n", s.selfParty.PartyID)
|
|
return
|
|
case msg, ok := <-msgCh:
|
|
if !ok {
|
|
fmt.Printf("[TSS-SIGN] handleIncomingMessages channel closed party_id=%s\n", s.selfParty.PartyID)
|
|
return
|
|
}
|
|
|
|
fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_keygen_index=%d is_broadcast=%v msg_len=%d\n",
|
|
s.selfParty.PartyID, msg.FromPartyIndex, msg.IsBroadcast, len(msg.MsgBytes))
|
|
|
|
// Map keygen index to sorted array index
|
|
// msg.FromPartyIndex is the original keygen party index (e.g., 0, 1, 2)
|
|
// We need the sorted array index (e.g., 0, 1 for a 2-party signing session)
|
|
sortedIndex, exists := s.keygenIndexToSortedIndex[msg.FromPartyIndex]
|
|
if !exists {
|
|
fmt.Printf("[TSS-SIGN] ERROR: unknown keygen index=%d, mapping=%v party_id=%s\n",
|
|
msg.FromPartyIndex, s.keygenIndexToSortedIndex, s.selfParty.PartyID)
|
|
continue
|
|
}
|
|
|
|
fmt.Printf("[TSS-SIGN] mapped keygen_index=%d to sorted_index=%d party_id=%s\n",
|
|
msg.FromPartyIndex, sortedIndex, s.selfParty.PartyID)
|
|
|
|
// Check if sorted index is valid
|
|
if sortedIndex < 0 || sortedIndex >= len(s.tssPartyIDs) {
|
|
fmt.Printf("[TSS-SIGN] ERROR: invalid sortedIndex=%d, len(tssPartyIDs)=%d party_id=%s\n",
|
|
sortedIndex, len(s.tssPartyIDs), s.selfParty.PartyID)
|
|
continue
|
|
}
|
|
|
|
// Parse the message using the sorted index
|
|
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[sortedIndex], msg.IsBroadcast)
|
|
if err != nil {
|
|
fmt.Printf("[TSS-SIGN] ERROR: failed to parse wire message party_id=%s from_index=%d error=%v\n",
|
|
s.selfParty.PartyID, msg.FromPartyIndex, err)
|
|
continue
|
|
}
|
|
|
|
fmt.Printf("[TSS-SIGN] parsed message successfully party_id=%s msg_type=%s\n",
|
|
s.selfParty.PartyID, parsedMsg.Type())
|
|
|
|
// Update the party
|
|
go func() {
|
|
ok, err := s.localParty.Update(parsedMsg)
|
|
if err != nil {
|
|
fmt.Printf("[TSS-SIGN] ERROR: party update failed party_id=%s error=%v\n", s.selfParty.PartyID, err)
|
|
s.errCh <- err
|
|
} else {
|
|
fmt.Printf("[TSS-SIGN] party update succeeded party_id=%s ok=%v\n", s.selfParty.PartyID, ok)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SigningSession) buildResult(signData *common.SignatureData) (*SigningResult, error) {
|
|
// Get R and S as big.Int
|
|
r := new(big.Int).SetBytes(signData.R)
|
|
rS := new(big.Int).SetBytes(signData.S)
|
|
|
|
// Build full signature (R || S)
|
|
signature := make([]byte, 64)
|
|
rBytes := signData.R
|
|
sBytes := signData.S
|
|
|
|
// Pad to 32 bytes each
|
|
copy(signature[32-len(rBytes):32], rBytes)
|
|
copy(signature[64-len(sBytes):64], sBytes)
|
|
|
|
// Calculate recovery ID
|
|
recoveryID := int(signData.SignatureRecovery[0])
|
|
|
|
return &SigningResult{
|
|
Signature: signature,
|
|
R: r,
|
|
S: rS,
|
|
RecoveryID: recoveryID,
|
|
}, nil
|
|
}
|
|
|
|
// LocalSigningResult contains local signing result for standalone testing
|
|
type LocalSigningResult struct {
|
|
Signature []byte
|
|
R *big.Int
|
|
S *big.Int
|
|
RecoveryID int
|
|
}
|
|
|
|
// RunLocalSigning runs signing locally with all parties in the same process (for testing)
|
|
func RunLocalSigning(
|
|
threshold int,
|
|
keygenResults []*LocalKeygenResult,
|
|
messageHash []byte,
|
|
) (*LocalSigningResult, error) {
|
|
signerCount := len(keygenResults)
|
|
if signerCount < threshold {
|
|
return nil, ErrInvalidSignerCount
|
|
}
|
|
|
|
// Create party IDs for signers using their ORIGINAL party indices from keygen
|
|
// This is critical for subset signing - party IDs must match the original keygen party IDs
|
|
partyIDs := make([]*tss.PartyID, signerCount)
|
|
for i, result := range keygenResults {
|
|
idx := result.PartyIndex
|
|
partyIDs[i] = tss.NewPartyID(
|
|
fmt.Sprintf("party-%d", idx),
|
|
fmt.Sprintf("party-%d", idx),
|
|
big.NewInt(int64(idx+1)),
|
|
)
|
|
}
|
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
|
|
// Convert message hash to big.Int
|
|
msgHash := new(big.Int).SetBytes(messageHash)
|
|
|
|
// Create channels for each party
|
|
outChs := make([]chan tss.Message, signerCount)
|
|
endChs := make([]chan *common.SignatureData, signerCount)
|
|
parties := make([]tss.Party, signerCount)
|
|
|
|
// Map sorted party IDs back to keygen results
|
|
sortedKeygenResults := make([]*LocalKeygenResult, signerCount)
|
|
for i, pid := range sortedPartyIDs {
|
|
for _, result := range keygenResults {
|
|
expectedID := fmt.Sprintf("party-%d", result.PartyIndex)
|
|
if pid.Id == expectedID {
|
|
sortedKeygenResults[i] = result
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := 0; i < signerCount; i++ {
|
|
outChs[i] = make(chan tss.Message, signerCount*10)
|
|
endChs[i] = make(chan *common.SignatureData, 1)
|
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], signerCount, threshold)
|
|
parties[i] = signing.NewLocalParty(msgHash, params, *sortedKeygenResults[i].SaveData, outChs[i], endChs[i])
|
|
}
|
|
|
|
// Start all parties
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, signerCount)
|
|
|
|
for i := 0; i < signerCount; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
if err := parties[idx].Start(); err != nil {
|
|
errCh <- err
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Route messages between parties
|
|
var routeWg sync.WaitGroup
|
|
doneCh := make(chan struct{})
|
|
|
|
for i := 0; i < signerCount; i++ {
|
|
routeWg.Add(1)
|
|
go func(idx int) {
|
|
defer routeWg.Done()
|
|
for {
|
|
select {
|
|
case <-doneCh:
|
|
return
|
|
case msg := <-outChs[idx]:
|
|
if msg == nil {
|
|
return
|
|
}
|
|
dest := msg.GetTo()
|
|
if msg.IsBroadcast() {
|
|
for j := 0; j < signerCount; j++ {
|
|
if j != idx {
|
|
go updateSignParty(parties[j], msg, errCh)
|
|
}
|
|
}
|
|
} else {
|
|
for _, d := range dest {
|
|
for j := 0; j < signerCount; j++ {
|
|
if sortedPartyIDs[j].Id == d.Id {
|
|
go updateSignParty(parties[j], msg, errCh)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Collect first result (all parties should produce same signature)
|
|
var result *LocalSigningResult
|
|
for i := 0; i < signerCount; i++ {
|
|
select {
|
|
case signData := <-endChs[i]:
|
|
if result == nil {
|
|
r := new(big.Int).SetBytes(signData.R)
|
|
rS := new(big.Int).SetBytes(signData.S)
|
|
|
|
signature := make([]byte, 64)
|
|
copy(signature[32-len(signData.R):32], signData.R)
|
|
copy(signature[64-len(signData.S):64], signData.S)
|
|
|
|
result = &LocalSigningResult{
|
|
Signature: signature,
|
|
R: r,
|
|
S: rS,
|
|
RecoveryID: int(signData.SignatureRecovery[0]),
|
|
}
|
|
}
|
|
case err := <-errCh:
|
|
close(doneCh)
|
|
return nil, err
|
|
case <-time.After(5 * time.Minute):
|
|
close(doneCh)
|
|
return nil, ErrSigningTimeout
|
|
}
|
|
}
|
|
|
|
close(doneCh)
|
|
return result, nil
|
|
}
|
|
|
|
func updateSignParty(party tss.Party, msg tss.Message, errCh chan error) {
|
|
bytes, routing, err := msg.WireBytes()
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
parsedMsg, err := tss.ParseWireMessage(bytes, msg.GetFrom(), routing.IsBroadcast)
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
if _, err := party.Update(parsedMsg); err != nil {
|
|
// Only send error if it's not a duplicate message error
|
|
if err.Error() != "" && !isSignDuplicateMessageError(err) {
|
|
errCh <- err
|
|
}
|
|
}
|
|
}
|
|
|
|
// isSignDuplicateMessageError checks if an error is a duplicate message error
|
|
func isSignDuplicateMessageError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
errStr := err.Error()
|
|
return strings.Contains(errStr, "duplicate") || strings.Contains(errStr, "already received")
|
|
}
|