// integration_test.go - Integration test for the complete co-sign flow // Tests: session creation, joining, waiting, events, and signing package main import ( "bufio" "bytes" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "os" "os/exec" "sync" "testing" "time" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/tss" ) const ( // Backend service URLs (from docker-compose.windows.yml) accountServiceURL = "http://localhost:4000" grpcRouterAddr = "localhost:50051" ) // API Response types for co-managed keygen type CreateCoManagedSessionRequest struct { WalletName string `json:"wallet_name"` ThresholdT int `json:"threshold_t"` ThresholdN int `json:"threshold_n"` InitiatorPartyID string `json:"initiator_party_id"` InitiatorName string `json:"initiator_name,omitempty"` PersistentCount int `json:"persistent_count"` } type CreateCoManagedSessionResponse struct { SessionID string `json:"session_id"` InviteCode string `json:"invite_code"` WalletName string `json:"wallet_name"` ThresholdT int `json:"threshold_t"` ThresholdN int `json:"threshold_n"` ExpiresAt int64 `json:"expires_at"` JoinToken string `json:"join_token"` PartyID string `json:"party_id"` PartyIndex int `json:"party_index"` } // API Response types for co-managed sign type CreateSignSessionRequest struct { KeygenSessionID string `json:"keygen_session_id"` WalletName string `json:"wallet_name"` MessageHash string `json:"message_hash"` Parties []SignPartyInfo `json:"parties"` ThresholdT int `json:"threshold_t"` InitiatorName string `json:"initiator_name,omitempty"` } type CreateSignSessionResponse struct { SessionID string `json:"session_id"` InviteCode string `json:"invite_code"` KeygenSessionID string `json:"keygen_session_id"` MessageHash string `json:"message_hash"` ThresholdT int `json:"threshold_t"` ExpiresAt int64 `json:"expires_at"` JoinToken string `json:"join_token"` } type GetSignSessionResponse struct { SessionID string `json:"session_id"` KeygenSessionID string `json:"keygen_session_id"` WalletName string `json:"wallet_name"` MessageHash string `json:"message_hash"` ThresholdT int `json:"threshold_t"` Status string `json:"status"` InviteCode string `json:"invite_code"` ExpiresAt int64 `json:"expires_at"` Parties []SignPartyInfo `json:"parties"` JoinedCount int `json:"joined_count"` JoinToken string `json:"join_token,omitempty"` } type SignPartyInfo struct { PartyID string `json:"party_id"` PartyIndex int `json:"party_index"` } // TestAccountServiceHealth tests if account service is available func TestAccountServiceHealth(t *testing.T) { resp, err := http.Get(accountServiceURL + "/health") if err != nil { t.Fatalf("Failed to connect to account service: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("Account service unhealthy, status: %d", resp.StatusCode) } t.Log("Account service is healthy") } // TestCreateSignSession tests creating a new sign session func TestCreateSignSession(t *testing.T) { // This requires an existing keygen session ID // For testing, we'll use a mock one keygenSessionID := "test-keygen-session-" + fmt.Sprintf("%d", time.Now().UnixNano()) messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" reqBody := CreateSignSessionRequest{ KeygenSessionID: keygenSessionID, MessageHash: messageHash, InitiatorName: "test-initiator", } jsonBody, _ := json.Marshal(reqBody) resp, err := http.Post( accountServiceURL+"/api/v1/co-managed/sign", "application/json", bytes.NewBuffer(jsonBody), ) if err != nil { t.Fatalf("Failed to create sign session: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body)) if resp.StatusCode != 200 && resp.StatusCode != 201 { t.Logf("Note: This test requires an existing keygen session in the database") t.Skipf("Sign session creation returned status %d (expected existing keygen session)", resp.StatusCode) } var result CreateSignSessionResponse if err := json.Unmarshal(body, &result); err != nil { t.Fatalf("Failed to parse response: %v", err) } t.Logf("Session created: ID=%s, InviteCode=%s", result.SessionID, result.InviteCode) } // TestGetSessionByInviteCode tests retrieving session info by invite code func TestGetSessionByInviteCode(t *testing.T) { // This would need a valid invite code from a real session inviteCode := "TEST123" resp, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + inviteCode) if err != nil { t.Fatalf("Failed to get session: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("Get session response (status %d): %s", resp.StatusCode, string(body)) if resp.StatusCode == 404 { t.Skip("No session found with invite code (expected for test)") } } // TestCoManagedKeygenSessionFlow tests the keygen session creation and join flow func TestCoManagedKeygenSessionFlow(t *testing.T) { t.Log("=== Co-Managed Keygen Session Flow Test ===") // Step 1: Create a keygen session t.Log("Step 1: Creating keygen session...") reqBody := CreateCoManagedSessionRequest{ WalletName: "test-wallet-" + fmt.Sprintf("%d", time.Now().UnixNano()), ThresholdT: 1, // 2-of-3: t=1 means t+1=2 signers needed ThresholdN: 3, InitiatorPartyID: "test-initiator-party", InitiatorName: "Test User", PersistentCount: 0, // No server parties for this test } jsonBody, _ := json.Marshal(reqBody) resp, err := http.Post( accountServiceURL+"/api/v1/co-managed/sessions", "application/json", bytes.NewBuffer(jsonBody), ) if err != nil { t.Fatalf("Failed to create session: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("Create session response (status %d): %s", resp.StatusCode, string(body)) if resp.StatusCode != 200 && resp.StatusCode != 201 { t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode) } var createResp CreateCoManagedSessionResponse if err := json.Unmarshal(body, &createResp); err != nil { t.Fatalf("Failed to parse response: %v", err) } t.Logf("Session created: ID=%s, InviteCode=%s", createResp.SessionID, createResp.InviteCode) // Step 2: Get session by invite code t.Log("Step 2: Getting session by invite code...") resp2, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/by-invite-code/" + createResp.InviteCode) if err != nil { t.Fatalf("Failed to get session: %v", err) } defer resp2.Body.Close() body2, _ := io.ReadAll(resp2.Body) t.Logf("Get session response (status %d): %s", resp2.StatusCode, string(body2)) if resp2.StatusCode != 200 { t.Fatalf("Failed to get session by invite code, status: %d", resp2.StatusCode) } // Step 3: Get session status t.Log("Step 3: Getting session status...") resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sessions/" + createResp.SessionID) if err != nil { t.Fatalf("Failed to get session status: %v", err) } defer resp3.Body.Close() body3, _ := io.ReadAll(resp3.Body) t.Logf("Session status response (status %d): %s", resp3.StatusCode, string(body3)) if resp3.StatusCode != 200 { t.Fatalf("Failed to get session status, status: %d", resp3.StatusCode) } t.Log("Keygen session flow test passed!") } // TestCoManagedSignSessionFlow tests the full sign session flow func TestCoManagedSignSessionFlow(t *testing.T) { t.Log("=== Co-Managed Sign Session Flow Test ===") // First, create a keygen session to get a valid keygen_session_id t.Log("Step 1: Creating keygen session...") keygenReq := CreateCoManagedSessionRequest{ WalletName: "test-wallet-for-sign-" + fmt.Sprintf("%d", time.Now().UnixNano()), ThresholdT: 1, ThresholdN: 3, InitiatorPartyID: "test-party-0", InitiatorName: "Test User", PersistentCount: 0, } jsonBody, _ := json.Marshal(keygenReq) resp, err := http.Post( accountServiceURL+"/api/v1/co-managed/sessions", "application/json", bytes.NewBuffer(jsonBody), ) if err != nil { t.Fatalf("Failed to create keygen session: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("Keygen session response (status %d): %s", resp.StatusCode, string(body)) if resp.StatusCode != 200 && resp.StatusCode != 201 { t.Fatalf("Failed to create keygen session, status: %d", resp.StatusCode) } var keygenResp CreateCoManagedSessionResponse if err := json.Unmarshal(body, &keygenResp); err != nil { t.Fatalf("Failed to parse keygen response: %v", err) } // Step 2: Create a sign session using the keygen session ID // Note: For threshold T, we need T+1 parties to sign // Backend validates that parties.length >= threshold_t + 1 t.Log("Step 2: Creating sign session...") signReq := CreateSignSessionRequest{ KeygenSessionID: keygenResp.SessionID, WalletName: keygenReq.WalletName, MessageHash: "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", ThresholdT: 1, // Threshold T=1 means T+1=2 parties needed to sign Parties: []SignPartyInfo{ {PartyID: "test-party-0", PartyIndex: 0}, {PartyID: "test-party-1", PartyIndex: 1}, }, InitiatorName: "Test User", } jsonBody2, _ := json.Marshal(signReq) resp2, err := http.Post( accountServiceURL+"/api/v1/co-managed/sign", "application/json", bytes.NewBuffer(jsonBody2), ) if err != nil { t.Fatalf("Failed to create sign session: %v", err) } defer resp2.Body.Close() body2, _ := io.ReadAll(resp2.Body) t.Logf("Sign session response (status %d): %s", resp2.StatusCode, string(body2)) if resp2.StatusCode != 200 && resp2.StatusCode != 201 { t.Fatalf("Failed to create sign session, status: %d", resp2.StatusCode) } var signResp CreateSignSessionResponse if err := json.Unmarshal(body2, &signResp); err != nil { t.Fatalf("Failed to parse sign response: %v", err) } t.Logf("Sign session created: ID=%s, InviteCode=%s", signResp.SessionID, signResp.InviteCode) // Step 3: Get sign session by invite code t.Log("Step 3: Getting sign session by invite code...") resp3, err := http.Get(accountServiceURL + "/api/v1/co-managed/sign/by-invite-code/" + signResp.InviteCode) if err != nil { t.Fatalf("Failed to get sign session: %v", err) } defer resp3.Body.Close() body3, _ := io.ReadAll(resp3.Body) t.Logf("Get sign session response (status %d): %s", resp3.StatusCode, string(body3)) if resp3.StatusCode != 200 { t.Fatalf("Failed to get sign session by invite code, status: %d", resp3.StatusCode) } var getSignResp GetSignSessionResponse if err := json.Unmarshal(body3, &getSignResp); err != nil { t.Fatalf("Failed to parse get sign session response: %v", err) } t.Logf("Sign session status: %s, JoinedCount: %d, ThresholdT: %d", getSignResp.Status, getSignResp.JoinedCount, getSignResp.ThresholdT) t.Log("Sign session flow test passed!") } // TestFullSignFlow tests the complete signing flow with mock data func TestFullSignFlow(t *testing.T) { t.Log("=== Full Co-Sign Flow Test ===") // Step 1: Generate mock key shares (simulating keygen result) // NOTE: In tss-lib, threshold parameter means "t" where you need t+1 parties to sign. // So for 2-of-3, we use threshold=1 (meaning 1+1=2 parties needed to sign) t.Log("Step 1: Generating mock key shares for 2-of-3 scheme...") thresholdT := 1 // t value: need t+1=2 parties to sign totalN := 3 // total parties shares, err := generateMockKeyShares(thresholdT, totalN) if err != nil { t.Fatalf("Failed to generate key shares: %v", err) } t.Logf("Generated %d key shares", len(shares)) // Step 2: Prepare signing parameters messageHash := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" sessionID := fmt.Sprintf("test-sign-session-%d", time.Now().UnixNano()) password := "test-password" // Step 3: Start signing with 2 parties (index 0 and 1) // For tss-party.exe, we pass the signing threshold (t+1=2) and total (n=3) t.Log("Step 2: Starting sign process with 2 parties...") signingParties := []int{0, 1} // Pass thresholdT+1=2 as the signing threshold to tss-party.exe signature, err := runSigningProcess(shares, signingParties, messageHash, sessionID, password, thresholdT+1, totalN, t) if err != nil { t.Fatalf("Signing failed: %v", err) } t.Logf("Step 3: Signing complete!") t.Logf("Signature (hex): %x", signature) t.Logf("Signature (base64): %s", base64.StdEncoding.EncodeToString(signature)) } // generateMockKeyShares generates key shares using tss-lib directly func generateMockKeyShares(threshold, total int) ([]*keygen.LocalPartySaveData, error) { fmt.Println("[Keygen] Starting key generation for", total, "parties with threshold", threshold) partyIDs := make([]*tss.PartyID, total) for i := 0; i < total; i++ { partyIDs[i] = tss.NewPartyID( fmt.Sprintf("party-%d", i), fmt.Sprintf("party-%d", i), big.NewInt(int64(i+1)), ) } sortedPartyIDs := tss.SortPartyIDs(partyIDs) peerCtx := tss.NewPeerContext(sortedPartyIDs) outChannels := make([]chan tss.Message, total) endChannels := make([]chan *keygen.LocalPartySaveData, total) parties := make([]tss.Party, total) for i := 0; i < total; i++ { outChannels[i] = make(chan tss.Message, total*20) endChannels[i] = make(chan *keygen.LocalPartySaveData, 1) params := tss.NewParameters(tss.S256(), peerCtx, sortedPartyIDs[i], total, threshold) parties[i] = keygen.NewLocalParty(params, outChannels[i], endChannels[i]) } // Start all parties var wg sync.WaitGroup errChan := make(chan error, total) for i := 0; i < total; i++ { wg.Add(1) go func(idx int) { defer wg.Done() if err := parties[idx].Start(); err != nil { errChan <- fmt.Errorf("party %d start error: %w", idx, err) } }(i) } results := make([]*keygen.LocalPartySaveData, total) var resultsMu sync.Mutex resultCount := 0 done := make(chan struct{}) // Message routing go func() { for { select { case <-done: return default: for i := 0; i < total; i++ { select { case msg := <-outChannels[i]: wire, _, _ := msg.WireBytes() if msg.IsBroadcast() { for j := 0; j < total; j++ { if j != i { go func(destIdx int) { parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), true) parties[destIdx].Update(parsed) }(j) } } } else { for _, to := range msg.GetTo() { for j := 0; j < total; j++ { if sortedPartyIDs[j].Id == to.Id { go func(destIdx int) { parsed, _ := tss.ParseWireMessage(wire, msg.GetFrom(), false) parties[destIdx].Update(parsed) }(j) break } } } } case result := <-endChannels[i]: resultsMu.Lock() results[i] = result resultCount++ fmt.Printf("[Keygen] Party %d completed\n", i) if resultCount == total { close(done) } resultsMu.Unlock() default: } } time.Sleep(5 * time.Millisecond) } } }() wg.Wait() select { case <-done: fmt.Println("[Keygen] All parties completed successfully") case <-time.After(5 * time.Minute): return nil, fmt.Errorf("keygen timeout") } select { case err := <-errChan: return nil, err default: } return results, nil } // runSigningProcess runs the signing process using tss-party.exe func runSigningProcess( shares []*keygen.LocalPartySaveData, signingPartyIndices []int, messageHash, sessionID, password string, thresholdT, thresholdN int, t *testing.T, ) ([]byte, error) { // Build participants participants := make([]Participant, len(signingPartyIndices)) for i, idx := range signingPartyIndices { participants[i] = Participant{ PartyID: fmt.Sprintf("party-%d", idx), PartyIndex: idx, } } participantsJSON, _ := json.Marshal(participants) // Prepare encrypted shares encryptedShares := make([]string, len(signingPartyIndices)) for i, idx := range signingPartyIndices { shareBytes, _ := json.Marshal(shares[idx]) encrypted := encryptShare(shareBytes, password) encryptedShares[i] = base64.StdEncoding.EncodeToString(encrypted) } // Find tss-party executable exePath := "./tss-party.exe" if _, err := os.Stat(exePath); os.IsNotExist(err) { exePath = "./tss-party" if _, err := os.Stat(exePath); os.IsNotExist(err) { return nil, fmt.Errorf("tss-party executable not found") } } t.Logf("Using executable: %s", exePath) // Start processes processes := make([]*exec.Cmd, len(signingPartyIndices)) stdinPipes := make([]io.WriteCloser, len(signingPartyIndices)) stdoutPipes := make([]io.ReadCloser, len(signingPartyIndices)) for i, idx := range signingPartyIndices { args := []string{ "sign", "-session-id", sessionID, "-party-id", fmt.Sprintf("party-%d", idx), "-party-index", fmt.Sprintf("%d", idx), "-threshold-t", fmt.Sprintf("%d", thresholdT), "-threshold-n", fmt.Sprintf("%d", thresholdN), "-participants", string(participantsJSON), "-message-hash", messageHash, "-share-data", encryptedShares[i], "-password", password, } t.Logf("[Party %d] Starting with args: sign -session-id %s -party-id party-%d ...", idx, sessionID, idx) cmd := exec.Command(exePath, args...) stdin, _ := cmd.StdinPipe() stdout, _ := cmd.StdoutPipe() cmd.Stderr = os.Stderr processes[i] = cmd stdinPipes[i] = stdin stdoutPipes[i] = stdout if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start party %d: %w", idx, err) } } // Message routing between parties var wg sync.WaitGroup results := make([][]byte, len(signingPartyIndices)) errors := make([]error, len(signingPartyIndices)) // Create a mutex for stdin writes stdinMutexes := make([]*sync.Mutex, len(signingPartyIndices)) for i := range stdinMutexes { stdinMutexes[i] = &sync.Mutex{} } for i := range signingPartyIndices { wg.Add(1) go func(partyIdx int) { defer wg.Done() scanner := bufio.NewScanner(stdoutPipes[partyIdx]) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() var msg Message if err := json.Unmarshal([]byte(line), &msg); err != nil { t.Logf("[Party %d] Invalid JSON: %s", signingPartyIndices[partyIdx], line) continue } switch msg.Type { case "progress": t.Logf("[Party %d] Progress: round %d/%d", signingPartyIndices[partyIdx], msg.Round, msg.TotalRounds) case "outgoing": // tss-party.exe outputs "outgoing" messages that need to be routed to other parties t.Logf("[Party %d] Outgoing message (broadcast=%v, toParties=%v)", signingPartyIndices[partyIdx], msg.IsBroadcast, msg.ToParties) // Route to other parties for j := range signingPartyIndices { if j != partyIdx { // If not broadcast, check if this party is in the ToParties list if !msg.IsBroadcast && len(msg.ToParties) > 0 { targetPartyID := fmt.Sprintf("party-%d", signingPartyIndices[j]) found := false for _, to := range msg.ToParties { if to == targetPartyID { found = true break } } if !found { continue } } // tss-party.exe expects "incoming" messages msgToSend := Message{ Type: "incoming", IsBroadcast: msg.IsBroadcast, Payload: msg.Payload, FromPartyIndex: signingPartyIndices[partyIdx], } data, _ := json.Marshal(msgToSend) stdinMutexes[j].Lock() stdinPipes[j].Write(append(data, '\n')) stdinMutexes[j].Unlock() } } case "result": t.Logf("[Party %d] Got result!", signingPartyIndices[partyIdx]) signature, _ := base64.StdEncoding.DecodeString(msg.Payload) results[partyIdx] = signature case "error": t.Logf("[Party %d] Error: %s", signingPartyIndices[partyIdx], msg.Error) errors[partyIdx] = fmt.Errorf("party error: %s", msg.Error) } } if err := scanner.Err(); err != nil { t.Logf("[Party %d] Scanner error: %v", signingPartyIndices[partyIdx], err) } }(i) } // Wait for processes to complete for i, cmd := range processes { if err := cmd.Wait(); err != nil { t.Logf("[Party %d] Process exit error: %v", signingPartyIndices[i], err) } } wg.Wait() // Check for errors for i, err := range errors { if err != nil { return nil, fmt.Errorf("party %d error: %w", signingPartyIndices[i], err) } } // Return first result for _, result := range results { if result != nil { return result, nil } } return nil, fmt.Errorf("no signature received") }