rwadurian/backend/mpc-system/services/service-party-app/tss-party/test_keygen.go

299 lines
7.8 KiB
Go

// +build ignore
// Test script to verify tss-party.exe keygen with 3 parties
// Run with: go run test_keygen.go
package main
import (
"bufio"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"sync"
"time"
)
// Message types for IPC
type Message struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
EncryptedShare string `json:"encryptedShare,omitempty"`
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
}
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
type Party struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.ReadCloser
partyID string
partyIndex int
outMessages chan Message
result *Message
err error
mu sync.Mutex
}
func main() {
fmt.Println("=== TSS Party Keygen Test ===")
fmt.Println("Testing 2-of-3 keygen with 3 tss-party.exe processes")
fmt.Println()
sessionID := "test-session-123"
participants := []Participant{
{PartyID: "party-0", PartyIndex: 0},
{PartyID: "party-1", PartyIndex: 1},
{PartyID: "party-2", PartyIndex: 2},
}
participantsJSON, _ := json.Marshal(participants)
// Create 3 parties
parties := make([]*Party, 3)
for i := 0; i < 3; i++ {
party := &Party{
partyID: fmt.Sprintf("party-%d", i),
partyIndex: i,
outMessages: make(chan Message, 100),
}
cmd := exec.Command("./tss-party.exe",
"keygen",
"--session-id", sessionID,
"--party-id", party.partyID,
"--party-index", fmt.Sprintf("%d", i),
"--threshold-t", "2",
"--threshold-n", "3",
"--participants", string(participantsJSON),
"--password", "test-password",
)
stdin, err := cmd.StdinPipe()
if err != nil {
fmt.Printf("Failed to get stdin for party %d: %v\n", i, err)
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
fmt.Printf("Failed to get stdout for party %d: %v\n", i, err)
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
fmt.Printf("Failed to get stderr for party %d: %v\n", i, err)
return
}
party.cmd = cmd
party.stdin = stdin
party.stdout = stdout
party.stderr = stderr
parties[i] = party
}
// Start all processes
for i, party := range parties {
if err := party.cmd.Start(); err != nil {
fmt.Printf("Failed to start party %d: %v\n", i, err)
return
}
fmt.Printf("[Party %d] Started process (PID: %d)\n", i, party.cmd.Process.Pid)
}
// Read stderr in background
for i, party := range parties {
go func(idx int, p *Party) {
scanner := bufio.NewScanner(p.stderr)
for scanner.Scan() {
fmt.Printf("[Party %d STDERR] %s\n", idx, scanner.Text())
}
}(i, party)
}
// Read stdout and collect outgoing messages
for i, party := range parties {
go func(idx int, p *Party) {
scanner := bufio.NewScanner(p.stdout)
// Increase buffer size for large messages
buf := make([]byte, 1024*1024) // 1MB buffer
scanner.Buffer(buf, len(buf))
for scanner.Scan() {
line := scanner.Text()
var msg Message
if err := json.Unmarshal([]byte(line), &msg); err != nil {
fmt.Printf("[Party %d] Non-JSON output: %s\n", idx, line[:min(100, len(line))])
continue
}
switch msg.Type {
case "outgoing":
fmt.Printf("[Party %d] Outgoing: broadcast=%v, toParties=%v, payloadLen=%d\n",
idx, msg.IsBroadcast, msg.ToParties, len(msg.Payload))
p.outMessages <- msg
case "progress":
fmt.Printf("[Party %d] Progress: round %d/%d\n", idx, msg.Round, msg.TotalRounds)
case "result":
fmt.Printf("[Party %d] Got result! PublicKey len=%d\n", idx, len(msg.PublicKey))
p.mu.Lock()
p.result = &msg
p.mu.Unlock()
case "error":
fmt.Printf("[Party %d] Error: %s\n", idx, msg.Error)
p.mu.Lock()
p.err = fmt.Errorf(msg.Error)
p.mu.Unlock()
}
}
if err := scanner.Err(); err != nil {
fmt.Printf("[Party %d] Scanner error: %v\n", idx, err)
}
}(i, party)
}
// Message router - route messages between parties using separate goroutines per party
// Use a mutex for each party's stdin to prevent concurrent writes
stdinMutexes := make([]sync.Mutex, 3)
for senderIdx, sender := range parties {
go func(idx int, s *Party) {
for msg := range s.outMessages {
// Route message to recipients
if msg.IsBroadcast {
// Send to all except sender - use goroutines to avoid blocking
var wg sync.WaitGroup
for receiverIdx, receiver := range parties {
if receiverIdx != idx {
wg.Add(1)
go func(rIdx int, r *Party) {
defer wg.Done()
inMsg := Message{
Type: "incoming",
FromPartyIndex: idx,
IsBroadcast: true,
Payload: msg.Payload,
}
data, _ := json.Marshal(inMsg)
fmt.Printf("[Router] Broadcast %d -> %d (payload=%d)\n", idx, rIdx, len(msg.Payload))
stdinMutexes[rIdx].Lock()
_, err := r.stdin.Write(append(data, '\n'))
stdinMutexes[rIdx].Unlock()
if err != nil {
fmt.Printf("[Router] Error writing to party %d: %v\n", rIdx, err)
}
}(receiverIdx, receiver)
}
}
wg.Wait()
} else {
// Send to specific parties
for _, targetID := range msg.ToParties {
for receiverIdx, receiver := range parties {
if receiver.partyID == targetID {
inMsg := Message{
Type: "incoming",
FromPartyIndex: idx,
IsBroadcast: false,
Payload: msg.Payload,
}
data, _ := json.Marshal(inMsg)
fmt.Printf("[Router] P2P %d -> %d (%s, payload=%d)\n", idx, receiverIdx, targetID, len(msg.Payload))
stdinMutexes[receiverIdx].Lock()
_, err := receiver.stdin.Write(append(data, '\n'))
stdinMutexes[receiverIdx].Unlock()
if err != nil {
fmt.Printf("[Router] Error writing to party %d: %v\n", receiverIdx, err)
}
}
}
}
}
}
}(senderIdx, sender)
}
// Wait for completion with timeout
done := make(chan bool)
go func() {
for _, party := range parties {
party.cmd.Wait()
}
done <- true
}()
select {
case <-done:
fmt.Println("\n=== All processes completed ===")
case <-time.After(5 * time.Minute):
fmt.Println("\n=== TIMEOUT after 5 minutes ===")
for _, party := range parties {
party.cmd.Process.Kill()
}
}
// Check results
fmt.Println("\n=== Results ===")
success := true
var publicKeys []string
for i, party := range parties {
party.mu.Lock()
if party.err != nil {
fmt.Printf("[Party %d] FAILED: %v\n", i, party.err)
success = false
} else if party.result != nil {
pkLen := len(party.result.PublicKey)
if pkLen > 40 {
fmt.Printf("[Party %d] SUCCESS: PublicKey=%s...\n", i, party.result.PublicKey[:40])
} else {
fmt.Printf("[Party %d] SUCCESS: PublicKey=%s\n", i, party.result.PublicKey)
}
publicKeys = append(publicKeys, party.result.PublicKey)
} else {
fmt.Printf("[Party %d] NO RESULT\n", i)
success = false
}
party.mu.Unlock()
}
// Verify all public keys match
if len(publicKeys) == 3 {
if publicKeys[0] == publicKeys[1] && publicKeys[1] == publicKeys[2] {
fmt.Println("\nAll public keys match!")
} else {
fmt.Println("\nWARNING: Public keys don't match!")
success = false
}
}
if success {
fmt.Println("\n=== TEST PASSED ===")
} else {
fmt.Println("\n=== TEST FAILED ===")
os.Exit(1)
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}