675 lines
21 KiB
Go
675 lines
21 KiB
Go
// integration_test.go - Integration test for the complete co-sign flow
|
|
// Tests: session creation, joining, waiting, events, and signing
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
|
)
|
|
|
|
const (
|
|
// Backend service URLs (from docker-compose.windows.yml)
|
|
accountServiceURL = "http://localhost:4000"
|
|
grpcRouterAddr = "localhost:50051"
|
|
)
|
|
|
|
// API Response types for co-managed keygen
|
|
type CreateCoManagedSessionRequest struct {
|
|
WalletName string `json:"wallet_name"`
|
|
ThresholdT int `json:"threshold_t"`
|
|
ThresholdN int `json:"threshold_n"`
|
|
InitiatorPartyID string `json:"initiator_party_id"`
|
|
InitiatorName string `json:"initiator_name,omitempty"`
|
|
PersistentCount int `json:"persistent_count"`
|
|
}
|
|
|
|
type CreateCoManagedSessionResponse struct {
|
|
SessionID string `json:"session_id"`
|
|
InviteCode string `json:"invite_code"`
|
|
WalletName string `json:"wallet_name"`
|
|
ThresholdT int `json:"threshold_t"`
|
|
ThresholdN int `json:"threshold_n"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
JoinToken string `json:"join_token"`
|
|
PartyID string `json:"party_id"`
|
|
PartyIndex int `json:"party_index"`
|
|
}
|
|
|
|
// API Response types for co-managed sign
|
|
type CreateSignSessionRequest struct {
|
|
KeygenSessionID string `json:"keygen_session_id"`
|
|
WalletName string `json:"wallet_name"`
|
|
MessageHash string `json:"message_hash"`
|
|
Parties []SignPartyInfo `json:"parties"`
|
|
ThresholdT int `json:"threshold_t"`
|
|
InitiatorName string `json:"initiator_name,omitempty"`
|
|
}
|
|
|
|
type CreateSignSessionResponse struct {
|
|
SessionID string `json:"session_id"`
|
|
InviteCode string `json:"invite_code"`
|
|
KeygenSessionID string `json:"keygen_session_id"`
|
|
MessageHash string `json:"message_hash"`
|
|
ThresholdT int `json:"threshold_t"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
JoinToken string `json:"join_token"`
|
|
}
|
|
|
|
type GetSignSessionResponse struct {
|
|
SessionID string `json:"session_id"`
|
|
KeygenSessionID string `json:"keygen_session_id"`
|
|
WalletName string `json:"wallet_name"`
|
|
MessageHash string `json:"message_hash"`
|
|
ThresholdT int `json:"threshold_t"`
|
|
Status string `json:"status"`
|
|
InviteCode string `json:"invite_code"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
Parties []SignPartyInfo `json:"parties"`
|
|
JoinedCount int `json:"joined_count"`
|
|
JoinToken string `json:"join_token,omitempty"`
|
|
}
|
|
|
|
type SignPartyInfo struct {
|
|
PartyID string `json:"party_id"`
|
|
PartyIndex int `json:"party_index"`
|
|
}
|
|
|
|
// TestAccountServiceHealth tests if account service is available
|
|
func TestAccountServiceHealth(t *testing.T) {
|
|
resp, err := http.Get(accountServiceURL + "/health")
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to account service: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("Account service unhealthy, status: %d", resp.StatusCode)
|
|
}
|
|
|
|
t.Log("Account service is healthy")
|
|
}
|
|
|
|
// TestCreateSignSession tests creating a new sign session
|
|
func TestCreateSignSession(t *testing.T) {
|
|
// This requires an existing keygen session ID
|
|
// For testing, we'll use a mock one
|
|
keygenSessionID := "test-keygen-session-" + fmt.Sprintf("%d", time.Now().UnixNano())
|
|
messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
|
|
|
reqBody := CreateSignSessionRequest{
|
|
KeygenSessionID: keygenSessionID,
|
|
MessageHash: messageHash,
|
|
InitiatorName: "test-initiator",
|
|
}
|
|
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
resp, err := http.Post(
|
|
accountServiceURL+"/api/v1/co-managed/sign",
|
|
"application/json",
|
|
bytes.NewBuffer(jsonBody),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create sign session: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body))
|
|
|
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
t.Logf("Note: This test requires an existing keygen session in the database")
|
|
t.Skipf("Sign session creation returned status %d (expected existing keygen session)", resp.StatusCode)
|
|
}
|
|
|
|
var result CreateSignSessionResponse
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
t.Fatalf("Failed to parse response: %v", err)
|
|
}
|
|
|
|
t.Logf("Session created: ID=%s, InviteCode=%s", result.SessionID, result.InviteCode)
|
|
}
|
|
|
|
// TestGetSessionByInviteCode tests retrieving session info by invite code
|
|
func TestGetSessionByInviteCode(t *testing.T) {
|
|
// This would need a valid invite code from a real session
|
|
inviteCode := "TEST123"
|
|
|
|
resp, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + inviteCode)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Logf("Get session response (status %d): %s", resp.StatusCode, string(body))
|
|
|
|
if resp.StatusCode == 404 {
|
|
t.Skip("No session found with invite code (expected for test)")
|
|
}
|
|
}
|
|
|
|
// TestCoManagedKeygenSessionFlow tests the keygen session creation and join flow
|
|
func TestCoManagedKeygenSessionFlow(t *testing.T) {
|
|
t.Log("=== Co-Managed Keygen Session Flow Test ===")
|
|
|
|
// Step 1: Create a keygen session
|
|
t.Log("Step 1: Creating keygen session...")
|
|
reqBody := CreateCoManagedSessionRequest{
|
|
WalletName: "test-wallet-" + fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
ThresholdT: 1, // 2-of-3: t=1 means t+1=2 signers needed
|
|
ThresholdN: 3,
|
|
InitiatorPartyID: "test-initiator-party",
|
|
InitiatorName: "Test User",
|
|
PersistentCount: 0, // No server parties for this test
|
|
}
|
|
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
resp, err := http.Post(
|
|
accountServiceURL+"/api/v1/co-managed/sessions",
|
|
"application/json",
|
|
bytes.NewBuffer(jsonBody),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body))
|
|
|
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode)
|
|
}
|
|
|
|
var createResp CreateCoManagedSessionResponse
|
|
if err := json.Unmarshal(body, &createResp); err != nil {
|
|
t.Fatalf("Failed to parse response: %v", err)
|
|
}
|
|
|
|
t.Logf("Session created: ID=%s, InviteCode=%s", createResp.SessionID, createResp.InviteCode)
|
|
|
|
// Step 2: Get session by invite code
|
|
t.Log("Step 2: Getting session by invite code...")
|
|
resp2, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/by-invite-code/" + createResp.InviteCode)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer resp2.Body.Close()
|
|
|
|
body2, _ := io.ReadAll(resp2.Body)
|
|
t.Logf("Get session response (status %d): %s", resp2.StatusCode, string(body2))
|
|
|
|
if resp2.StatusCode != 200 {
|
|
t.Fatalf("Failed to get session by invite code, status: %d", resp2.StatusCode)
|
|
}
|
|
|
|
// Step 3: Get session status
|
|
t.Log("Step 3: Getting session status...")
|
|
resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/" + createResp.SessionID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session status: %v", err)
|
|
}
|
|
defer resp3.Body.Close()
|
|
|
|
body3, _ := io.ReadAll(resp3.Body)
|
|
t.Logf("Session status response (status %d): %s", resp3.StatusCode, string(body3))
|
|
|
|
if resp3.StatusCode != 200 {
|
|
t.Fatalf("Failed to get session status, status: %d", resp3.StatusCode)
|
|
}
|
|
|
|
t.Log("Keygen session flow test passed!")
|
|
}
|
|
|
|
// TestCoManagedSignSessionFlow tests the full sign session flow
|
|
func TestCoManagedSignSessionFlow(t *testing.T) {
|
|
t.Log("=== Co-Managed Sign Session Flow Test ===")
|
|
|
|
// First, create a keygen session to get a valid keygen_session_id
|
|
t.Log("Step 1: Creating keygen session...")
|
|
keygenReq := CreateCoManagedSessionRequest{
|
|
WalletName: "test-wallet-for-sign-" + fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
ThresholdT: 1,
|
|
ThresholdN: 3,
|
|
InitiatorPartyID: "test-party-0",
|
|
InitiatorName: "Test User",
|
|
PersistentCount: 0,
|
|
}
|
|
|
|
jsonBody, _ := json.Marshal(keygenReq)
|
|
resp, err := http.Post(
|
|
accountServiceURL+"/api/v1/co-managed/sessions",
|
|
"application/json",
|
|
bytes.NewBuffer(jsonBody),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create keygen session: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Logf("Keygen session response (status %d): %s", resp.StatusCode, string(body))
|
|
|
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode)
|
|
}
|
|
|
|
var keygenResp CreateCoManagedSessionResponse
|
|
if err := json.Unmarshal(body, &keygenResp); err != nil {
|
|
t.Fatalf("Failed to parse keygen response: %v", err)
|
|
}
|
|
|
|
// Step 2: Create a sign session using the keygen session ID
|
|
// Note: For threshold T, we need T+1 parties to sign
|
|
// Backend validates that parties.length >= threshold_t + 1
|
|
t.Log("Step 2: Creating sign session...")
|
|
signReq := CreateSignSessionRequest{
|
|
KeygenSessionID: keygenResp.SessionID,
|
|
WalletName: keygenReq.WalletName,
|
|
MessageHash: "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2",
|
|
ThresholdT: 1, // Threshold T=1 means T+1=2 parties needed to sign
|
|
Parties: []SignPartyInfo{
|
|
{PartyID: "test-party-0", PartyIndex: 0},
|
|
{PartyID: "test-party-1", PartyIndex: 1},
|
|
},
|
|
InitiatorName: "Test User",
|
|
}
|
|
|
|
jsonBody2, _ := json.Marshal(signReq)
|
|
resp2, err := http.Post(
|
|
accountServiceURL+"/api/v1/co-managed/sign",
|
|
"application/json",
|
|
bytes.NewBuffer(jsonBody2),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create sign session: %v", err)
|
|
}
|
|
defer resp2.Body.Close()
|
|
|
|
body2, _ := io.ReadAll(resp2.Body)
|
|
t.Logf("Sign session response (status %d): %s", resp2.StatusCode, string(body2))
|
|
|
|
if resp2.StatusCode != 200 && resp2.StatusCode != 201 {
|
|
t.Fatalf("Failed to create sign session, status: %d", resp2.StatusCode)
|
|
}
|
|
|
|
var signResp CreateSignSessionResponse
|
|
if err := json.Unmarshal(body2, &signResp); err != nil {
|
|
t.Fatalf("Failed to parse sign response: %v", err)
|
|
}
|
|
|
|
t.Logf("Sign session created: ID=%s, InviteCode=%s", signResp.SessionID, signResp.InviteCode)
|
|
|
|
// Step 3: Get sign session by invite code
|
|
t.Log("Step 3: Getting sign session by invite code...")
|
|
resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + signResp.InviteCode)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get sign session: %v", err)
|
|
}
|
|
defer resp3.Body.Close()
|
|
|
|
body3, _ := io.ReadAll(resp3.Body)
|
|
t.Logf("Get sign session response (status %d): %s", resp3.StatusCode, string(body3))
|
|
|
|
if resp3.StatusCode != 200 {
|
|
t.Fatalf("Failed to get sign session by invite code, status: %d", resp3.StatusCode)
|
|
}
|
|
|
|
var getSignResp GetSignSessionResponse
|
|
if err := json.Unmarshal(body3, &getSignResp); err != nil {
|
|
t.Fatalf("Failed to parse get sign session response: %v", err)
|
|
}
|
|
|
|
t.Logf("Sign session status: %s, JoinedCount: %d, ThresholdT: %d",
|
|
getSignResp.Status, getSignResp.JoinedCount, getSignResp.ThresholdT)
|
|
|
|
t.Log("Sign session flow test passed!")
|
|
}
|
|
|
|
// TestFullSignFlow tests the complete signing flow with mock data
|
|
func TestFullSignFlow(t *testing.T) {
|
|
t.Log("=== Full Co-Sign Flow Test ===")
|
|
|
|
// Step 1: Generate mock key shares (simulating keygen result)
|
|
// NOTE: In tss-lib, threshold parameter means "t" where you need t+1 parties to sign.
|
|
// So for 2-of-3, we use threshold=1 (meaning 1+1=2 parties needed to sign)
|
|
t.Log("Step 1: Generating mock key shares for 2-of-3 scheme...")
|
|
thresholdT := 1 // t value: need t+1=2 parties to sign
|
|
totalN := 3 // total parties
|
|
shares, err := generateMockKeyShares(thresholdT, totalN)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate key shares: %v", err)
|
|
}
|
|
t.Logf("Generated %d key shares", len(shares))
|
|
|
|
// Step 2: Prepare signing parameters
|
|
messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
|
sessionID := fmt.Sprintf("test-sign-session-%d", time.Now().UnixNano())
|
|
password := "test-password"
|
|
|
|
// Step 3: Start signing with 2 parties (index 0 and 1)
|
|
// For tss-party.exe, we pass the signing threshold (t+1=2) and total (n=3)
|
|
t.Log("Step 2: Starting sign process with 2 parties...")
|
|
signingParties := []int{0, 1}
|
|
|
|
// Pass thresholdT+1=2 as the signing threshold to tss-party.exe
|
|
signature, err := runSigningProcess(shares, signingParties, messageHash, sessionID, password, thresholdT+1, totalN, t)
|
|
if err != nil {
|
|
t.Fatalf("Signing failed: %v", err)
|
|
}
|
|
|
|
t.Logf("Step 3: Signing complete!")
|
|
t.Logf("Signature (hex): %x", signature)
|
|
t.Logf("Signature (base64): %s", base64.StdEncoding.EncodeToString(signature))
|
|
}
|
|
|
|
// generateMockKeyShares generates key shares using tss-lib directly
|
|
func generateMockKeyShares(threshold, total int) ([]*keygen.LocalPartySaveData, error) {
|
|
fmt.Println("[Keygen] Starting key generation for", total, "parties with threshold", threshold)
|
|
|
|
partyIDs := make([]*tss.PartyID, total)
|
|
for i := 0; i < total; i++ {
|
|
partyIDs[i] = tss.NewPartyID(
|
|
fmt.Sprintf("party-%d", i),
|
|
fmt.Sprintf("party-%d", i),
|
|
big.NewInt(int64(i+1)),
|
|
)
|
|
}
|
|
|
|
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
|
|
outChannels := make([]chan tss.Message, total)
|
|
endChannels := make([]chan *keygen.LocalPartySaveData, total)
|
|
parties := make([]tss.Party, total)
|
|
|
|
for i := 0; i < total; i++ {
|
|
outChannels[i] = make(chan tss.Message, total*20)
|
|
endChannels[i] = make(chan *keygen.LocalPartySaveData, 1)
|
|
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], total, threshold)
|
|
parties[i] = keygen.NewLocalParty(params, outChannels[i], endChannels[i])
|
|
}
|
|
|
|
// Start all parties
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, total)
|
|
|
|
for i := 0; i < total; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
if err := parties[idx].Start(); err != nil {
|
|
errChan <- fmt.Errorf("party %d start error: %w", idx, err)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
results := make([]*keygen.LocalPartySaveData, total)
|
|
var resultsMu sync.Mutex
|
|
resultCount := 0
|
|
|
|
done := make(chan struct{})
|
|
|
|
// Message routing
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-done:
|
|
return
|
|
default:
|
|
for i := 0; i < total; i++ {
|
|
select {
|
|
case msg := <-outChannels[i]:
|
|
wire, _, _ := msg.WireBytes()
|
|
if msg.IsBroadcast() {
|
|
for j := 0; j < total; j++ {
|
|
if j != i {
|
|
go func(destIdx int) {
|
|
parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), true)
|
|
parties[destIdx].Update(parsed)
|
|
}(j)
|
|
}
|
|
}
|
|
} else {
|
|
for _, to := range msg.GetTo() {
|
|
for j := 0; j < total; j++ {
|
|
if sortedPartyIDs[j].Id == to.Id {
|
|
go func(destIdx int) {
|
|
parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), false)
|
|
parties[destIdx].Update(parsed)
|
|
}(j)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case result := <-endChannels[i]:
|
|
resultsMu.Lock()
|
|
results[i] = result
|
|
resultCount++
|
|
fmt.Printf("[Keygen] Party %d completed\n", i)
|
|
if resultCount == total {
|
|
close(done)
|
|
}
|
|
resultsMu.Unlock()
|
|
default:
|
|
}
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
select {
|
|
case <-done:
|
|
fmt.Println("[Keygen] All parties completed successfully")
|
|
case <-time.After(5 * time.Minute):
|
|
return nil, fmt.Errorf("keygen timeout")
|
|
}
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
return nil, err
|
|
default:
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// runSigningProcess runs the signing process using tss-party.exe
|
|
func runSigningProcess(
|
|
shares []*keygen.LocalPartySaveData,
|
|
signingPartyIndices []int,
|
|
messageHash, sessionID, password string,
|
|
thresholdT, thresholdN int,
|
|
t *testing.T,
|
|
) ([]byte, error) {
|
|
// Build participants
|
|
participants := make([]Participant, len(signingPartyIndices))
|
|
for i, idx := range signingPartyIndices {
|
|
participants[i] = Participant{
|
|
PartyID: fmt.Sprintf("party-%d", idx),
|
|
PartyIndex: idx,
|
|
}
|
|
}
|
|
participantsJSON, _ := json.Marshal(participants)
|
|
|
|
// Prepare encrypted shares
|
|
encryptedShares := make([]string, len(signingPartyIndices))
|
|
for i, idx := range signingPartyIndices {
|
|
shareBytes, _ := json.Marshal(shares[idx])
|
|
encrypted := encryptShare(shareBytes, password)
|
|
encryptedShares[i] = base64.StdEncoding.EncodeToString(encrypted)
|
|
}
|
|
|
|
// Find tss-party executable
|
|
exePath := "./tss-party.exe"
|
|
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
|
exePath = "./tss-party"
|
|
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("tss-party executable not found")
|
|
}
|
|
}
|
|
|
|
t.Logf("Using executable: %s", exePath)
|
|
|
|
// Start processes
|
|
processes := make([]*exec.Cmd, len(signingPartyIndices))
|
|
stdinPipes := make([]io.WriteCloser, len(signingPartyIndices))
|
|
stdoutPipes := make([]io.ReadCloser, len(signingPartyIndices))
|
|
|
|
for i, idx := range signingPartyIndices {
|
|
args := []string{
|
|
"sign",
|
|
"-session-id", sessionID,
|
|
"-party-id", fmt.Sprintf("party-%d", idx),
|
|
"-party-index", fmt.Sprintf("%d", idx),
|
|
"-threshold-t", fmt.Sprintf("%d", thresholdT),
|
|
"-threshold-n", fmt.Sprintf("%d", thresholdN),
|
|
"-participants", string(participantsJSON),
|
|
"-message-hash", messageHash,
|
|
"-share-data", encryptedShares[i],
|
|
"-password", password,
|
|
}
|
|
|
|
t.Logf("[Party %d] Starting with args: sign -session-id %s -party-id party-%d ...", idx, sessionID, idx)
|
|
|
|
cmd := exec.Command(exePath, args...)
|
|
stdin, _ := cmd.StdinPipe()
|
|
stdout, _ := cmd.StdoutPipe()
|
|
cmd.Stderr = os.Stderr
|
|
|
|
processes[i] = cmd
|
|
stdinPipes[i] = stdin
|
|
stdoutPipes[i] = stdout
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
return nil, fmt.Errorf("failed to start party %d: %w", idx, err)
|
|
}
|
|
}
|
|
|
|
// Message routing between parties
|
|
var wg sync.WaitGroup
|
|
results := make([][]byte, len(signingPartyIndices))
|
|
errors := make([]error, len(signingPartyIndices))
|
|
|
|
// Create a mutex for stdin writes
|
|
stdinMutexes := make([]*sync.Mutex, len(signingPartyIndices))
|
|
for i := range stdinMutexes {
|
|
stdinMutexes[i] = &sync.Mutex{}
|
|
}
|
|
|
|
for i := range signingPartyIndices {
|
|
wg.Add(1)
|
|
go func(partyIdx int) {
|
|
defer wg.Done()
|
|
scanner := bufio.NewScanner(stdoutPipes[partyIdx])
|
|
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
var msg Message
|
|
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
|
t.Logf("[Party %d] Invalid JSON: %s", signingPartyIndices[partyIdx], line)
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case "progress":
|
|
t.Logf("[Party %d] Progress: round %d/%d", signingPartyIndices[partyIdx], msg.Round, msg.TotalRounds)
|
|
|
|
case "outgoing":
|
|
// tss-party.exe outputs "outgoing" messages that need to be routed to other parties
|
|
t.Logf("[Party %d] Outgoing message (broadcast=%v, toParties=%v)", signingPartyIndices[partyIdx], msg.IsBroadcast, msg.ToParties)
|
|
// Route to other parties
|
|
for j := range signingPartyIndices {
|
|
if j != partyIdx {
|
|
// If not broadcast, check if this party is in the ToParties list
|
|
if !msg.IsBroadcast && len(msg.ToParties) > 0 {
|
|
targetPartyID := fmt.Sprintf("party-%d", signingPartyIndices[j])
|
|
found := false
|
|
for _, to := range msg.ToParties {
|
|
if to == targetPartyID {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// tss-party.exe expects "incoming" messages
|
|
msgToSend := Message{
|
|
Type: "incoming",
|
|
IsBroadcast: msg.IsBroadcast,
|
|
Payload: msg.Payload,
|
|
FromPartyIndex: signingPartyIndices[partyIdx],
|
|
}
|
|
data, _ := json.Marshal(msgToSend)
|
|
|
|
stdinMutexes[j].Lock()
|
|
stdinPipes[j].Write(append(data, '\n'))
|
|
stdinMutexes[j].Unlock()
|
|
}
|
|
}
|
|
|
|
case "result":
|
|
t.Logf("[Party %d] Got result!", signingPartyIndices[partyIdx])
|
|
signature, _ := base64.StdEncoding.DecodeString(msg.Payload)
|
|
results[partyIdx] = signature
|
|
|
|
case "error":
|
|
t.Logf("[Party %d] Error: %s", signingPartyIndices[partyIdx], msg.Error)
|
|
errors[partyIdx] = fmt.Errorf("party error: %s", msg.Error)
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
t.Logf("[Party %d] Scanner error: %v", signingPartyIndices[partyIdx], err)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Wait for processes to complete
|
|
for i, cmd := range processes {
|
|
if err := cmd.Wait(); err != nil {
|
|
t.Logf("[Party %d] Process exit error: %v", signingPartyIndices[i], err)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Check for errors
|
|
for i, err := range errors {
|
|
if err != nil {
|
|
return nil, fmt.Errorf("party %d error: %w", signingPartyIndices[i], err)
|
|
}
|
|
}
|
|
|
|
// Return first result
|
|
for _, result := range results {
|
|
if result != nil {
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("no signature received")
|
|
}
|