rwadurian/backend/mpc-system/services/tss-wasm/main.go

591 lines
16 KiB
Go

//go:build js && wasm
// +build js,wasm
// Package main provides TSS (Threshold Signature Scheme) functionality for WebAssembly
// This module runs in the browser and communicates with JavaScript via callbacks
package main
import (
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"regexp"
"strconv"
"sync"
"syscall/js"
"github.com/bnb-chain/tss-lib/v2/common"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Regex to extract round number from tss-lib message type
// Message types look like: "binance.tsslib.ecdsa.keygen.KGRound1Message"
// or "binance.tsslib.ecdsa.signing.SignRound3Message"
var roundRegex = regexp.MustCompile(`Round(\d+)`)
// Global state for active sessions
var (
activeSessions = make(map[string]*TSSSession)
sessionMutex sync.RWMutex
)
// TSSSession holds the state for an active TSS session
type TSSSession struct {
SessionID string
PartyID string
PartyIndex int
ThresholdT int
ThresholdN int
Participants []Participant
LocalParty tss.Party
OutCh chan tss.Message
EndChKeygen chan *keygen.LocalPartySaveData
EndChSign chan *common.SignatureData
ErrCh chan error
PartyIndexMap map[int]*tss.PartyID
Password string
IsKeygen bool
Done chan struct{}
OnMessage js.Value // JavaScript callback for outgoing messages
OnProgress js.Value // JavaScript callback for progress updates
OnComplete js.Value // JavaScript callback for completion
OnError js.Value // JavaScript callback for errors
}
// Participant info
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
// Message types for JS communication
type JSMessage struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
EncryptedShare string `json:"encryptedShare,omitempty"`
Signature string `json:"signature,omitempty"`
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
}
func main() {
// Register JavaScript functions
js.Global().Set("tssStartKeygen", js.FuncOf(startKeygen))
js.Global().Set("tssStartSigning", js.FuncOf(startSigning))
js.Global().Set("tssHandleMessage", js.FuncOf(handleMessage))
js.Global().Set("tssStopSession", js.FuncOf(stopSession))
js.Global().Set("tssGetVersion", js.FuncOf(getVersion))
// Keep the program running
select {}
}
// getVersion returns the TSS WASM version
func getVersion(this js.Value, args []js.Value) interface{} {
return "1.0.0"
}
// startKeygen initializes a keygen session
// Arguments: sessionId, partyId, partyIndex, thresholdT, thresholdN, participantsJSON, password, onMessage, onProgress, onComplete, onError
func startKeygen(this js.Value, args []js.Value) interface{} {
if len(args) < 11 {
return createErrorResult("Missing required arguments")
}
sessionID := args[0].String()
partyID := args[1].String()
partyIndex := args[2].Int()
thresholdT := args[3].Int()
thresholdN := args[4].Int()
participantsJSON := args[5].String()
password := args[6].String()
onMessage := args[7]
onProgress := args[8]
onComplete := args[9]
onError := args[10]
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err))
}
if len(participants) != thresholdN {
return createErrorResult(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), thresholdN))
}
// Create session
session := &TSSSession{
SessionID: sessionID,
PartyID: partyID,
PartyIndex: partyIndex,
ThresholdT: thresholdT,
ThresholdN: thresholdN,
Participants: participants,
OutCh: make(chan tss.Message, thresholdN*10),
EndChKeygen: make(chan *keygen.LocalPartySaveData, 1),
ErrCh: make(chan error, 1),
Password: password,
IsKeygen: true,
Done: make(chan struct{}),
OnMessage: onMessage,
OnProgress: onProgress,
OnComplete: onComplete,
OnError: onError,
}
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return createErrorResult("Self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
// User says "2-of-3" meaning 2 signers needed, so we pass (thresholdT-1) to TSS-lib
peerCtx := tss.NewPeerContext(sortedPartyIDs)
tssThreshold := thresholdT - 1
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
// Build party index map
session.PartyIndexMap = make(map[int]*tss.PartyID)
for _, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
session.PartyIndexMap[orig.PartyIndex] = p
break
}
}
}
// Create local party
session.LocalParty = keygen.NewLocalParty(params, session.OutCh, session.EndChKeygen)
// Store session
sessionMutex.Lock()
activeSessions[sessionID] = session
sessionMutex.Unlock()
// Start goroutines
go session.handleOutgoingMessages()
go session.waitForKeygenCompletion()
// Start the local party
go func() {
if err := session.LocalParty.Start(); err != nil {
session.ErrCh <- err
}
}()
return createSuccessResult(map[string]interface{}{
"sessionId": sessionID,
"started": true,
})
}
// startSigning initializes a signing session
// Arguments: sessionId, partyId, partyIndex, thresholdT, participantsJSON, saveDataJSON, messageHash, onMessage, onProgress, onComplete, onError
func startSigning(this js.Value, args []js.Value) interface{} {
if len(args) < 11 {
return createErrorResult("Missing required arguments")
}
sessionID := args[0].String()
partyID := args[1].String()
partyIndex := args[2].Int()
thresholdT := args[3].Int()
participantsJSON := args[4].String()
saveDataJSON := args[5].String()
messageHashB64 := args[6].String()
onMessage := args[7]
onProgress := args[8]
onComplete := args[9]
onError := args[10]
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err))
}
// Parse save data (keygen result)
var saveData keygen.LocalPartySaveData
if err := json.Unmarshal([]byte(saveDataJSON), &saveData); err != nil {
return createErrorResult(fmt.Sprintf("Failed to parse save data: %v", err))
}
// Decode message hash
messageHash, err := base64.StdEncoding.DecodeString(messageHashB64)
if err != nil {
return createErrorResult(fmt.Sprintf("Failed to decode message hash: %v", err))
}
thresholdN := len(participants)
// Create session
session := &TSSSession{
SessionID: sessionID,
PartyID: partyID,
PartyIndex: partyIndex,
ThresholdT: thresholdT,
ThresholdN: thresholdN,
Participants: participants,
OutCh: make(chan tss.Message, thresholdN*10),
EndChSign: make(chan *common.SignatureData, 1),
ErrCh: make(chan error, 1),
IsKeygen: false,
Done: make(chan struct{}),
OnMessage: onMessage,
OnProgress: onProgress,
OnComplete: onComplete,
OnError: onError,
}
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return createErrorResult("Self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
// This MUST match keygen exactly! Both use (thresholdT-1)
peerCtx := tss.NewPeerContext(sortedPartyIDs)
tssThreshold := thresholdT - 1
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
// Build party index map
session.PartyIndexMap = make(map[int]*tss.PartyID)
for _, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
session.PartyIndexMap[orig.PartyIndex] = p
break
}
}
}
// Create message hash as big.Int
msgHashBig := new(big.Int).SetBytes(messageHash)
// CRITICAL: Build a subset of the keygen save data for the current signing parties
// This is required when signing with a subset of the original keygen participants.
subsetSaveData := keygen.BuildLocalSaveDataSubset(saveData, sortedPartyIDs)
// Create local signing party with the SUBSET save data
session.LocalParty = signing.NewLocalParty(msgHashBig, params, subsetSaveData, session.OutCh, session.EndChSign)
// Store session
sessionMutex.Lock()
activeSessions[sessionID] = session
sessionMutex.Unlock()
// Start goroutines
go session.handleOutgoingMessages()
go session.waitForSigningCompletion()
// Start the local party
go func() {
if err := session.LocalParty.Start(); err != nil {
session.ErrCh <- err
}
}()
return createSuccessResult(map[string]interface{}{
"sessionId": sessionID,
"started": true,
})
}
// handleMessage processes an incoming TSS message
// Arguments: sessionId, fromPartyIndex, isBroadcast, payloadBase64
func handleMessage(this js.Value, args []js.Value) interface{} {
if len(args) < 4 {
return createErrorResult("Missing required arguments")
}
sessionID := args[0].String()
fromPartyIndex := args[1].Int()
isBroadcast := args[2].Bool()
payloadB64 := args[3].String()
sessionMutex.RLock()
session, exists := activeSessions[sessionID]
sessionMutex.RUnlock()
if !exists {
return createErrorResult("Session not found")
}
fromParty, ok := session.PartyIndexMap[fromPartyIndex]
if !ok {
return createErrorResult("Unknown party index")
}
payload, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
return createErrorResult(fmt.Sprintf("Failed to decode payload: %v", err))
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, isBroadcast)
if err != nil {
return createErrorResult(fmt.Sprintf("Failed to parse message: %v", err))
}
go func() {
_, err := session.LocalParty.Update(parsedMsg)
if err != nil && !isDuplicateError(err) {
session.ErrCh <- err
}
}()
return createSuccessResult(nil)
}
// stopSession stops an active TSS session
func stopSession(this js.Value, args []js.Value) interface{} {
if len(args) < 1 {
return createErrorResult("Missing session ID")
}
sessionID := args[0].String()
sessionMutex.Lock()
session, exists := activeSessions[sessionID]
if exists {
close(session.Done)
delete(activeSessions, sessionID)
}
sessionMutex.Unlock()
if !exists {
return createErrorResult("Session not found")
}
return createSuccessResult(nil)
}
// extractRoundFromMessageType parses the round number from a tss-lib message type string.
// Returns 0 if parsing fails (safe fallback).
// Example: "binance.tsslib.ecdsa.keygen.KGRound2Message1" -> 2
func extractRoundFromMessageType(msgType string) int {
matches := roundRegex.FindStringSubmatch(msgType)
if len(matches) >= 2 {
if round, err := strconv.Atoi(matches[1]); err == nil {
return round
}
}
return 0 // Safe fallback - doesn't affect protocol, just shows 0 in UI
}
// handleOutgoingMessages processes messages from the TSS protocol
func (s *TSSSession) handleOutgoingMessages() {
totalRounds := 4 // GG20 keygen has 4 rounds
if !s.IsKeygen {
totalRounds = 9 // GG20 signing has 9 rounds (matching Electron and Android)
}
for {
select {
case <-s.Done:
return
case msg, ok := <-s.OutCh:
if !ok {
return
}
msgBytes, _, err := msg.WireBytes()
if err != nil {
continue
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
// Call JavaScript callback
jsMsg := map[string]interface{}{
"type": "outgoing",
"isBroadcast": msg.IsBroadcast(),
"toParties": toParties,
"payload": base64.StdEncoding.EncodeToString(msgBytes),
}
jsMsgJSON, _ := json.Marshal(jsMsg)
s.OnMessage.Invoke(string(jsMsgJSON))
// Extract current round from message type and send progress update
currentRound := extractRoundFromMessageType(msg.Type())
s.OnProgress.Invoke(currentRound, totalRounds)
}
}
}
// waitForKeygenCompletion waits for keygen to complete
func (s *TSSSession) waitForKeygenCompletion() {
select {
case <-s.Done:
return
case err := <-s.ErrCh:
s.OnError.Invoke(err.Error())
s.cleanup()
case saveData := <-s.EndChKeygen:
// Get public key (33 bytes compressed)
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
s.OnError.Invoke(fmt.Sprintf("Failed to serialize save data: %v", err))
s.cleanup()
return
}
// Encrypt with password
encryptedShare := encryptShare(saveDataBytes, s.Password)
// Call completion callback
result := map[string]interface{}{
"type": "result",
"publicKey": base64.StdEncoding.EncodeToString(pubKeyBytes),
"encryptedShare": base64.StdEncoding.EncodeToString(encryptedShare),
"partyIndex": s.PartyIndex,
}
resultJSON, _ := json.Marshal(result)
s.OnComplete.Invoke(string(resultJSON))
s.cleanup()
}
}
// waitForSigningCompletion waits for signing to complete
func (s *TSSSession) waitForSigningCompletion() {
select {
case <-s.Done:
return
case err := <-s.ErrCh:
s.OnError.Invoke(err.Error())
s.cleanup()
case sigData := <-s.EndChSign:
// Serialize signature (R, S format)
sigBytes := append(sigData.GetR(), sigData.GetS()...)
result := map[string]interface{}{
"type": "result",
"signature": base64.StdEncoding.EncodeToString(sigBytes),
"partyIndex": s.PartyIndex,
}
resultJSON, _ := json.Marshal(result)
s.OnComplete.Invoke(string(resultJSON))
s.cleanup()
}
}
func (s *TSSSession) cleanup() {
sessionMutex.Lock()
delete(activeSessions, s.SessionID)
sessionMutex.Unlock()
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
func hashPassword(password string) []byte {
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}
func createSuccessResult(data interface{}) map[string]interface{} {
result := map[string]interface{}{
"success": true,
}
if data != nil {
result["data"] = data
}
return result
}
func createErrorResult(errMsg string) map[string]interface{} {
return map[string]interface{}{
"success": false,
"error": errMsg,
}
}