rwadurian/backend/mpc-system/services/service-party-app/tss-party/integration_test.go

675 lines
21 KiB
Go

// integration_test.go - Integration test for the complete co-sign flow
// Tests: session creation, joining, waiting, events, and signing
package main
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"os"
"os/exec"
"sync"
"testing"
"time"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/tss"
)
const (
// Backend service URLs (from docker-compose.windows.yml)
accountServiceURL = "http://localhost:4000"
grpcRouterAddr = "localhost:50051"
)
// API Response types for co-managed keygen
type CreateCoManagedSessionRequest struct {
WalletName string `json:"wallet_name"`
ThresholdT int `json:"threshold_t"`
ThresholdN int `json:"threshold_n"`
InitiatorPartyID string `json:"initiator_party_id"`
InitiatorName string `json:"initiator_name,omitempty"`
PersistentCount int `json:"persistent_count"`
}
type CreateCoManagedSessionResponse struct {
SessionID string `json:"session_id"`
InviteCode string `json:"invite_code"`
WalletName string `json:"wallet_name"`
ThresholdT int `json:"threshold_t"`
ThresholdN int `json:"threshold_n"`
ExpiresAt int64 `json:"expires_at"`
JoinToken string `json:"join_token"`
PartyID string `json:"party_id"`
PartyIndex int `json:"party_index"`
}
// API Response types for co-managed sign
type CreateSignSessionRequest struct {
KeygenSessionID string `json:"keygen_session_id"`
WalletName string `json:"wallet_name"`
MessageHash string `json:"message_hash"`
Parties []SignPartyInfo `json:"parties"`
ThresholdT int `json:"threshold_t"`
InitiatorName string `json:"initiator_name,omitempty"`
}
type CreateSignSessionResponse struct {
SessionID string `json:"session_id"`
InviteCode string `json:"invite_code"`
KeygenSessionID string `json:"keygen_session_id"`
MessageHash string `json:"message_hash"`
ThresholdT int `json:"threshold_t"`
ExpiresAt int64 `json:"expires_at"`
JoinToken string `json:"join_token"`
}
type GetSignSessionResponse struct {
SessionID string `json:"session_id"`
KeygenSessionID string `json:"keygen_session_id"`
WalletName string `json:"wallet_name"`
MessageHash string `json:"message_hash"`
ThresholdT int `json:"threshold_t"`
Status string `json:"status"`
InviteCode string `json:"invite_code"`
ExpiresAt int64 `json:"expires_at"`
Parties []SignPartyInfo `json:"parties"`
JoinedCount int `json:"joined_count"`
JoinToken string `json:"join_token,omitempty"`
}
type SignPartyInfo struct {
PartyID string `json:"party_id"`
PartyIndex int `json:"party_index"`
}
// TestAccountServiceHealth tests if account service is available
func TestAccountServiceHealth(t *testing.T) {
resp, err := http.Get(accountServiceURL + "/health")
if err != nil {
t.Fatalf("Failed to connect to account service: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("Account service unhealthy, status: %d", resp.StatusCode)
}
t.Log("Account service is healthy")
}
// TestCreateSignSession tests creating a new sign session
func TestCreateSignSession(t *testing.T) {
// This requires an existing keygen session ID
// For testing, we'll use a mock one
keygenSessionID := "test-keygen-session-" + fmt.Sprintf("%d", time.Now().UnixNano())
messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
reqBody := CreateSignSessionRequest{
KeygenSessionID: keygenSessionID,
MessageHash: messageHash,
InitiatorName: "test-initiator",
}
jsonBody, _ := json.Marshal(reqBody)
resp, err := http.Post(
accountServiceURL+"/api/v1/co-managed/sign",
"application/json",
bytes.NewBuffer(jsonBody),
)
if err != nil {
t.Fatalf("Failed to create sign session: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode != 200 && resp.StatusCode != 201 {
t.Logf("Note: This test requires an existing keygen session in the database")
t.Skipf("Sign session creation returned status %d (expected existing keygen session)", resp.StatusCode)
}
var result CreateSignSessionResponse
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
t.Logf("Session created: ID=%s, InviteCode=%s", result.SessionID, result.InviteCode)
}
// TestGetSessionByInviteCode tests retrieving session info by invite code
func TestGetSessionByInviteCode(t *testing.T) {
// This would need a valid invite code from a real session
inviteCode := "TEST123"
resp, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + inviteCode)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
t.Logf("Get session response (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode == 404 {
t.Skip("No session found with invite code (expected for test)")
}
}
// TestCoManagedKeygenSessionFlow tests the keygen session creation and join flow
func TestCoManagedKeygenSessionFlow(t *testing.T) {
t.Log("=== Co-Managed Keygen Session Flow Test ===")
// Step 1: Create a keygen session
t.Log("Step 1: Creating keygen session...")
reqBody := CreateCoManagedSessionRequest{
WalletName: "test-wallet-" + fmt.Sprintf("%d", time.Now().UnixNano()),
ThresholdT: 1, // 2-of-3: t=1 means t+1=2 signers needed
ThresholdN: 3,
InitiatorPartyID: "test-initiator-party",
InitiatorName: "Test User",
PersistentCount: 0, // No server parties for this test
}
jsonBody, _ := json.Marshal(reqBody)
resp, err := http.Post(
accountServiceURL+"/api/v1/co-managed/sessions",
"application/json",
bytes.NewBuffer(jsonBody),
)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode != 200 && resp.StatusCode != 201 {
t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode)
}
var createResp CreateCoManagedSessionResponse
if err := json.Unmarshal(body, &createResp); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
t.Logf("Session created: ID=%s, InviteCode=%s", createResp.SessionID, createResp.InviteCode)
// Step 2: Get session by invite code
t.Log("Step 2: Getting session by invite code...")
resp2, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/by-invite-code/" + createResp.InviteCode)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer resp2.Body.Close()
body2, _ := io.ReadAll(resp2.Body)
t.Logf("Get session response (status %d): %s", resp2.StatusCode, string(body2))
if resp2.StatusCode != 200 {
t.Fatalf("Failed to get session by invite code, status: %d", resp2.StatusCode)
}
// Step 3: Get session status
t.Log("Step 3: Getting session status...")
resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/" + createResp.SessionID)
if err != nil {
t.Fatalf("Failed to get session status: %v", err)
}
defer resp3.Body.Close()
body3, _ := io.ReadAll(resp3.Body)
t.Logf("Session status response (status %d): %s", resp3.StatusCode, string(body3))
if resp3.StatusCode != 200 {
t.Fatalf("Failed to get session status, status: %d", resp3.StatusCode)
}
t.Log("Keygen session flow test passed!")
}
// TestCoManagedSignSessionFlow tests the full sign session flow
func TestCoManagedSignSessionFlow(t *testing.T) {
t.Log("=== Co-Managed Sign Session Flow Test ===")
// First, create a keygen session to get a valid keygen_session_id
t.Log("Step 1: Creating keygen session...")
keygenReq := CreateCoManagedSessionRequest{
WalletName: "test-wallet-for-sign-" + fmt.Sprintf("%d", time.Now().UnixNano()),
ThresholdT: 1,
ThresholdN: 3,
InitiatorPartyID: "test-party-0",
InitiatorName: "Test User",
PersistentCount: 0,
}
jsonBody, _ := json.Marshal(keygenReq)
resp, err := http.Post(
accountServiceURL+"/api/v1/co-managed/sessions",
"application/json",
bytes.NewBuffer(jsonBody),
)
if err != nil {
t.Fatalf("Failed to create keygen session: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
t.Logf("Keygen session response (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode != 200 && resp.StatusCode != 201 {
t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode)
}
var keygenResp CreateCoManagedSessionResponse
if err := json.Unmarshal(body, &keygenResp); err != nil {
t.Fatalf("Failed to parse keygen response: %v", err)
}
// Step 2: Create a sign session using the keygen session ID
// Note: For threshold T, we need T+1 parties to sign
// Backend validates that parties.length >= threshold_t + 1
t.Log("Step 2: Creating sign session...")
signReq := CreateSignSessionRequest{
KeygenSessionID: keygenResp.SessionID,
WalletName: keygenReq.WalletName,
MessageHash: "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2",
ThresholdT: 1, // Threshold T=1 means T+1=2 parties needed to sign
Parties: []SignPartyInfo{
{PartyID: "test-party-0", PartyIndex: 0},
{PartyID: "test-party-1", PartyIndex: 1},
},
InitiatorName: "Test User",
}
jsonBody2, _ := json.Marshal(signReq)
resp2, err := http.Post(
accountServiceURL+"/api/v1/co-managed/sign",
"application/json",
bytes.NewBuffer(jsonBody2),
)
if err != nil {
t.Fatalf("Failed to create sign session: %v", err)
}
defer resp2.Body.Close()
body2, _ := io.ReadAll(resp2.Body)
t.Logf("Sign session response (status %d): %s", resp2.StatusCode, string(body2))
if resp2.StatusCode != 200 && resp2.StatusCode != 201 {
t.Fatalf("Failed to create sign session, status: %d", resp2.StatusCode)
}
var signResp CreateSignSessionResponse
if err := json.Unmarshal(body2, &signResp); err != nil {
t.Fatalf("Failed to parse sign response: %v", err)
}
t.Logf("Sign session created: ID=%s, InviteCode=%s", signResp.SessionID, signResp.InviteCode)
// Step 3: Get sign session by invite code
t.Log("Step 3: Getting sign session by invite code...")
resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + signResp.InviteCode)
if err != nil {
t.Fatalf("Failed to get sign session: %v", err)
}
defer resp3.Body.Close()
body3, _ := io.ReadAll(resp3.Body)
t.Logf("Get sign session response (status %d): %s", resp3.StatusCode, string(body3))
if resp3.StatusCode != 200 {
t.Fatalf("Failed to get sign session by invite code, status: %d", resp3.StatusCode)
}
var getSignResp GetSignSessionResponse
if err := json.Unmarshal(body3, &getSignResp); err != nil {
t.Fatalf("Failed to parse get sign session response: %v", err)
}
t.Logf("Sign session status: %s, JoinedCount: %d, ThresholdT: %d",
getSignResp.Status, getSignResp.JoinedCount, getSignResp.ThresholdT)
t.Log("Sign session flow test passed!")
}
// TestFullSignFlow tests the complete signing flow with mock data
func TestFullSignFlow(t *testing.T) {
t.Log("=== Full Co-Sign Flow Test ===")
// Step 1: Generate mock key shares (simulating keygen result)
// NOTE: In tss-lib, threshold parameter means "t" where you need t+1 parties to sign.
// So for 2-of-3, we use threshold=1 (meaning 1+1=2 parties needed to sign)
t.Log("Step 1: Generating mock key shares for 2-of-3 scheme...")
thresholdT := 1 // t value: need t+1=2 parties to sign
totalN := 3 // total parties
shares, err := generateMockKeyShares(thresholdT, totalN)
if err != nil {
t.Fatalf("Failed to generate key shares: %v", err)
}
t.Logf("Generated %d key shares", len(shares))
// Step 2: Prepare signing parameters
messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := fmt.Sprintf("test-sign-session-%d", time.Now().UnixNano())
password := "test-password"
// Step 3: Start signing with 2 parties (index 0 and 1)
// For tss-party.exe, we pass the signing threshold (t+1=2) and total (n=3)
t.Log("Step 2: Starting sign process with 2 parties...")
signingParties := []int{0, 1}
// Pass thresholdT+1=2 as the signing threshold to tss-party.exe
signature, err := runSigningProcess(shares, signingParties, messageHash, sessionID, password, thresholdT+1, totalN, t)
if err != nil {
t.Fatalf("Signing failed: %v", err)
}
t.Logf("Step 3: Signing complete!")
t.Logf("Signature (hex): %x", signature)
t.Logf("Signature (base64): %s", base64.StdEncoding.EncodeToString(signature))
}
// generateMockKeyShares generates key shares using tss-lib directly
func generateMockKeyShares(threshold, total int) ([]*keygen.LocalPartySaveData, error) {
fmt.Println("[Keygen] Starting key generation for", total, "parties with threshold", threshold)
partyIDs := make([]*tss.PartyID, total)
for i := 0; i < total; i++ {
partyIDs[i] = tss.NewPartyID(
fmt.Sprintf("party-%d", i),
fmt.Sprintf("party-%d", i),
big.NewInt(int64(i+1)),
)
}
sortedPartyIDs := tss.SortPartyIDs(partyIDs)
peerCtx := tss.NewPeerContext(sortedPartyIDs)
outChannels := make([]chan tss.Message, total)
endChannels := make([]chan *keygen.LocalPartySaveData, total)
parties := make([]tss.Party, total)
for i := 0; i < total; i++ {
outChannels[i] = make(chan tss.Message, total*20)
endChannels[i] = make(chan *keygen.LocalPartySaveData, 1)
params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], total, threshold)
parties[i] = keygen.NewLocalParty(params, outChannels[i], endChannels[i])
}
// Start all parties
var wg sync.WaitGroup
errChan := make(chan error, total)
for i := 0; i < total; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
if err := parties[idx].Start(); err != nil {
errChan <- fmt.Errorf("party %d start error: %w", idx, err)
}
}(i)
}
results := make([]*keygen.LocalPartySaveData, total)
var resultsMu sync.Mutex
resultCount := 0
done := make(chan struct{})
// Message routing
go func() {
for {
select {
case <-done:
return
default:
for i := 0; i < total; i++ {
select {
case msg := <-outChannels[i]:
wire, _, _ := msg.WireBytes()
if msg.IsBroadcast() {
for j := 0; j < total; j++ {
if j != i {
go func(destIdx int) {
parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), true)
parties[destIdx].Update(parsed)
}(j)
}
}
} else {
for _, to := range msg.GetTo() {
for j := 0; j < total; j++ {
if sortedPartyIDs[j].Id == to.Id {
go func(destIdx int) {
parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), false)
parties[destIdx].Update(parsed)
}(j)
break
}
}
}
}
case result := <-endChannels[i]:
resultsMu.Lock()
results[i] = result
resultCount++
fmt.Printf("[Keygen] Party %d completed\n", i)
if resultCount == total {
close(done)
}
resultsMu.Unlock()
default:
}
}
time.Sleep(5 * time.Millisecond)
}
}
}()
wg.Wait()
select {
case <-done:
fmt.Println("[Keygen] All parties completed successfully")
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("keygen timeout")
}
select {
case err := <-errChan:
return nil, err
default:
}
return results, nil
}
// runSigningProcess runs the signing process using tss-party.exe
func runSigningProcess(
shares []*keygen.LocalPartySaveData,
signingPartyIndices []int,
messageHash, sessionID, password string,
thresholdT, thresholdN int,
t *testing.T,
) ([]byte, error) {
// Build participants
participants := make([]Participant, len(signingPartyIndices))
for i, idx := range signingPartyIndices {
participants[i] = Participant{
PartyID: fmt.Sprintf("party-%d", idx),
PartyIndex: idx,
}
}
participantsJSON, _ := json.Marshal(participants)
// Prepare encrypted shares
encryptedShares := make([]string, len(signingPartyIndices))
for i, idx := range signingPartyIndices {
shareBytes, _ := json.Marshal(shares[idx])
encrypted := encryptShare(shareBytes, password)
encryptedShares[i] = base64.StdEncoding.EncodeToString(encrypted)
}
// Find tss-party executable
exePath := "./tss-party.exe"
if _, err := os.Stat(exePath); os.IsNotExist(err) {
exePath = "./tss-party"
if _, err := os.Stat(exePath); os.IsNotExist(err) {
return nil, fmt.Errorf("tss-party executable not found")
}
}
t.Logf("Using executable: %s", exePath)
// Start processes
processes := make([]*exec.Cmd, len(signingPartyIndices))
stdinPipes := make([]io.WriteCloser, len(signingPartyIndices))
stdoutPipes := make([]io.ReadCloser, len(signingPartyIndices))
for i, idx := range signingPartyIndices {
args := []string{
"sign",
"-session-id", sessionID,
"-party-id", fmt.Sprintf("party-%d", idx),
"-party-index", fmt.Sprintf("%d", idx),
"-threshold-t", fmt.Sprintf("%d", thresholdT),
"-threshold-n", fmt.Sprintf("%d", thresholdN),
"-participants", string(participantsJSON),
"-message-hash", messageHash,
"-share-data", encryptedShares[i],
"-password", password,
}
t.Logf("[Party %d] Starting with args: sign -session-id %s -party-id party-%d ...", idx, sessionID, idx)
cmd := exec.Command(exePath, args...)
stdin, _ := cmd.StdinPipe()
stdout, _ := cmd.StdoutPipe()
cmd.Stderr = os.Stderr
processes[i] = cmd
stdinPipes[i] = stdin
stdoutPipes[i] = stdout
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start party %d: %w", idx, err)
}
}
// Message routing between parties
var wg sync.WaitGroup
results := make([][]byte, len(signingPartyIndices))
errors := make([]error, len(signingPartyIndices))
// Create a mutex for stdin writes
stdinMutexes := make([]*sync.Mutex, len(signingPartyIndices))
for i := range stdinMutexes {
stdinMutexes[i] = &sync.Mutex{}
}
for i := range signingPartyIndices {
wg.Add(1)
go func(partyIdx int) {
defer wg.Done()
scanner := bufio.NewScanner(stdoutPipes[partyIdx])
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
var msg Message
if err := json.Unmarshal([]byte(line), &msg); err != nil {
t.Logf("[Party %d] Invalid JSON: %s", signingPartyIndices[partyIdx], line)
continue
}
switch msg.Type {
case "progress":
t.Logf("[Party %d] Progress: round %d/%d", signingPartyIndices[partyIdx], msg.Round, msg.TotalRounds)
case "outgoing":
// tss-party.exe outputs "outgoing" messages that need to be routed to other parties
t.Logf("[Party %d] Outgoing message (broadcast=%v, toParties=%v)", signingPartyIndices[partyIdx], msg.IsBroadcast, msg.ToParties)
// Route to other parties
for j := range signingPartyIndices {
if j != partyIdx {
// If not broadcast, check if this party is in the ToParties list
if !msg.IsBroadcast && len(msg.ToParties) > 0 {
targetPartyID := fmt.Sprintf("party-%d", signingPartyIndices[j])
found := false
for _, to := range msg.ToParties {
if to == targetPartyID {
found = true
break
}
}
if !found {
continue
}
}
// tss-party.exe expects "incoming" messages
msgToSend := Message{
Type: "incoming",
IsBroadcast: msg.IsBroadcast,
Payload: msg.Payload,
FromPartyIndex: signingPartyIndices[partyIdx],
}
data, _ := json.Marshal(msgToSend)
stdinMutexes[j].Lock()
stdinPipes[j].Write(append(data, '\n'))
stdinMutexes[j].Unlock()
}
}
case "result":
t.Logf("[Party %d] Got result!", signingPartyIndices[partyIdx])
signature, _ := base64.StdEncoding.DecodeString(msg.Payload)
results[partyIdx] = signature
case "error":
t.Logf("[Party %d] Error: %s", signingPartyIndices[partyIdx], msg.Error)
errors[partyIdx] = fmt.Errorf("party error: %s", msg.Error)
}
}
if err := scanner.Err(); err != nil {
t.Logf("[Party %d] Scanner error: %v", signingPartyIndices[partyIdx], err)
}
}(i)
}
// Wait for processes to complete
for i, cmd := range processes {
if err := cmd.Wait(); err != nil {
t.Logf("[Party %d] Process exit error: %v", signingPartyIndices[i], err)
}
}
wg.Wait()
// Check for errors
for i, err := range errors {
if err != nil {
return nil, fmt.Errorf("party %d error: %w", signingPartyIndices[i], err)
}
}
// Return first result
for _, result := range results {
if result != nil {
return result, nil
}
}
return nil, fmt.Errorf("no signature received")
}