//go:build js && wasm // +build js,wasm // Package main provides TSS (Threshold Signature Scheme) functionality for WebAssembly // This module runs in the browser and communicates with JavaScript via callbacks package main import ( "encoding/base64" "encoding/json" "fmt" "math/big" "regexp" "strconv" "sync" "syscall/js" "github.com/bnb-chain/tss-lib/v2/common" "github.com/bnb-chain/tss-lib/v2/ecdsa/keygen" "github.com/bnb-chain/tss-lib/v2/ecdsa/signing" "github.com/bnb-chain/tss-lib/v2/tss" ) // Regex to extract round number from tss-lib message type // Message types look like: "binance.tsslib.ecdsa.keygen.KGRound1Message" // or "binance.tsslib.ecdsa.signing.SignRound3Message" var roundRegex = regexp.MustCompile(`Round(\d+)`) // Global state for active sessions var ( activeSessions = make(map[string]*TSSSession) sessionMutex sync.RWMutex ) // TSSSession holds the state for an active TSS session type TSSSession struct { SessionID string PartyID string PartyIndex int ThresholdT int ThresholdN int Participants []Participant LocalParty tss.Party OutCh chan tss.Message EndChKeygen chan *keygen.LocalPartySaveData EndChSign chan *common.SignatureData ErrCh chan error PartyIndexMap map[int]*tss.PartyID Password string IsKeygen bool Done chan struct{} OnMessage js.Value // JavaScript callback for outgoing messages OnProgress js.Value // JavaScript callback for progress updates OnComplete js.Value // JavaScript callback for completion OnError js.Value // JavaScript callback for errors } // Participant info type Participant struct { PartyID string `json:"partyId"` PartyIndex int `json:"partyIndex"` } // Message types for JS communication type JSMessage struct { Type string `json:"type"` IsBroadcast bool `json:"isBroadcast,omitempty"` ToParties []string `json:"toParties,omitempty"` Payload string `json:"payload,omitempty"` PublicKey string `json:"publicKey,omitempty"` EncryptedShare string `json:"encryptedShare,omitempty"` Signature string `json:"signature,omitempty"` PartyIndex int `json:"partyIndex,omitempty"` Round int `json:"round,omitempty"` TotalRounds int `json:"totalRounds,omitempty"` FromPartyIndex int `json:"fromPartyIndex,omitempty"` Error string `json:"error,omitempty"` } func main() { // Register JavaScript functions js.Global().Set("tssStartKeygen", js.FuncOf(startKeygen)) js.Global().Set("tssStartSigning", js.FuncOf(startSigning)) js.Global().Set("tssHandleMessage", js.FuncOf(handleMessage)) js.Global().Set("tssStopSession", js.FuncOf(stopSession)) js.Global().Set("tssGetVersion", js.FuncOf(getVersion)) // Keep the program running select {} } // getVersion returns the TSS WASM version func getVersion(this js.Value, args []js.Value) interface{} { return "1.0.0" } // startKeygen initializes a keygen session // Arguments: sessionId, partyId, partyIndex, thresholdT, thresholdN, participantsJSON, password, onMessage, onProgress, onComplete, onError func startKeygen(this js.Value, args []js.Value) interface{} { if len(args) < 11 { return createErrorResult("Missing required arguments") } sessionID := args[0].String() partyID := args[1].String() partyIndex := args[2].Int() thresholdT := args[3].Int() thresholdN := args[4].Int() participantsJSON := args[5].String() password := args[6].String() onMessage := args[7] onProgress := args[8] onComplete := args[9] onError := args[10] // Parse participants var participants []Participant if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil { return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err)) } if len(participants) != thresholdN { return createErrorResult(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), thresholdN)) } // Create session session := &TSSSession{ SessionID: sessionID, PartyID: partyID, PartyIndex: partyIndex, ThresholdT: thresholdT, ThresholdN: thresholdN, Participants: participants, OutCh: make(chan tss.Message, thresholdN*10), EndChKeygen: make(chan *keygen.LocalPartySaveData, 1), ErrCh: make(chan error, 1), Password: password, IsKeygen: true, Done: make(chan struct{}), OnMessage: onMessage, OnProgress: onProgress, OnComplete: onComplete, OnError: onError, } // Create TSS party IDs tssPartyIDs := make([]*tss.PartyID, len(participants)) var selfTSSID *tss.PartyID for i, p := range participants { partyKey := tss.NewPartyID( p.PartyID, fmt.Sprintf("party-%d", p.PartyIndex), big.NewInt(int64(p.PartyIndex+1)), ) tssPartyIDs[i] = partyKey if p.PartyID == partyID { selfTSSID = partyKey } } if selfTSSID == nil { return createErrorResult("Self party not found in participants") } // Sort party IDs sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) // Create peer context and parameters // IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required // User says "2-of-3" meaning 2 signers needed, so we pass (thresholdT-1) to TSS-lib peerCtx := tss.NewPeerContext(sortedPartyIDs) tssThreshold := thresholdT - 1 params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold) // Build party index map session.PartyIndexMap = make(map[int]*tss.PartyID) for _, p := range sortedPartyIDs { for _, orig := range participants { if orig.PartyID == p.Id { session.PartyIndexMap[orig.PartyIndex] = p break } } } // Create local party session.LocalParty = keygen.NewLocalParty(params, session.OutCh, session.EndChKeygen) // Store session sessionMutex.Lock() activeSessions[sessionID] = session sessionMutex.Unlock() // Start goroutines go session.handleOutgoingMessages() go session.waitForKeygenCompletion() // Start the local party go func() { if err := session.LocalParty.Start(); err != nil { session.ErrCh <- err } }() return createSuccessResult(map[string]interface{}{ "sessionId": sessionID, "started": true, }) } // startSigning initializes a signing session // Arguments: sessionId, partyId, partyIndex, thresholdT, participantsJSON, saveDataJSON, messageHash, onMessage, onProgress, onComplete, onError func startSigning(this js.Value, args []js.Value) interface{} { if len(args) < 11 { return createErrorResult("Missing required arguments") } sessionID := args[0].String() partyID := args[1].String() partyIndex := args[2].Int() thresholdT := args[3].Int() participantsJSON := args[4].String() saveDataJSON := args[5].String() messageHashB64 := args[6].String() onMessage := args[7] onProgress := args[8] onComplete := args[9] onError := args[10] // Parse participants var participants []Participant if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil { return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err)) } // Parse save data (keygen result) var saveData keygen.LocalPartySaveData if err := json.Unmarshal([]byte(saveDataJSON), &saveData); err != nil { return createErrorResult(fmt.Sprintf("Failed to parse save data: %v", err)) } // Decode message hash messageHash, err := base64.StdEncoding.DecodeString(messageHashB64) if err != nil { return createErrorResult(fmt.Sprintf("Failed to decode message hash: %v", err)) } thresholdN := len(participants) // Create session session := &TSSSession{ SessionID: sessionID, PartyID: partyID, PartyIndex: partyIndex, ThresholdT: thresholdT, ThresholdN: thresholdN, Participants: participants, OutCh: make(chan tss.Message, thresholdN*10), EndChSign: make(chan *common.SignatureData, 1), ErrCh: make(chan error, 1), IsKeygen: false, Done: make(chan struct{}), OnMessage: onMessage, OnProgress: onProgress, OnComplete: onComplete, OnError: onError, } // Create TSS party IDs tssPartyIDs := make([]*tss.PartyID, len(participants)) var selfTSSID *tss.PartyID for i, p := range participants { partyKey := tss.NewPartyID( p.PartyID, fmt.Sprintf("party-%d", p.PartyIndex), big.NewInt(int64(p.PartyIndex+1)), ) tssPartyIDs[i] = partyKey if p.PartyID == partyID { selfTSSID = partyKey } } if selfTSSID == nil { return createErrorResult("Self party not found in participants") } // Sort party IDs sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) // Create peer context and parameters // IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required // This MUST match keygen exactly! Both use (thresholdT-1) peerCtx := tss.NewPeerContext(sortedPartyIDs) tssThreshold := thresholdT - 1 params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold) // Build party index map session.PartyIndexMap = make(map[int]*tss.PartyID) for _, p := range sortedPartyIDs { for _, orig := range participants { if orig.PartyID == p.Id { session.PartyIndexMap[orig.PartyIndex] = p break } } } // Create message hash as big.Int msgHashBig := new(big.Int).SetBytes(messageHash) // CRITICAL: Build a subset of the keygen save data for the current signing parties // This is required when signing with a subset of the original keygen participants. subsetSaveData := keygen.BuildLocalSaveDataSubset(saveData, sortedPartyIDs) // Create local signing party with the SUBSET save data session.LocalParty = signing.NewLocalParty(msgHashBig, params, subsetSaveData, session.OutCh, session.EndChSign) // Store session sessionMutex.Lock() activeSessions[sessionID] = session sessionMutex.Unlock() // Start goroutines go session.handleOutgoingMessages() go session.waitForSigningCompletion() // Start the local party go func() { if err := session.LocalParty.Start(); err != nil { session.ErrCh <- err } }() return createSuccessResult(map[string]interface{}{ "sessionId": sessionID, "started": true, }) } // handleMessage processes an incoming TSS message // Arguments: sessionId, fromPartyIndex, isBroadcast, payloadBase64 func handleMessage(this js.Value, args []js.Value) interface{} { if len(args) < 4 { return createErrorResult("Missing required arguments") } sessionID := args[0].String() fromPartyIndex := args[1].Int() isBroadcast := args[2].Bool() payloadB64 := args[3].String() sessionMutex.RLock() session, exists := activeSessions[sessionID] sessionMutex.RUnlock() if !exists { return createErrorResult("Session not found") } fromParty, ok := session.PartyIndexMap[fromPartyIndex] if !ok { return createErrorResult("Unknown party index") } payload, err := base64.StdEncoding.DecodeString(payloadB64) if err != nil { return createErrorResult(fmt.Sprintf("Failed to decode payload: %v", err)) } parsedMsg, err := tss.ParseWireMessage(payload, fromParty, isBroadcast) if err != nil { return createErrorResult(fmt.Sprintf("Failed to parse message: %v", err)) } go func() { _, err := session.LocalParty.Update(parsedMsg) if err != nil && !isDuplicateError(err) { session.ErrCh <- err } }() return createSuccessResult(nil) } // stopSession stops an active TSS session func stopSession(this js.Value, args []js.Value) interface{} { if len(args) < 1 { return createErrorResult("Missing session ID") } sessionID := args[0].String() sessionMutex.Lock() session, exists := activeSessions[sessionID] if exists { close(session.Done) delete(activeSessions, sessionID) } sessionMutex.Unlock() if !exists { return createErrorResult("Session not found") } return createSuccessResult(nil) } // extractRoundFromMessageType parses the round number from a tss-lib message type string. // Returns 0 if parsing fails (safe fallback). // Example: "binance.tsslib.ecdsa.keygen.KGRound2Message1" -> 2 func extractRoundFromMessageType(msgType string) int { matches := roundRegex.FindStringSubmatch(msgType) if len(matches) >= 2 { if round, err := strconv.Atoi(matches[1]); err == nil { return round } } return 0 // Safe fallback - doesn't affect protocol, just shows 0 in UI } // handleOutgoingMessages processes messages from the TSS protocol func (s *TSSSession) handleOutgoingMessages() { totalRounds := 4 // GG20 keygen has 4 rounds if !s.IsKeygen { totalRounds = 9 // GG20 signing has 9 rounds (matching Electron and Android) } for { select { case <-s.Done: return case msg, ok := <-s.OutCh: if !ok { return } msgBytes, _, err := msg.WireBytes() if err != nil { continue } var toParties []string if !msg.IsBroadcast() { for _, to := range msg.GetTo() { toParties = append(toParties, to.Id) } } // Call JavaScript callback jsMsg := map[string]interface{}{ "type": "outgoing", "isBroadcast": msg.IsBroadcast(), "toParties": toParties, "payload": base64.StdEncoding.EncodeToString(msgBytes), } jsMsgJSON, _ := json.Marshal(jsMsg) s.OnMessage.Invoke(string(jsMsgJSON)) // Extract current round from message type and send progress update currentRound := extractRoundFromMessageType(msg.Type()) s.OnProgress.Invoke(currentRound, totalRounds) } } } // waitForKeygenCompletion waits for keygen to complete func (s *TSSSession) waitForKeygenCompletion() { select { case <-s.Done: return case err := <-s.ErrCh: s.OnError.Invoke(err.Error()) s.cleanup() case saveData := <-s.EndChKeygen: // Get public key (33 bytes compressed) pubKey := saveData.ECDSAPub.ToECDSAPubKey() pubKeyBytes := make([]byte, 33) pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0)) xBytes := pubKey.X.Bytes() copy(pubKeyBytes[33-len(xBytes):], xBytes) // Serialize save data saveDataBytes, err := json.Marshal(saveData) if err != nil { s.OnError.Invoke(fmt.Sprintf("Failed to serialize save data: %v", err)) s.cleanup() return } // Encrypt with password encryptedShare := encryptShare(saveDataBytes, s.Password) // Call completion callback result := map[string]interface{}{ "type": "result", "publicKey": base64.StdEncoding.EncodeToString(pubKeyBytes), "encryptedShare": base64.StdEncoding.EncodeToString(encryptedShare), "partyIndex": s.PartyIndex, } resultJSON, _ := json.Marshal(result) s.OnComplete.Invoke(string(resultJSON)) s.cleanup() } } // waitForSigningCompletion waits for signing to complete func (s *TSSSession) waitForSigningCompletion() { select { case <-s.Done: return case err := <-s.ErrCh: s.OnError.Invoke(err.Error()) s.cleanup() case sigData := <-s.EndChSign: // Serialize signature (R, S format) sigBytes := append(sigData.GetR(), sigData.GetS()...) result := map[string]interface{}{ "type": "result", "signature": base64.StdEncoding.EncodeToString(sigBytes), "partyIndex": s.PartyIndex, } resultJSON, _ := json.Marshal(result) s.OnComplete.Invoke(string(resultJSON)) s.cleanup() } } func (s *TSSSession) cleanup() { sessionMutex.Lock() delete(activeSessions, s.SessionID) sessionMutex.Unlock() } func isDuplicateError(err error) bool { if err == nil { return false } errStr := err.Error() return contains(errStr, "duplicate") || contains(errStr, "already received") } func contains(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } func encryptShare(data []byte, password string) []byte { // TODO: Use proper AES-256-GCM encryption // For now, just prepend a marker and the password hash result := make([]byte, len(data)+32) copy(result[:32], hashPassword(password)) copy(result[32:], data) return result } func hashPassword(password string) []byte { hash := make([]byte, 32) for i := 0; i < len(password) && i < 32; i++ { hash[i] = password[i] } return hash } func createSuccessResult(data interface{}) map[string]interface{} { result := map[string]interface{}{ "success": true, } if data != nil { result["data"] = data } return result } func createErrorResult(errMsg string) map[string]interface{} { return map[string]interface{}{ "success": false, "error": errMsg, } }