package tss import ( "context" stdecdsa "crypto/ecdsa" "crypto/sha256" "math/big" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" ) // TestRunLocalKeygen tests the local keygen functionality func TestRunLocalKeygen(t *testing.T) { tests := []struct { name string threshold int totalParties int wantErr bool }{ { name: "2-of-3 keygen", threshold: 2, totalParties: 3, wantErr: false, }, { name: "2-of-2 keygen", threshold: 2, totalParties: 2, wantErr: false, }, { name: "invalid party count", threshold: 2, totalParties: 1, wantErr: true, }, { name: "invalid threshold", threshold: 0, totalParties: 3, wantErr: true, }, { name: "threshold greater than parties", threshold: 4, totalParties: 3, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { results, err := RunLocalKeygen(tt.threshold, tt.totalParties) if (err != nil) != tt.wantErr { t.Errorf("RunLocalKeygen() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { return } // Verify results if len(results) != tt.totalParties { t.Errorf("Expected %d results, got %d", tt.totalParties, len(results)) return } // Verify all parties have the same public key var firstPubKey *stdecdsa.PublicKey for i, result := range results { if result.SaveData == nil { t.Errorf("Party %d has nil SaveData", i) continue } if result.PublicKey == nil { t.Errorf("Party %d has nil PublicKey", i) continue } if firstPubKey == nil { firstPubKey = result.PublicKey } else { // Compare public keys if result.PublicKey.X.Cmp(firstPubKey.X) != 0 || result.PublicKey.Y.Cmp(firstPubKey.Y) != 0 { t.Errorf("Party %d has different public key", i) } } } t.Logf("Keygen successful: %d-of-%d, public key X: %s", tt.threshold, tt.totalParties, firstPubKey.X.Text(16)[:16]+"...") }) } } // TestRunLocalSigning tests the local signing functionality func TestRunLocalSigning(t *testing.T) { // First run keygen to get key shares threshold := 2 totalParties := 3 keygenResults, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } // Create message hash message := []byte("Hello, MPC signing!") messageHash := sha256.Sum256(message) // Run signing signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:]) if err != nil { t.Fatalf("Signing failed: %v", err) } // Verify signature if signResult == nil { t.Fatal("Sign result is nil") } if len(signResult.Signature) != 64 { t.Errorf("Expected 64-byte signature, got %d bytes", len(signResult.Signature)) } if signResult.R == nil || signResult.S == nil { t.Error("R or S is nil") } // Verify signature using the public key pubKey := keygenResults[0].PublicKey valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S) if !valid { t.Error("Signature verification failed") } t.Logf("Signing successful: R=%s..., S=%s...", signResult.R.Text(16)[:16], signResult.S.Text(16)[:16]) } // TestMultipleSigning tests signing multiple messages with the same keys func TestMultipleSigning(t *testing.T) { // Run keygen threshold := 2 totalParties := 3 keygenResults, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } messages := []string{ "First message", "Second message", "Third message", } pubKey := keygenResults[0].PublicKey for i, msg := range messages { messageHash := sha256.Sum256([]byte(msg)) signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:]) if err != nil { t.Errorf("Signing message %d failed: %v", i, err) continue } valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S) if !valid { t.Errorf("Signature %d verification failed", i) } } } // TestSigningWithSubsetOfParties tests signing with a subset of parties // In tss-lib, threshold `t` means `t+1` parties are needed to sign. // For a 2-of-3 scheme (2 signers needed), we use threshold=1 (1+1=2). func TestSigningWithSubsetOfParties(t *testing.T) { // For a 2-of-3 scheme in tss-lib: // - totalParties (n) = 3 // - threshold (t) = 1 (meaning t+1=2 parties are required to sign) threshold := 1 totalParties := 3 keygenResults, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } // Sign with only 2 parties (party 0 and party 1) - this should work with t=1 signers := []*LocalKeygenResult{ keygenResults[0], keygenResults[1], } message := []byte("Threshold signing test") messageHash := sha256.Sum256(message) signResult, err := RunLocalSigning(threshold, signers, messageHash[:]) if err != nil { t.Fatalf("Signing with subset failed: %v", err) } // Verify signature pubKey := keygenResults[0].PublicKey valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S) if !valid { t.Error("Signature verification failed for subset signing") } t.Log("Subset signing (2-of-3) successful with threshold=1") } // TestSigningWithDifferentSubsets tests signing with different party combinations // In tss-lib, threshold `t` means `t+1` parties are needed to sign. // For a 2-of-3 scheme (2 signers needed), we use threshold=1. func TestSigningWithDifferentSubsets(t *testing.T) { // For 2-of-3 in tss-lib terminology: threshold=1 means t+1=2 signers needed threshold := 1 totalParties := 3 keygenResults, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } pubKey := keygenResults[0].PublicKey // Test different combinations of 2 parties (the minimum required with t=1) combinations := [][]*LocalKeygenResult{ {keygenResults[0], keygenResults[1]}, // parties 0,1 {keygenResults[0], keygenResults[2]}, // parties 0,2 {keygenResults[1], keygenResults[2]}, // parties 1,2 } for i, signers := range combinations { message := []byte("Test message " + string(rune('A'+i))) messageHash := sha256.Sum256(message) signResult, err := RunLocalSigning(threshold, signers, messageHash[:]) if err != nil { t.Errorf("Signing with combination %d failed: %v", i, err) continue } valid := stdecdsa.Verify(pubKey, messageHash[:], signResult.R, signResult.S) if !valid { t.Errorf("Signature verification failed for combination %d", i) } } t.Log("All subset combinations successful") } // TestKeygenResultConsistency tests that all parties produce consistent results func TestKeygenResultConsistency(t *testing.T) { threshold := 2 totalParties := 3 results, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } // All parties should have the same ECDSAPub var refX, refY *big.Int for i, result := range results { if i == 0 { refX = result.SaveData.ECDSAPub.X() refY = result.SaveData.ECDSAPub.Y() } else { if result.SaveData.ECDSAPub.X().Cmp(refX) != 0 { t.Errorf("Party %d X coordinate mismatch", i) } if result.SaveData.ECDSAPub.Y().Cmp(refY) != 0 { t.Errorf("Party %d Y coordinate mismatch", i) } } } } // TestSignatureRecovery tests that the recovery ID allows public key recovery func TestSignatureRecovery(t *testing.T) { threshold := 2 totalParties := 3 keygenResults, err := RunLocalKeygen(threshold, totalParties) if err != nil { t.Fatalf("Keygen failed: %v", err) } message := []byte("Recovery test message") messageHash := sha256.Sum256(message) signResult, err := RunLocalSigning(threshold, keygenResults, messageHash[:]) if err != nil { t.Fatalf("Signing failed: %v", err) } // Verify the recovery ID is valid (0-3) if signResult.RecoveryID < 0 || signResult.RecoveryID > 3 { t.Errorf("Invalid recovery ID: %d", signResult.RecoveryID) } // Verify we can create a btcec signature and verify it r := new(btcec.ModNScalar) r.SetByteSlice(signResult.R.Bytes()) s := new(btcec.ModNScalar) s.SetByteSlice(signResult.S.Bytes()) btcSig := ecdsa.NewSignature(r, s) // Convert public key to btcec format originalPub := keygenResults[0].PublicKey btcPubKey, err := btcec.ParsePubKey(append([]byte{0x04}, append(originalPub.X.Bytes(), originalPub.Y.Bytes()...)...)) if err != nil { t.Logf("Failed to parse public key: %v", err) return } // Verify the signature verified := btcSig.Verify(messageHash[:], btcPubKey) if !verified { t.Error("btcec signature verification failed") } else { t.Log("btcec signature verification successful") } } // TestNewKeygenSession tests creating a new keygen session func TestNewKeygenSession(t *testing.T) { config := KeygenConfig{ Threshold: 2, TotalParties: 3, } selfParty := KeygenParty{PartyID: "party-0", PartyIndex: 0} allParties := []KeygenParty{ {PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}, } // Create a mock message handler handler := &mockMessageHandler{ msgCh: make(chan *ReceivedMessage, 100), } session, err := NewKeygenSession(config, selfParty, allParties, handler) if err != nil { t.Fatalf("Failed to create keygen session: %v", err) } if session == nil { t.Fatal("Session is nil") } } // TestNewKeygenSessionValidation tests validation in NewKeygenSession func TestNewKeygenSessionValidation(t *testing.T) { tests := []struct { name string config KeygenConfig selfParty KeygenParty allParties []KeygenParty wantErr bool expectedErr error }{ { name: "invalid party count", config: KeygenConfig{ Threshold: 2, TotalParties: 1, }, selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0}, allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}}, wantErr: true, expectedErr: ErrInvalidPartyCount, }, { name: "invalid threshold - zero", config: KeygenConfig{ Threshold: 0, TotalParties: 3, }, selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0}, allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}, {PartyID: "party-2", PartyIndex: 2}}, wantErr: true, expectedErr: ErrInvalidThreshold, }, { name: "mismatched party count", config: KeygenConfig{ Threshold: 2, TotalParties: 3, }, selfParty: KeygenParty{PartyID: "party-0", PartyIndex: 0}, allParties: []KeygenParty{{PartyID: "party-0", PartyIndex: 0}, {PartyID: "party-1", PartyIndex: 1}}, wantErr: true, expectedErr: ErrInvalidPartyCount, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { handler := &mockMessageHandler{msgCh: make(chan *ReceivedMessage)} _, err := NewKeygenSession(tt.config, tt.selfParty, tt.allParties, handler) if (err != nil) != tt.wantErr { t.Errorf("NewKeygenSession() error = %v, wantErr %v", err, tt.wantErr) } if tt.expectedErr != nil && err != tt.expectedErr { t.Errorf("Expected error %v, got %v", tt.expectedErr, err) } }) } } // mockMessageHandler is a mock implementation of MessageHandler for testing type mockMessageHandler struct { msgCh chan *ReceivedMessage sentMsgs []sentMessage } type sentMessage struct { isBroadcast bool toParties []string msgBytes []byte } func (m *mockMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error { m.sentMsgs = append(m.sentMsgs, sentMessage{ isBroadcast: isBroadcast, toParties: toParties, msgBytes: msgBytes, }) return nil } func (m *mockMessageHandler) ReceiveMessages() <-chan *ReceivedMessage { return m.msgCh } // BenchmarkKeygen benchmarks the keygen operation func BenchmarkKeygen2of3(b *testing.B) { for i := 0; i < b.N; i++ { _, err := RunLocalKeygen(2, 3) if err != nil { b.Fatalf("Keygen failed: %v", err) } } } // BenchmarkSigning benchmarks the signing operation func BenchmarkSigning2of3(b *testing.B) { // Setup: run keygen once keygenResults, err := RunLocalKeygen(2, 3) if err != nil { b.Fatalf("Keygen failed: %v", err) } message := []byte("Benchmark signing message") messageHash := sha256.Sum256(message) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := RunLocalSigning(2, keygenResults, messageHash[:]) if err != nil { b.Fatalf("Signing failed: %v", err) } } }