299 lines
7.8 KiB
Go
299 lines
7.8 KiB
Go
// +build ignore
|
|
|
|
// Test script to verify tss-party.exe keygen with 3 parties
|
|
// Run with: go run test_keygen.go
|
|
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Message types for IPC
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
IsBroadcast bool `json:"isBroadcast,omitempty"`
|
|
ToParties []string `json:"toParties,omitempty"`
|
|
Payload string `json:"payload,omitempty"`
|
|
PublicKey string `json:"publicKey,omitempty"`
|
|
EncryptedShare string `json:"encryptedShare,omitempty"`
|
|
PartyIndex int `json:"partyIndex,omitempty"`
|
|
Round int `json:"round,omitempty"`
|
|
TotalRounds int `json:"totalRounds,omitempty"`
|
|
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type Participant struct {
|
|
PartyID string `json:"partyId"`
|
|
PartyIndex int `json:"partyIndex"`
|
|
}
|
|
|
|
type Party struct {
|
|
cmd *exec.Cmd
|
|
stdin io.WriteCloser
|
|
stdout io.ReadCloser
|
|
stderr io.ReadCloser
|
|
partyID string
|
|
partyIndex int
|
|
outMessages chan Message
|
|
result *Message
|
|
err error
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func main() {
|
|
fmt.Println("=== TSS Party Keygen Test ===")
|
|
fmt.Println("Testing 2-of-3 keygen with 3 tss-party.exe processes")
|
|
fmt.Println()
|
|
|
|
sessionID := "test-session-123"
|
|
participants := []Participant{
|
|
{PartyID: "party-0", PartyIndex: 0},
|
|
{PartyID: "party-1", PartyIndex: 1},
|
|
{PartyID: "party-2", PartyIndex: 2},
|
|
}
|
|
participantsJSON, _ := json.Marshal(participants)
|
|
|
|
// Create 3 parties
|
|
parties := make([]*Party, 3)
|
|
|
|
for i := 0; i < 3; i++ {
|
|
party := &Party{
|
|
partyID: fmt.Sprintf("party-%d", i),
|
|
partyIndex: i,
|
|
outMessages: make(chan Message, 100),
|
|
}
|
|
|
|
cmd := exec.Command("./tss-party.exe",
|
|
"keygen",
|
|
"--session-id", sessionID,
|
|
"--party-id", party.partyID,
|
|
"--party-index", fmt.Sprintf("%d", i),
|
|
"--threshold-t", "2",
|
|
"--threshold-n", "3",
|
|
"--participants", string(participantsJSON),
|
|
"--password", "test-password",
|
|
)
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
fmt.Printf("Failed to get stdin for party %d: %v\n", i, err)
|
|
return
|
|
}
|
|
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
fmt.Printf("Failed to get stdout for party %d: %v\n", i, err)
|
|
return
|
|
}
|
|
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
fmt.Printf("Failed to get stderr for party %d: %v\n", i, err)
|
|
return
|
|
}
|
|
|
|
party.cmd = cmd
|
|
party.stdin = stdin
|
|
party.stdout = stdout
|
|
party.stderr = stderr
|
|
parties[i] = party
|
|
}
|
|
|
|
// Start all processes
|
|
for i, party := range parties {
|
|
if err := party.cmd.Start(); err != nil {
|
|
fmt.Printf("Failed to start party %d: %v\n", i, err)
|
|
return
|
|
}
|
|
fmt.Printf("[Party %d] Started process (PID: %d)\n", i, party.cmd.Process.Pid)
|
|
}
|
|
|
|
// Read stderr in background
|
|
for i, party := range parties {
|
|
go func(idx int, p *Party) {
|
|
scanner := bufio.NewScanner(p.stderr)
|
|
for scanner.Scan() {
|
|
fmt.Printf("[Party %d STDERR] %s\n", idx, scanner.Text())
|
|
}
|
|
}(i, party)
|
|
}
|
|
|
|
// Read stdout and collect outgoing messages
|
|
for i, party := range parties {
|
|
go func(idx int, p *Party) {
|
|
scanner := bufio.NewScanner(p.stdout)
|
|
// Increase buffer size for large messages
|
|
buf := make([]byte, 1024*1024) // 1MB buffer
|
|
scanner.Buffer(buf, len(buf))
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
var msg Message
|
|
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
|
fmt.Printf("[Party %d] Non-JSON output: %s\n", idx, line[:min(100, len(line))])
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "outgoing":
|
|
fmt.Printf("[Party %d] Outgoing: broadcast=%v, toParties=%v, payloadLen=%d\n",
|
|
idx, msg.IsBroadcast, msg.ToParties, len(msg.Payload))
|
|
p.outMessages <- msg
|
|
case "progress":
|
|
fmt.Printf("[Party %d] Progress: round %d/%d\n", idx, msg.Round, msg.TotalRounds)
|
|
case "result":
|
|
fmt.Printf("[Party %d] Got result! PublicKey len=%d\n", idx, len(msg.PublicKey))
|
|
p.mu.Lock()
|
|
p.result = &msg
|
|
p.mu.Unlock()
|
|
case "error":
|
|
fmt.Printf("[Party %d] Error: %s\n", idx, msg.Error)
|
|
p.mu.Lock()
|
|
p.err = fmt.Errorf(msg.Error)
|
|
p.mu.Unlock()
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
fmt.Printf("[Party %d] Scanner error: %v\n", idx, err)
|
|
}
|
|
}(i, party)
|
|
}
|
|
|
|
// Message router - route messages between parties using separate goroutines per party
|
|
// Use a mutex for each party's stdin to prevent concurrent writes
|
|
stdinMutexes := make([]sync.Mutex, 3)
|
|
|
|
for senderIdx, sender := range parties {
|
|
go func(idx int, s *Party) {
|
|
for msg := range s.outMessages {
|
|
// Route message to recipients
|
|
if msg.IsBroadcast {
|
|
// Send to all except sender - use goroutines to avoid blocking
|
|
var wg sync.WaitGroup
|
|
for receiverIdx, receiver := range parties {
|
|
if receiverIdx != idx {
|
|
wg.Add(1)
|
|
go func(rIdx int, r *Party) {
|
|
defer wg.Done()
|
|
inMsg := Message{
|
|
Type: "incoming",
|
|
FromPartyIndex: idx,
|
|
IsBroadcast: true,
|
|
Payload: msg.Payload,
|
|
}
|
|
data, _ := json.Marshal(inMsg)
|
|
fmt.Printf("[Router] Broadcast %d -> %d (payload=%d)\n", idx, rIdx, len(msg.Payload))
|
|
stdinMutexes[rIdx].Lock()
|
|
_, err := r.stdin.Write(append(data, '\n'))
|
|
stdinMutexes[rIdx].Unlock()
|
|
if err != nil {
|
|
fmt.Printf("[Router] Error writing to party %d: %v\n", rIdx, err)
|
|
}
|
|
}(receiverIdx, receiver)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
} else {
|
|
// Send to specific parties
|
|
for _, targetID := range msg.ToParties {
|
|
for receiverIdx, receiver := range parties {
|
|
if receiver.partyID == targetID {
|
|
inMsg := Message{
|
|
Type: "incoming",
|
|
FromPartyIndex: idx,
|
|
IsBroadcast: false,
|
|
Payload: msg.Payload,
|
|
}
|
|
data, _ := json.Marshal(inMsg)
|
|
fmt.Printf("[Router] P2P %d -> %d (%s, payload=%d)\n", idx, receiverIdx, targetID, len(msg.Payload))
|
|
stdinMutexes[receiverIdx].Lock()
|
|
_, err := receiver.stdin.Write(append(data, '\n'))
|
|
stdinMutexes[receiverIdx].Unlock()
|
|
if err != nil {
|
|
fmt.Printf("[Router] Error writing to party %d: %v\n", receiverIdx, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(senderIdx, sender)
|
|
}
|
|
|
|
// Wait for completion with timeout
|
|
done := make(chan bool)
|
|
go func() {
|
|
for _, party := range parties {
|
|
party.cmd.Wait()
|
|
}
|
|
done <- true
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
fmt.Println("\n=== All processes completed ===")
|
|
case <-time.After(5 * time.Minute):
|
|
fmt.Println("\n=== TIMEOUT after 5 minutes ===")
|
|
for _, party := range parties {
|
|
party.cmd.Process.Kill()
|
|
}
|
|
}
|
|
|
|
// Check results
|
|
fmt.Println("\n=== Results ===")
|
|
success := true
|
|
var publicKeys []string
|
|
for i, party := range parties {
|
|
party.mu.Lock()
|
|
if party.err != nil {
|
|
fmt.Printf("[Party %d] FAILED: %v\n", i, party.err)
|
|
success = false
|
|
} else if party.result != nil {
|
|
pkLen := len(party.result.PublicKey)
|
|
if pkLen > 40 {
|
|
fmt.Printf("[Party %d] SUCCESS: PublicKey=%s...\n", i, party.result.PublicKey[:40])
|
|
} else {
|
|
fmt.Printf("[Party %d] SUCCESS: PublicKey=%s\n", i, party.result.PublicKey)
|
|
}
|
|
publicKeys = append(publicKeys, party.result.PublicKey)
|
|
} else {
|
|
fmt.Printf("[Party %d] NO RESULT\n", i)
|
|
success = false
|
|
}
|
|
party.mu.Unlock()
|
|
}
|
|
|
|
// Verify all public keys match
|
|
if len(publicKeys) == 3 {
|
|
if publicKeys[0] == publicKeys[1] && publicKeys[1] == publicKeys[2] {
|
|
fmt.Println("\nAll public keys match!")
|
|
} else {
|
|
fmt.Println("\nWARNING: Public keys don't match!")
|
|
success = false
|
|
}
|
|
}
|
|
|
|
if success {
|
|
fmt.Println("\n=== TEST PASSED ===")
|
|
} else {
|
|
fmt.Println("\n=== TEST FAILED ===")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|