464 lines
11 KiB
Go
464 lines
11 KiB
Go
// Package main provides E2E tests for tss-party.exe
|
|
//
|
|
// This test simulates the full flow of keygen and signing using tss-party.exe
|
|
// by spawning multiple processes and coordinating message passing between them.
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/btcsuite/btcd/btcec/v2"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestTSSPartyE2E tests the full keygen and signing flow using tss-party.exe
|
|
func TestTSSPartyE2E(t *testing.T) {
|
|
// Find the tss-party.exe
|
|
exePath, err := filepath.Abs("tss-party.exe")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get exe path: %v", err)
|
|
}
|
|
|
|
// Check if exe exists
|
|
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
|
t.Skipf("tss-party.exe not found at %s, run 'go build' first", exePath)
|
|
}
|
|
|
|
// Test parameters
|
|
sessionID := "test-session-001"
|
|
// In tss-lib: threshold=t means t+1 signers required
|
|
// For 2-of-3: we want 2 signers, so t=1 (1+1=2)
|
|
thresholdT := 1 // t=1 means 2 signers required (t+1=2)
|
|
thresholdN := 3
|
|
password := "test-password-123"
|
|
|
|
participants := []Participant{
|
|
{PartyID: "party-0", PartyIndex: 0},
|
|
{PartyID: "party-1", PartyIndex: 1},
|
|
{PartyID: "party-2", PartyIndex: 2},
|
|
}
|
|
|
|
// ============================================
|
|
// Step 1: Run Keygen
|
|
// ============================================
|
|
t.Log("========================================")
|
|
t.Log(" Step 1: Running Keygen")
|
|
t.Log("========================================")
|
|
|
|
keygenResults := runKeygenE2E(t, exePath, sessionID, thresholdT, thresholdN, participants, password)
|
|
require.Len(t, keygenResults, 3, "Should have 3 keygen results")
|
|
|
|
// Verify all parties have the same public key
|
|
pubKey0 := keygenResults[0].PublicKey
|
|
for i, r := range keygenResults {
|
|
require.Equal(t, pubKey0, r.PublicKey, "Party %d should have same public key", i)
|
|
}
|
|
|
|
t.Logf("Keygen completed! Public key: %s", hex.EncodeToString(pubKey0))
|
|
|
|
// ============================================
|
|
// Step 2: Run Signing with 2 parties
|
|
// ============================================
|
|
t.Log("========================================")
|
|
t.Log(" Step 2: Running Signing (2-of-3)")
|
|
t.Log("========================================")
|
|
|
|
message := []byte("Hello MPC World!")
|
|
messageHash := sha256.Sum256(message)
|
|
|
|
// Sign with parties 0 and 1
|
|
signers := []Participant{
|
|
participants[0],
|
|
participants[1],
|
|
}
|
|
|
|
signerShares := [][]byte{
|
|
keygenResults[0].EncryptedShare,
|
|
keygenResults[1].EncryptedShare,
|
|
}
|
|
|
|
signResult := runSignE2E(t, exePath, sessionID+"-sign", thresholdT, thresholdN, signers, signerShares, messageHash[:], password)
|
|
|
|
t.Logf("Signing completed!")
|
|
t.Logf(" R: %s", signResult.R)
|
|
t.Logf(" S: %s", signResult.S)
|
|
t.Logf(" RecoveryID: %d", signResult.RecoveryID)
|
|
|
|
// ============================================
|
|
// Step 3: Verify Signature
|
|
// ============================================
|
|
t.Log("========================================")
|
|
t.Log(" Step 3: Verifying Signature")
|
|
t.Log("========================================")
|
|
|
|
// Parse public key
|
|
pubKeyECDSA := parseCompressedPublicKey(t, pubKey0)
|
|
|
|
// Parse R and S
|
|
rBigInt, ok := new(big.Int).SetString(signResult.R, 16)
|
|
require.True(t, ok, "Failed to parse R")
|
|
sBigInt, ok := new(big.Int).SetString(signResult.S, 16)
|
|
require.True(t, ok, "Failed to parse S")
|
|
|
|
// Verify
|
|
valid := ecdsa.Verify(pubKeyECDSA, messageHash[:], rBigInt, sBigInt)
|
|
require.True(t, valid, "Signature verification should pass")
|
|
|
|
t.Log("✓ Signature verified successfully!")
|
|
t.Log("========================================")
|
|
t.Log(" E2E Test PASSED!")
|
|
t.Log("========================================")
|
|
}
|
|
|
|
type keygenE2EResult struct {
|
|
PublicKey []byte
|
|
EncryptedShare []byte
|
|
}
|
|
|
|
type signE2EResult struct {
|
|
Signature []byte
|
|
R string
|
|
S string
|
|
RecoveryID int
|
|
}
|
|
|
|
func runKeygenE2E(t *testing.T, exePath, sessionID string, thresholdT, thresholdN int, participants []Participant, password string) []*keygenE2EResult {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
participantsJSON, _ := json.Marshal(participants)
|
|
|
|
// Create processes for all parties
|
|
type partyProc struct {
|
|
cmd *exec.Cmd
|
|
stdin io.WriteCloser
|
|
stdout io.ReadCloser
|
|
party Participant
|
|
}
|
|
|
|
procs := make([]*partyProc, len(participants))
|
|
|
|
for i, p := range participants {
|
|
args := []string{
|
|
"keygen",
|
|
"-session-id", sessionID,
|
|
"-party-id", p.PartyID,
|
|
"-party-index", fmt.Sprintf("%d", p.PartyIndex),
|
|
"-threshold-t", fmt.Sprintf("%d", thresholdT),
|
|
"-threshold-n", fmt.Sprintf("%d", thresholdN),
|
|
"-participants", string(participantsJSON),
|
|
"-password", password,
|
|
}
|
|
|
|
cmd := exec.CommandContext(ctx, exePath, args...)
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
require.NoError(t, err)
|
|
|
|
stdout, err := cmd.StdoutPipe()
|
|
require.NoError(t, err)
|
|
|
|
cmd.Stderr = os.Stderr
|
|
|
|
err = cmd.Start()
|
|
require.NoError(t, err)
|
|
|
|
procs[i] = &partyProc{
|
|
cmd: cmd,
|
|
stdin: stdin,
|
|
stdout: stdout,
|
|
party: p,
|
|
}
|
|
}
|
|
|
|
// Message router
|
|
type routedMsg struct {
|
|
fromIndex int
|
|
isBroadcast bool
|
|
toParties []string
|
|
payload string
|
|
}
|
|
msgChan := make(chan routedMsg, 100)
|
|
|
|
// Result channel
|
|
results := make([]*keygenE2EResult, len(participants))
|
|
var resultsMu sync.Mutex
|
|
|
|
// Read output from all processes
|
|
var wg sync.WaitGroup
|
|
for i, proc := range procs {
|
|
wg.Add(1)
|
|
go func(idx int, p *partyProc) {
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(p.stdout)
|
|
buf := make([]byte, 1024*1024)
|
|
scanner.Buffer(buf, len(buf))
|
|
|
|
for scanner.Scan() {
|
|
var msg Message
|
|
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "outgoing":
|
|
msgChan <- routedMsg{
|
|
fromIndex: p.party.PartyIndex,
|
|
isBroadcast: msg.IsBroadcast,
|
|
toParties: msg.ToParties,
|
|
payload: msg.Payload,
|
|
}
|
|
case "result":
|
|
pubKey, _ := base64.StdEncoding.DecodeString(msg.PublicKey)
|
|
share, _ := base64.StdEncoding.DecodeString(msg.EncryptedShare)
|
|
resultsMu.Lock()
|
|
results[idx] = &keygenE2EResult{
|
|
PublicKey: pubKey,
|
|
EncryptedShare: share,
|
|
}
|
|
resultsMu.Unlock()
|
|
case "progress":
|
|
t.Logf("Party %d: Round %d/%d", idx, msg.Round, msg.TotalRounds)
|
|
case "error":
|
|
t.Errorf("Party %d error: %s", idx, msg.Error)
|
|
}
|
|
}
|
|
}(i, proc)
|
|
}
|
|
|
|
// Route messages between processes
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case m := <-msgChan:
|
|
incomingMsg := Message{
|
|
Type: "incoming",
|
|
IsBroadcast: m.isBroadcast,
|
|
Payload: m.payload,
|
|
FromPartyIndex: m.fromIndex,
|
|
}
|
|
data, _ := json.Marshal(incomingMsg)
|
|
dataLine := string(data) + "\n"
|
|
|
|
if m.isBroadcast {
|
|
// Send to all other parties
|
|
for idx, proc := range procs {
|
|
if idx != m.fromIndex {
|
|
proc.stdin.Write([]byte(dataLine))
|
|
}
|
|
}
|
|
} else {
|
|
// Send to specific parties
|
|
for _, toPartyID := range m.toParties {
|
|
for _, proc := range procs {
|
|
if proc.party.PartyID == toPartyID {
|
|
proc.stdin.Write([]byte(dataLine))
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Wait for all processes to complete
|
|
for _, proc := range procs {
|
|
proc.cmd.Wait()
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
func runSignE2E(t *testing.T, exePath, sessionID string, thresholdT, thresholdN int, participants []Participant, shares [][]byte, messageHash []byte, password string) *signE2EResult {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
participantsJSON, _ := json.Marshal(participants)
|
|
messageHashB64 := base64.StdEncoding.EncodeToString(messageHash)
|
|
|
|
// Create processes for all signers
|
|
type partyProc struct {
|
|
cmd *exec.Cmd
|
|
stdin io.WriteCloser
|
|
stdout io.ReadCloser
|
|
party Participant
|
|
}
|
|
|
|
procs := make([]*partyProc, len(participants))
|
|
|
|
for i, p := range participants {
|
|
// Decrypt share data (remove password hash prefix)
|
|
shareData := shares[i][32:] // Skip the 32-byte password hash
|
|
shareDataB64 := base64.StdEncoding.EncodeToString(shareData)
|
|
|
|
args := []string{
|
|
"sign",
|
|
"-session-id", sessionID,
|
|
"-party-id", p.PartyID,
|
|
"-party-index", fmt.Sprintf("%d", p.PartyIndex),
|
|
"-threshold-t", fmt.Sprintf("%d", thresholdT),
|
|
"-threshold-n", fmt.Sprintf("%d", thresholdN),
|
|
"-participants", string(participantsJSON),
|
|
"-message-hash", messageHashB64,
|
|
"-share-data", shareDataB64,
|
|
}
|
|
|
|
cmd := exec.CommandContext(ctx, exePath, args...)
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
require.NoError(t, err)
|
|
|
|
stdout, err := cmd.StdoutPipe()
|
|
require.NoError(t, err)
|
|
|
|
cmd.Stderr = os.Stderr
|
|
|
|
err = cmd.Start()
|
|
require.NoError(t, err)
|
|
|
|
procs[i] = &partyProc{
|
|
cmd: cmd,
|
|
stdin: stdin,
|
|
stdout: stdout,
|
|
party: p,
|
|
}
|
|
}
|
|
|
|
// Message router
|
|
type routedMsg struct {
|
|
fromIndex int
|
|
isBroadcast bool
|
|
toParties []string
|
|
payload string
|
|
}
|
|
msgChan := make(chan routedMsg, 100)
|
|
|
|
// Result channel
|
|
var result *signE2EResult
|
|
var resultMu sync.Mutex
|
|
resultChan := make(chan struct{})
|
|
|
|
// Read output from all processes
|
|
var wg sync.WaitGroup
|
|
for i, proc := range procs {
|
|
wg.Add(1)
|
|
go func(idx int, p *partyProc) {
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(p.stdout)
|
|
buf := make([]byte, 1024*1024)
|
|
scanner.Buffer(buf, len(buf))
|
|
|
|
for scanner.Scan() {
|
|
var msg Message
|
|
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "outgoing":
|
|
msgChan <- routedMsg{
|
|
fromIndex: p.party.PartyIndex,
|
|
isBroadcast: msg.IsBroadcast,
|
|
toParties: msg.ToParties,
|
|
payload: msg.Payload,
|
|
}
|
|
case "sign_result":
|
|
sig, _ := base64.StdEncoding.DecodeString(msg.Signature)
|
|
resultMu.Lock()
|
|
if result == nil {
|
|
result = &signE2EResult{
|
|
Signature: sig,
|
|
R: msg.R,
|
|
S: msg.S,
|
|
RecoveryID: msg.RecoveryID,
|
|
}
|
|
close(resultChan)
|
|
}
|
|
resultMu.Unlock()
|
|
case "progress":
|
|
t.Logf("Party %d: Round %d/%d", idx, msg.Round, msg.TotalRounds)
|
|
case "error":
|
|
t.Errorf("Party %d error: %s", idx, msg.Error)
|
|
}
|
|
}
|
|
}(i, proc)
|
|
}
|
|
|
|
// Route messages between processes
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case m := <-msgChan:
|
|
incomingMsg := Message{
|
|
Type: "incoming",
|
|
IsBroadcast: m.isBroadcast,
|
|
Payload: m.payload,
|
|
FromPartyIndex: m.fromIndex,
|
|
}
|
|
data, _ := json.Marshal(incomingMsg)
|
|
dataLine := string(data) + "\n"
|
|
|
|
if m.isBroadcast {
|
|
// Send to all other parties
|
|
for idx, proc := range procs {
|
|
if idx != m.fromIndex {
|
|
proc.stdin.Write([]byte(dataLine))
|
|
}
|
|
}
|
|
} else {
|
|
// Send to specific parties
|
|
for _, toPartyID := range m.toParties {
|
|
for _, proc := range procs {
|
|
if proc.party.PartyID == toPartyID {
|
|
proc.stdin.Write([]byte(dataLine))
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Wait for result or timeout
|
|
select {
|
|
case <-resultChan:
|
|
case <-ctx.Done():
|
|
t.Fatal("Signing timed out")
|
|
}
|
|
|
|
// Wait for all processes to complete
|
|
for _, proc := range procs {
|
|
proc.cmd.Wait()
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func parseCompressedPublicKey(t *testing.T, compressed []byte) *ecdsa.PublicKey {
|
|
pubKey, err := btcec.ParsePubKey(compressed)
|
|
require.NoError(t, err)
|
|
|
|
return pubKey.ToECDSA()
|
|
}
|