477 lines
13 KiB
Go
477 lines
13 KiB
Go
package tss
|
|
|
|
import (
|
|
"context"
|
|
stdecdsa "crypto/ecdsa"
|
|
"crypto/sha256"
|
|
"math/big"
|
|
"testing"
|
|
|
|
"github.com/btcsuite/btcd/btcec/v2"
|
|
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
|
|
)
|
|
|
|
// TestRunLocalKeygen tests the local keygen functionality
|
|
func TestRunLocalKeygen(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
threshold int
|
|
totalParties int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "2-of-3 keygen",
|
|
threshold: 2,
|
|
totalParties: 3,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "2-of-2 keygen",
|
|
threshold: 2,
|
|
totalParties: 2,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid party count",
|
|
threshold: 2,
|
|
totalParties: 1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid threshold",
|
|
threshold: 0,
|
|
totalParties: 3,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "threshold greater than parties",
|
|
threshold: 4,
|
|
totalParties: 3,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := RunLocalKeygen(tt.threshold, tt.totalParties)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("RunLocalKeygen() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
|
|
// Verify results
|
|
if len(results) != tt.totalParties {
|
|
t.Errorf("Expected %d results, got %d", tt.totalParties, len(results))
|
|
return
|
|
}
|
|
|
|
// Verify all parties have the same public key
|
|
var firstPubKey *stdecdsa.PublicKey
|
|
for i, result := range results {
|
|
if result.SaveData == nil {
|
|
t.Errorf("Party %d has nil SaveData", i)
|
|
continue
|
|
}
|
|
if result.PublicKey == nil {
|
|
t.Errorf("Party %d has nil PublicKey", i)
|
|
continue
|
|
}
|
|
|
|
if firstPubKey == nil {
|
|
firstPubKey = result.PublicKey
|
|
} else {
|
|
// Compare public keys
|
|
if result.PublicKey.X.Cmp(firstPubKey.X) != 0 ||
|
|
result.PublicKey.Y.Cmp(firstPubKey.Y) != 0 {
|
|
t.Errorf("Party %d has different public key", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
t.Logf("Keygen successful: %d-of-%d, public key X: %s",
|
|
tt.threshold, tt.totalParties, firstPubKey.X.Text(16)[:16]+"...")
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRunLocalSigning tests the local signing functionality
|
|
func TestRunLocalSigning(t *testing.T) {
|
|
// First run keygen to get key shares
|
|
threshold := 2
|
|
totalParties := 3
|
|
|
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
// Create message hash
|
|
message := []byte("Hello, MPC signing!")
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
// Run signing
|
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
|
if err != nil {
|
|
t.Fatalf("Signing failed: %v", err)
|
|
}
|
|
|
|
// Verify signature
|
|
if signResult == nil {
|
|
t.Fatal("Sign result is nil")
|
|
}
|
|
|
|
if len(signResult.Signature) != 64 {
|
|
t.Errorf("Expected 64-byte signature, got %d bytes", len(signResult.Signature))
|
|
}
|
|
|
|
if signResult.R == nil || signResult.S == nil {
|
|
t.Error("R or S is nil")
|
|
}
|
|
|
|
// Verify signature using the public key
|
|
pubKey := keygenResults[0].PublicKey
|
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
|
if !valid {
|
|
t.Error("Signature verification failed")
|
|
}
|
|
|
|
t.Logf("Signing successful: R=%s..., S=%s...",
|
|
signResult.R.Text(16)[:16], signResult.S.Text(16)[:16])
|
|
}
|
|
|
|
// TestMultipleSigning tests signing multiple messages with the same keys
|
|
func TestMultipleSigning(t *testing.T) {
|
|
// Run keygen
|
|
threshold := 2
|
|
totalParties := 3
|
|
|
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
messages := []string{
|
|
"First message",
|
|
"Second message",
|
|
"Third message",
|
|
}
|
|
|
|
pubKey := keygenResults[0].PublicKey
|
|
|
|
for i, msg := range messages {
|
|
messageHash := sha256.Sum256([]byte(msg))
|
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
|
if err != nil {
|
|
t.Errorf("Signing message %d failed: %v", i, err)
|
|
continue
|
|
}
|
|
|
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
|
if !valid {
|
|
t.Errorf("Signature %d verification failed", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSigningWithSubsetOfParties tests signing with a subset of parties
|
|
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
|
// For a 2-of-3 scheme (2 signers needed), we use threshold=1 (1+1=2).
|
|
func TestSigningWithSubsetOfParties(t *testing.T) {
|
|
// For a 2-of-3 scheme in tss-lib:
|
|
// - totalParties (n) = 3
|
|
// - threshold (t) = 1 (meaning t+1=2 parties are required to sign)
|
|
threshold := 1
|
|
totalParties := 3
|
|
|
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
// Sign with only 2 parties (party 0 and party 1) - this should work with t=1
|
|
signers := []*LocalKeygenResult{
|
|
keygenResults[0],
|
|
keygenResults[1],
|
|
}
|
|
|
|
message := []byte("Threshold signing test")
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
|
if err != nil {
|
|
t.Fatalf("Signing with subset failed: %v", err)
|
|
}
|
|
|
|
// Verify signature
|
|
pubKey := keygenResults[0].PublicKey
|
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
|
if !valid {
|
|
t.Error("Signature verification failed for subset signing")
|
|
}
|
|
|
|
t.Log("Subset signing (2-of-3) successful with threshold=1")
|
|
}
|
|
|
|
// TestSigningWithDifferentSubsets tests signing with different party combinations
|
|
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
|
|
// For a 2-of-3 scheme (2 signers needed), we use threshold=1.
|
|
func TestSigningWithDifferentSubsets(t *testing.T) {
|
|
// For 2-of-3 in tss-lib terminology: threshold=1 means t+1=2 signers needed
|
|
threshold := 1
|
|
totalParties := 3
|
|
|
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
pubKey := keygenResults[0].PublicKey
|
|
|
|
// Test different combinations of 2 parties (the minimum required with t=1)
|
|
combinations := [][]*LocalKeygenResult{
|
|
{keygenResults[0], keygenResults[1]}, // parties 0,1
|
|
{keygenResults[0], keygenResults[2]}, // parties 0,2
|
|
{keygenResults[1], keygenResults[2]}, // parties 1,2
|
|
}
|
|
|
|
for i, signers := range combinations {
|
|
message := []byte("Test message " + string(rune('A'+i)))
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
|
|
if err != nil {
|
|
t.Errorf("Signing with combination %d failed: %v", i, err)
|
|
continue
|
|
}
|
|
|
|
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
|
|
if !valid {
|
|
t.Errorf("Signature verification failed for combination %d", i)
|
|
}
|
|
}
|
|
|
|
t.Log("All subset combinations successful")
|
|
}
|
|
|
|
// TestKeygenResultConsistency tests that all parties produce consistent results
|
|
func TestKeygenResultConsistency(t *testing.T) {
|
|
threshold := 2
|
|
totalParties := 3
|
|
|
|
results, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
// All parties should have the same ECDSAPub
|
|
var refX, refY *big.Int
|
|
for i, result := range results {
|
|
if i == 0 {
|
|
refX = result.SaveData.ECDSAPub.X()
|
|
refY = result.SaveData.ECDSAPub.Y()
|
|
} else {
|
|
if result.SaveData.ECDSAPub.X().Cmp(refX) != 0 {
|
|
t.Errorf("Party %d X coordinate mismatch", i)
|
|
}
|
|
if result.SaveData.ECDSAPub.Y().Cmp(refY) != 0 {
|
|
t.Errorf("Party %d Y coordinate mismatch", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSignatureRecovery tests that the recovery ID allows public key recovery
|
|
func TestSignatureRecovery(t *testing.T) {
|
|
threshold := 2
|
|
totalParties := 3
|
|
|
|
keygenResults, err := RunLocalKeygen(threshold, totalParties)
|
|
if err != nil {
|
|
t.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
message := []byte("Recovery test message")
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
|
|
if err != nil {
|
|
t.Fatalf("Signing failed: %v", err)
|
|
}
|
|
|
|
// Verify the recovery ID is valid (0-3)
|
|
if signResult.RecoveryID < 0 || signResult.RecoveryID > 3 {
|
|
t.Errorf("Invalid recovery ID: %d", signResult.RecoveryID)
|
|
}
|
|
|
|
// Verify we can create a btcec signature and verify it
|
|
r := new(btcec.ModNScalar)
|
|
r.SetByteSlice(signResult.R.Bytes())
|
|
s := new(btcec.ModNScalar)
|
|
s.SetByteSlice(signResult.S.Bytes())
|
|
|
|
btcSig := ecdsa.NewSignature(r, s)
|
|
|
|
// Convert public key to btcec format
|
|
originalPub := keygenResults[0].PublicKey
|
|
btcPubKey, err := btcec.ParsePubKey(append([]byte{0x04}, append(originalPub.X.Bytes(), originalPub.Y.Bytes()...)...))
|
|
if err != nil {
|
|
t.Logf("Failed to parse public key: %v", err)
|
|
return
|
|
}
|
|
|
|
// Verify the signature
|
|
verified := btcSig.Verify(messageHash[:], btcPubKey)
|
|
if !verified {
|
|
t.Error("btcec signature verification failed")
|
|
} else {
|
|
t.Log("btcec signature verification successful")
|
|
}
|
|
}
|
|
|
|
// TestNewKeygenSession tests creating a new keygen session
|
|
func TestNewKeygenSession(t *testing.T) {
|
|
config := KeygenConfig{
|
|
Threshold: 2,
|
|
TotalParties: 3,
|
|
}
|
|
|
|
selfParty := KeygenParty{PartyID: "party-0", PartyIndex: 0}
|
|
allParties := []KeygenParty{
|
|
{PartyID: "party-0", PartyIndex: 0},
|
|
{PartyID: "party-1", PartyIndex: 1},
|
|
{PartyID: "party-2", PartyIndex: 2},
|
|
}
|
|
|
|
// Create a mock message handler
|
|
handler := &mockMessageHandler{
|
|
msgCh: make(chan *ReceivedMessage, 100),
|
|
}
|
|
|
|
session, err := NewKeygenSession(config, selfParty, allParties, handler)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create keygen session: %v", err)
|
|
}
|
|
|
|
if session == nil {
|
|
t.Fatal("Session is nil")
|
|
}
|
|
}
|
|
|
|
// TestNewKeygenSessionValidation tests validation in NewKeygenSession
|
|
func TestNewKeygenSessionValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config KeygenConfig
|
|
selfParty KeygenParty
|
|
allParties []KeygenParty
|
|
wantErr bool
|
|
expectedErr error
|
|
}{
|
|
{
|
|
name: "invalid party count",
|
|
config: KeygenConfig{
|
|
Threshold: 2,
|
|
TotalParties: 1,
|
|
},
|
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}},
|
|
wantErr: true,
|
|
expectedErr: ErrInvalidPartyCount,
|
|
},
|
|
{
|
|
name: "invalid threshold - zero",
|
|
config: KeygenConfig{
|
|
Threshold: 0,
|
|
TotalParties: 3,
|
|
},
|
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}},
|
|
wantErr: true,
|
|
expectedErr: ErrInvalidThreshold,
|
|
},
|
|
{
|
|
name: "mismatched party count",
|
|
config: KeygenConfig{
|
|
Threshold: 2,
|
|
TotalParties: 3,
|
|
},
|
|
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
|
|
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}},
|
|
wantErr: true,
|
|
expectedErr: ErrInvalidPartyCount,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
handler := &mockMessageHandler{msgCh: make(chan *ReceivedMessage)}
|
|
_, err := NewKeygenSession(tt.config, tt.selfParty, tt.allParties, handler)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewKeygenSession() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if tt.expectedErr != nil && err != tt.expectedErr {
|
|
t.Errorf("Expected error %v, got %v", tt.expectedErr, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// mockMessageHandler is a mock implementation of MessageHandler for testing
|
|
type mockMessageHandler struct {
|
|
msgCh chan *ReceivedMessage
|
|
sentMsgs []sentMessage
|
|
}
|
|
|
|
type sentMessage struct {
|
|
isBroadcast bool
|
|
toParties []string
|
|
msgBytes []byte
|
|
}
|
|
|
|
func (m *mockMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
|
m.sentMsgs = append(m.sentMsgs, sentMessage{
|
|
isBroadcast: isBroadcast,
|
|
toParties: toParties,
|
|
msgBytes: msgBytes,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (m *mockMessageHandler) ReceiveMessages() <-chan *ReceivedMessage {
|
|
return m.msgCh
|
|
}
|
|
|
|
// BenchmarkKeygen benchmarks the keygen operation
|
|
func BenchmarkKeygen2of3(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := RunLocalKeygen(2, 3)
|
|
if err != nil {
|
|
b.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// BenchmarkSigning benchmarks the signing operation
|
|
func BenchmarkSigning2of3(b *testing.B) {
|
|
// Setup: run keygen once
|
|
keygenResults, err := RunLocalKeygen(2, 3)
|
|
if err != nil {
|
|
b.Fatalf("Keygen failed: %v", err)
|
|
}
|
|
|
|
message := []byte("Benchmark signing message")
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := RunLocalSigning(2, keygenResults, messageHash[:])
|
|
if err != nil {
|
|
b.Fatalf("Signing failed: %v", err)
|
|
}
|
|
}
|
|
}
|