rwadurian/backend/mpc-system/pkg/tss/tss_test.go

477 lines
13 KiB
Go

package tss
import (
"context"
stdecdsa "crypto/ecdsa"
"crypto/sha256"
"math/big"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
)
// TestRunLocalKeygen tests the local keygen functionality
func TestRunLocalKeygen(t *testing.T) {
tests := []struct {
name string
threshold int
totalParties int
wantErr bool
}{
{
name: "2-of-3 keygen",
threshold: 2,
totalParties: 3,
wantErr: false,
},
{
name: "2-of-2 keygen",
threshold: 2,
totalParties: 2,
wantErr: false,
},
{
name: "invalid party count",
threshold: 2,
totalParties: 1,
wantErr: true,
},
{
name: "invalid threshold",
threshold: 0,
totalParties: 3,
wantErr: true,
},
{
name: "threshold greater than parties",
threshold: 4,
totalParties: 3,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := RunLocalKeygen(tt.threshold, tt.totalParties)
if (err != nil) != tt.wantErr {
t.Errorf("RunLocalKeygen() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// Verify results
if len(results) != tt.totalParties {
t.Errorf("Expected %d results, got %d", tt.totalParties, len(results))
return
}
// Verify all parties have the same public key
var firstPubKey *stdecdsa.PublicKey
for i, result := range results {
if result.SaveData == nil {
t.Errorf("Party %d has nil SaveData", i)
continue
}
if result.PublicKey == nil {
t.Errorf("Party %d has nil PublicKey", i)
continue
}
if firstPubKey == nil {
firstPubKey = result.PublicKey
} else {
// Compare public keys
if result.PublicKey.X.Cmp(firstPubKey.X) != 0 ||
result.PublicKey.Y.Cmp(firstPubKey.Y) != 0 {
t.Errorf("Party %d has different public key", i)
}
}
}
t.Logf("Keygen successful: %d-of-%d, public key X: %s",
tt.threshold, tt.totalParties, firstPubKey.X.Text(16)[:16]+"...")
})
}
}
// TestRunLocalSigning tests the local signing functionality
func TestRunLocalSigning(t *testing.T) {
// First run keygen to get key shares
threshold := 2
totalParties := 3
keygenResults, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
// Create message hash
message := []byte("Hello, MPC signing!")
messageHash := sha256.Sum256(message)
// Run signing
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
if err != nil {
t.Fatalf("Signing failed: %v", err)
}
// Verify signature
if signResult == nil {
t.Fatal("Sign result is nil")
}
if len(signResult.Signature) != 64 {
t.Errorf("Expected 64-byte signature, got %d bytes", len(signResult.Signature))
}
if signResult.R == nil || signResult.S == nil {
t.Error("R or S is nil")
}
// Verify signature using the public key
pubKey := keygenResults[0].PublicKey
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
if !valid {
t.Error("Signature verification failed")
}
t.Logf("Signing successful: R=%s..., S=%s...",
signResult.R.Text(16)[:16], signResult.S.Text(16)[:16])
}
// TestMultipleSigning tests signing multiple messages with the same keys
func TestMultipleSigning(t *testing.T) {
// Run keygen
threshold := 2
totalParties := 3
keygenResults, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
messages := []string{
"First message",
"Second message",
"Third message",
}
pubKey := keygenResults[0].PublicKey
for i, msg := range messages {
messageHash := sha256.Sum256([]byte(msg))
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
if err != nil {
t.Errorf("Signing message %d failed: %v", i, err)
continue
}
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
if !valid {
t.Errorf("Signature %d verification failed", i)
}
}
}
// TestSigningWithSubsetOfParties tests signing with a subset of parties
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
// For a 2-of-3 scheme (2 signers needed), we use threshold=1 (1+1=2).
func TestSigningWithSubsetOfParties(t *testing.T) {
// For a 2-of-3 scheme in tss-lib:
// - totalParties (n) = 3
// - threshold (t) = 1 (meaning t+1=2 parties are required to sign)
threshold := 1
totalParties := 3
keygenResults, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
// Sign with only 2 parties (party 0 and party 1) - this should work with t=1
signers := []*LocalKeygenResult{
keygenResults[0],
keygenResults[1],
}
message := []byte("Threshold signing test")
messageHash := sha256.Sum256(message)
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
if err != nil {
t.Fatalf("Signing with subset failed: %v", err)
}
// Verify signature
pubKey := keygenResults[0].PublicKey
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
if !valid {
t.Error("Signature verification failed for subset signing")
}
t.Log("Subset signing (2-of-3) successful with threshold=1")
}
// TestSigningWithDifferentSubsets tests signing with different party combinations
// In tss-lib, threshold `t` means `t+1` parties are needed to sign.
// For a 2-of-3 scheme (2 signers needed), we use threshold=1.
func TestSigningWithDifferentSubsets(t *testing.T) {
// For 2-of-3 in tss-lib terminology: threshold=1 means t+1=2 signers needed
threshold := 1
totalParties := 3
keygenResults, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
pubKey := keygenResults[0].PublicKey
// Test different combinations of 2 parties (the minimum required with t=1)
combinations := [][]*LocalKeygenResult{
{keygenResults[0], keygenResults[1]}, // parties 0,1
{keygenResults[0], keygenResults[2]}, // parties 0,2
{keygenResults[1], keygenResults[2]}, // parties 1,2
}
for i, signers := range combinations {
message := []byte("Test message " + string(rune('A'+i)))
messageHash := sha256.Sum256(message)
signResult, err := RunLocalSigning(threshold, signers, messageHash[:])
if err != nil {
t.Errorf("Signing with combination %d failed: %v", i, err)
continue
}
valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S)
if !valid {
t.Errorf("Signature verification failed for combination %d", i)
}
}
t.Log("All subset combinations successful")
}
// TestKeygenResultConsistency tests that all parties produce consistent results
func TestKeygenResultConsistency(t *testing.T) {
threshold := 2
totalParties := 3
results, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
// All parties should have the same ECDSAPub
var refX, refY *big.Int
for i, result := range results {
if i == 0 {
refX = result.SaveData.ECDSAPub.X()
refY = result.SaveData.ECDSAPub.Y()
} else {
if result.SaveData.ECDSAPub.X().Cmp(refX) != 0 {
t.Errorf("Party %d X coordinate mismatch", i)
}
if result.SaveData.ECDSAPub.Y().Cmp(refY) != 0 {
t.Errorf("Party %d Y coordinate mismatch", i)
}
}
}
}
// TestSignatureRecovery tests that the recovery ID allows public key recovery
func TestSignatureRecovery(t *testing.T) {
threshold := 2
totalParties := 3
keygenResults, err := RunLocalKeygen(threshold, totalParties)
if err != nil {
t.Fatalf("Keygen failed: %v", err)
}
message := []byte("Recovery test message")
messageHash := sha256.Sum256(message)
signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:])
if err != nil {
t.Fatalf("Signing failed: %v", err)
}
// Verify the recovery ID is valid (0-3)
if signResult.RecoveryID < 0 || signResult.RecoveryID > 3 {
t.Errorf("Invalid recovery ID: %d", signResult.RecoveryID)
}
// Verify we can create a btcec signature and verify it
r := new(btcec.ModNScalar)
r.SetByteSlice(signResult.R.Bytes())
s := new(btcec.ModNScalar)
s.SetByteSlice(signResult.S.Bytes())
btcSig := ecdsa.NewSignature(r, s)
// Convert public key to btcec format
originalPub := keygenResults[0].PublicKey
btcPubKey, err := btcec.ParsePubKey(append([]byte{0x04}, append(originalPub.X.Bytes(), originalPub.Y.Bytes()...)...))
if err != nil {
t.Logf("Failed to parse public key: %v", err)
return
}
// Verify the signature
verified := btcSig.Verify(messageHash[:], btcPubKey)
if !verified {
t.Error("btcec signature verification failed")
} else {
t.Log("btcec signature verification successful")
}
}
// TestNewKeygenSession tests creating a new keygen session
func TestNewKeygenSession(t *testing.T) {
config := KeygenConfig{
Threshold: 2,
TotalParties: 3,
}
selfParty := KeygenParty{PartyID: "party-0", PartyIndex: 0}
allParties := []KeygenParty{
{PartyID: "party-0", PartyIndex: 0},
{PartyID: "party-1", PartyIndex: 1},
{PartyID: "party-2", PartyIndex: 2},
}
// Create a mock message handler
handler := &mockMessageHandler{
msgCh: make(chan *ReceivedMessage, 100),
}
session, err := NewKeygenSession(config, selfParty, allParties, handler)
if err != nil {
t.Fatalf("Failed to create keygen session: %v", err)
}
if session == nil {
t.Fatal("Session is nil")
}
}
// TestNewKeygenSessionValidation tests validation in NewKeygenSession
func TestNewKeygenSessionValidation(t *testing.T) {
tests := []struct {
name string
config KeygenConfig
selfParty KeygenParty
allParties []KeygenParty
wantErr bool
expectedErr error
}{
{
name: "invalid party count",
config: KeygenConfig{
Threshold: 2,
TotalParties: 1,
},
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}},
wantErr: true,
expectedErr: ErrInvalidPartyCount,
},
{
name: "invalid threshold - zero",
config: KeygenConfig{
Threshold: 0,
TotalParties: 3,
},
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}},
wantErr: true,
expectedErr: ErrInvalidThreshold,
},
{
name: "mismatched party count",
config: KeygenConfig{
Threshold: 2,
TotalParties: 3,
},
selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0},
allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}},
wantErr: true,
expectedErr: ErrInvalidPartyCount,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &mockMessageHandler{msgCh: make(chan *ReceivedMessage)}
_, err := NewKeygenSession(tt.config, tt.selfParty, tt.allParties, handler)
if (err != nil) != tt.wantErr {
t.Errorf("NewKeygenSession() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.expectedErr != nil && err != tt.expectedErr {
t.Errorf("Expected error %v, got %v", tt.expectedErr, err)
}
})
}
}
// mockMessageHandler is a mock implementation of MessageHandler for testing
type mockMessageHandler struct {
msgCh chan *ReceivedMessage
sentMsgs []sentMessage
}
type sentMessage struct {
isBroadcast bool
toParties []string
msgBytes []byte
}
func (m *mockMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
m.sentMsgs = append(m.sentMsgs, sentMessage{
isBroadcast: isBroadcast,
toParties: toParties,
msgBytes: msgBytes,
})
return nil
}
func (m *mockMessageHandler) ReceiveMessages() <-chan *ReceivedMessage {
return m.msgCh
}
// BenchmarkKeygen benchmarks the keygen operation
func BenchmarkKeygen2of3(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := RunLocalKeygen(2, 3)
if err != nil {
b.Fatalf("Keygen failed: %v", err)
}
}
}
// BenchmarkSigning benchmarks the signing operation
func BenchmarkSigning2of3(b *testing.B) {
// Setup: run keygen once
keygenResults, err := RunLocalKeygen(2, 3)
if err != nil {
b.Fatalf("Keygen failed: %v", err)
}
message := []byte("Benchmark signing message")
messageHash := sha256.Sum256(message)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := RunLocalSigning(2, keygenResults, messageHash[:])
if err != nil {
b.Fatalf("Signing failed: %v", err)
}
}
}