591 lines
16 KiB
Go
591 lines
16 KiB
Go
//go:build js && wasm
|
|
// +build js,wasm
|
|
|
|
// Package main provides TSS (Threshold Signature Scheme) functionality for WebAssembly
|
|
// This module runs in the browser and communicates with JavaScript via callbacks
|
|
package main
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"regexp"
|
|
"strconv"
|
|
"sync"
|
|
"syscall/js"
|
|
|
|
"github.com/bnb-chain/tss-lib/v2/common"
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
|
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
|
)
|
|
|
|
// Regex to extract round number from tss-lib message type
|
|
// Message types look like: "binance.tsslib.ecdsa.keygen.KGRound1Message"
|
|
// or "binance.tsslib.ecdsa.signing.SignRound3Message"
|
|
var roundRegex = regexp.MustCompile(`Round(\d+)`)
|
|
|
|
// Global state for active sessions
|
|
var (
|
|
activeSessions = make(map[string]*TSSSession)
|
|
sessionMutex sync.RWMutex
|
|
)
|
|
|
|
// TSSSession holds the state for an active TSS session
|
|
type TSSSession struct {
|
|
SessionID string
|
|
PartyID string
|
|
PartyIndex int
|
|
ThresholdT int
|
|
ThresholdN int
|
|
Participants []Participant
|
|
LocalParty tss.Party
|
|
OutCh chan tss.Message
|
|
EndChKeygen chan *keygen.LocalPartySaveData
|
|
EndChSign chan *common.SignatureData
|
|
ErrCh chan error
|
|
PartyIndexMap map[int]*tss.PartyID
|
|
Password string
|
|
IsKeygen bool
|
|
Done chan struct{}
|
|
OnMessage js.Value // JavaScript callback for outgoing messages
|
|
OnProgress js.Value // JavaScript callback for progress updates
|
|
OnComplete js.Value // JavaScript callback for completion
|
|
OnError js.Value // JavaScript callback for errors
|
|
}
|
|
|
|
// Participant info
|
|
type Participant struct {
|
|
PartyID string `json:"partyId"`
|
|
PartyIndex int `json:"partyIndex"`
|
|
}
|
|
|
|
// Message types for JS communication
|
|
type JSMessage struct {
|
|
Type string `json:"type"`
|
|
IsBroadcast bool `json:"isBroadcast,omitempty"`
|
|
ToParties []string `json:"toParties,omitempty"`
|
|
Payload string `json:"payload,omitempty"`
|
|
PublicKey string `json:"publicKey,omitempty"`
|
|
EncryptedShare string `json:"encryptedShare,omitempty"`
|
|
Signature string `json:"signature,omitempty"`
|
|
PartyIndex int `json:"partyIndex,omitempty"`
|
|
Round int `json:"round,omitempty"`
|
|
TotalRounds int `json:"totalRounds,omitempty"`
|
|
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func main() {
|
|
// Register JavaScript functions
|
|
js.Global().Set("tssStartKeygen", js.FuncOf(startKeygen))
|
|
js.Global().Set("tssStartSigning", js.FuncOf(startSigning))
|
|
js.Global().Set("tssHandleMessage", js.FuncOf(handleMessage))
|
|
js.Global().Set("tssStopSession", js.FuncOf(stopSession))
|
|
js.Global().Set("tssGetVersion", js.FuncOf(getVersion))
|
|
|
|
// Keep the program running
|
|
select {}
|
|
}
|
|
|
|
// getVersion returns the TSS WASM version
|
|
func getVersion(this js.Value, args []js.Value) interface{} {
|
|
return "1.0.0"
|
|
}
|
|
|
|
// startKeygen initializes a keygen session
|
|
// Arguments: sessionId, partyId, partyIndex, thresholdT, thresholdN, participantsJSON, password, onMessage, onProgress, onComplete, onError
|
|
func startKeygen(this js.Value, args []js.Value) interface{} {
|
|
if len(args) < 11 {
|
|
return createErrorResult("Missing required arguments")
|
|
}
|
|
|
|
sessionID := args[0].String()
|
|
partyID := args[1].String()
|
|
partyIndex := args[2].Int()
|
|
thresholdT := args[3].Int()
|
|
thresholdN := args[4].Int()
|
|
participantsJSON := args[5].String()
|
|
password := args[6].String()
|
|
onMessage := args[7]
|
|
onProgress := args[8]
|
|
onComplete := args[9]
|
|
onError := args[10]
|
|
|
|
// Parse participants
|
|
var participants []Participant
|
|
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err))
|
|
}
|
|
|
|
if len(participants) != thresholdN {
|
|
return createErrorResult(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), thresholdN))
|
|
}
|
|
|
|
// Create session
|
|
session := &TSSSession{
|
|
SessionID: sessionID,
|
|
PartyID: partyID,
|
|
PartyIndex: partyIndex,
|
|
ThresholdT: thresholdT,
|
|
ThresholdN: thresholdN,
|
|
Participants: participants,
|
|
OutCh: make(chan tss.Message, thresholdN*10),
|
|
EndChKeygen: make(chan *keygen.LocalPartySaveData, 1),
|
|
ErrCh: make(chan error, 1),
|
|
Password: password,
|
|
IsKeygen: true,
|
|
Done: make(chan struct{}),
|
|
OnMessage: onMessage,
|
|
OnProgress: onProgress,
|
|
OnComplete: onComplete,
|
|
OnError: onError,
|
|
}
|
|
|
|
// Create TSS party IDs
|
|
tssPartyIDs := make([]*tss.PartyID, len(participants))
|
|
var selfTSSID *tss.PartyID
|
|
|
|
for i, p := range participants {
|
|
partyKey := tss.NewPartyID(
|
|
p.PartyID,
|
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
|
big.NewInt(int64(p.PartyIndex+1)),
|
|
)
|
|
tssPartyIDs[i] = partyKey
|
|
if p.PartyID == partyID {
|
|
selfTSSID = partyKey
|
|
}
|
|
}
|
|
|
|
if selfTSSID == nil {
|
|
return createErrorResult("Self party not found in participants")
|
|
}
|
|
|
|
// Sort party IDs
|
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
|
|
|
// Create peer context and parameters
|
|
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
|
|
// User says "2-of-3" meaning 2 signers needed, so we pass (thresholdT-1) to TSS-lib
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
tssThreshold := thresholdT - 1
|
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
|
|
|
|
// Build party index map
|
|
session.PartyIndexMap = make(map[int]*tss.PartyID)
|
|
for _, p := range sortedPartyIDs {
|
|
for _, orig := range participants {
|
|
if orig.PartyID == p.Id {
|
|
session.PartyIndexMap[orig.PartyIndex] = p
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create local party
|
|
session.LocalParty = keygen.NewLocalParty(params, session.OutCh, session.EndChKeygen)
|
|
|
|
// Store session
|
|
sessionMutex.Lock()
|
|
activeSessions[sessionID] = session
|
|
sessionMutex.Unlock()
|
|
|
|
// Start goroutines
|
|
go session.handleOutgoingMessages()
|
|
go session.waitForKeygenCompletion()
|
|
|
|
// Start the local party
|
|
go func() {
|
|
if err := session.LocalParty.Start(); err != nil {
|
|
session.ErrCh <- err
|
|
}
|
|
}()
|
|
|
|
return createSuccessResult(map[string]interface{}{
|
|
"sessionId": sessionID,
|
|
"started": true,
|
|
})
|
|
}
|
|
|
|
// startSigning initializes a signing session
|
|
// Arguments: sessionId, partyId, partyIndex, thresholdT, participantsJSON, saveDataJSON, messageHash, onMessage, onProgress, onComplete, onError
|
|
func startSigning(this js.Value, args []js.Value) interface{} {
|
|
if len(args) < 11 {
|
|
return createErrorResult("Missing required arguments")
|
|
}
|
|
|
|
sessionID := args[0].String()
|
|
partyID := args[1].String()
|
|
partyIndex := args[2].Int()
|
|
thresholdT := args[3].Int()
|
|
participantsJSON := args[4].String()
|
|
saveDataJSON := args[5].String()
|
|
messageHashB64 := args[6].String()
|
|
onMessage := args[7]
|
|
onProgress := args[8]
|
|
onComplete := args[9]
|
|
onError := args[10]
|
|
|
|
// Parse participants
|
|
var participants []Participant
|
|
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to parse participants: %v", err))
|
|
}
|
|
|
|
// Parse save data (keygen result)
|
|
var saveData keygen.LocalPartySaveData
|
|
if err := json.Unmarshal([]byte(saveDataJSON), &saveData); err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to parse save data: %v", err))
|
|
}
|
|
|
|
// Decode message hash
|
|
messageHash, err := base64.StdEncoding.DecodeString(messageHashB64)
|
|
if err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to decode message hash: %v", err))
|
|
}
|
|
|
|
thresholdN := len(participants)
|
|
|
|
// Create session
|
|
session := &TSSSession{
|
|
SessionID: sessionID,
|
|
PartyID: partyID,
|
|
PartyIndex: partyIndex,
|
|
ThresholdT: thresholdT,
|
|
ThresholdN: thresholdN,
|
|
Participants: participants,
|
|
OutCh: make(chan tss.Message, thresholdN*10),
|
|
EndChSign: make(chan *common.SignatureData, 1),
|
|
ErrCh: make(chan error, 1),
|
|
IsKeygen: false,
|
|
Done: make(chan struct{}),
|
|
OnMessage: onMessage,
|
|
OnProgress: onProgress,
|
|
OnComplete: onComplete,
|
|
OnError: onError,
|
|
}
|
|
|
|
// Create TSS party IDs
|
|
tssPartyIDs := make([]*tss.PartyID, len(participants))
|
|
var selfTSSID *tss.PartyID
|
|
|
|
for i, p := range participants {
|
|
partyKey := tss.NewPartyID(
|
|
p.PartyID,
|
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
|
big.NewInt(int64(p.PartyIndex+1)),
|
|
)
|
|
tssPartyIDs[i] = partyKey
|
|
if p.PartyID == partyID {
|
|
selfTSSID = partyKey
|
|
}
|
|
}
|
|
|
|
if selfTSSID == nil {
|
|
return createErrorResult("Self party not found in participants")
|
|
}
|
|
|
|
// Sort party IDs
|
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
|
|
|
// Create peer context and parameters
|
|
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
|
|
// This MUST match keygen exactly! Both use (thresholdT-1)
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
tssThreshold := thresholdT - 1
|
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
|
|
|
|
// Build party index map
|
|
session.PartyIndexMap = make(map[int]*tss.PartyID)
|
|
for _, p := range sortedPartyIDs {
|
|
for _, orig := range participants {
|
|
if orig.PartyID == p.Id {
|
|
session.PartyIndexMap[orig.PartyIndex] = p
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create message hash as big.Int
|
|
msgHashBig := new(big.Int).SetBytes(messageHash)
|
|
|
|
// CRITICAL: Build a subset of the keygen save data for the current signing parties
|
|
// This is required when signing with a subset of the original keygen participants.
|
|
subsetSaveData := keygen.BuildLocalSaveDataSubset(saveData, sortedPartyIDs)
|
|
|
|
// Create local signing party with the SUBSET save data
|
|
session.LocalParty = signing.NewLocalParty(msgHashBig, params, subsetSaveData, session.OutCh, session.EndChSign)
|
|
|
|
// Store session
|
|
sessionMutex.Lock()
|
|
activeSessions[sessionID] = session
|
|
sessionMutex.Unlock()
|
|
|
|
// Start goroutines
|
|
go session.handleOutgoingMessages()
|
|
go session.waitForSigningCompletion()
|
|
|
|
// Start the local party
|
|
go func() {
|
|
if err := session.LocalParty.Start(); err != nil {
|
|
session.ErrCh <- err
|
|
}
|
|
}()
|
|
|
|
return createSuccessResult(map[string]interface{}{
|
|
"sessionId": sessionID,
|
|
"started": true,
|
|
})
|
|
}
|
|
|
|
// handleMessage processes an incoming TSS message
|
|
// Arguments: sessionId, fromPartyIndex, isBroadcast, payloadBase64
|
|
func handleMessage(this js.Value, args []js.Value) interface{} {
|
|
if len(args) < 4 {
|
|
return createErrorResult("Missing required arguments")
|
|
}
|
|
|
|
sessionID := args[0].String()
|
|
fromPartyIndex := args[1].Int()
|
|
isBroadcast := args[2].Bool()
|
|
payloadB64 := args[3].String()
|
|
|
|
sessionMutex.RLock()
|
|
session, exists := activeSessions[sessionID]
|
|
sessionMutex.RUnlock()
|
|
|
|
if !exists {
|
|
return createErrorResult("Session not found")
|
|
}
|
|
|
|
fromParty, ok := session.PartyIndexMap[fromPartyIndex]
|
|
if !ok {
|
|
return createErrorResult("Unknown party index")
|
|
}
|
|
|
|
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
|
if err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to decode payload: %v", err))
|
|
}
|
|
|
|
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, isBroadcast)
|
|
if err != nil {
|
|
return createErrorResult(fmt.Sprintf("Failed to parse message: %v", err))
|
|
}
|
|
|
|
go func() {
|
|
_, err := session.LocalParty.Update(parsedMsg)
|
|
if err != nil && !isDuplicateError(err) {
|
|
session.ErrCh <- err
|
|
}
|
|
}()
|
|
|
|
return createSuccessResult(nil)
|
|
}
|
|
|
|
// stopSession stops an active TSS session
|
|
func stopSession(this js.Value, args []js.Value) interface{} {
|
|
if len(args) < 1 {
|
|
return createErrorResult("Missing session ID")
|
|
}
|
|
|
|
sessionID := args[0].String()
|
|
|
|
sessionMutex.Lock()
|
|
session, exists := activeSessions[sessionID]
|
|
if exists {
|
|
close(session.Done)
|
|
delete(activeSessions, sessionID)
|
|
}
|
|
sessionMutex.Unlock()
|
|
|
|
if !exists {
|
|
return createErrorResult("Session not found")
|
|
}
|
|
|
|
return createSuccessResult(nil)
|
|
}
|
|
|
|
// extractRoundFromMessageType parses the round number from a tss-lib message type string.
|
|
// Returns 0 if parsing fails (safe fallback).
|
|
// Example: "binance.tsslib.ecdsa.keygen.KGRound2Message1" -> 2
|
|
func extractRoundFromMessageType(msgType string) int {
|
|
matches := roundRegex.FindStringSubmatch(msgType)
|
|
if len(matches) >= 2 {
|
|
if round, err := strconv.Atoi(matches[1]); err == nil {
|
|
return round
|
|
}
|
|
}
|
|
return 0 // Safe fallback - doesn't affect protocol, just shows 0 in UI
|
|
}
|
|
|
|
// handleOutgoingMessages processes messages from the TSS protocol
|
|
func (s *TSSSession) handleOutgoingMessages() {
|
|
totalRounds := 4 // GG20 keygen has 4 rounds
|
|
if !s.IsKeygen {
|
|
totalRounds = 9 // GG20 signing has 9 rounds (matching Electron and Android)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-s.Done:
|
|
return
|
|
case msg, ok := <-s.OutCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
msgBytes, _, err := msg.WireBytes()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var toParties []string
|
|
if !msg.IsBroadcast() {
|
|
for _, to := range msg.GetTo() {
|
|
toParties = append(toParties, to.Id)
|
|
}
|
|
}
|
|
|
|
// Call JavaScript callback
|
|
jsMsg := map[string]interface{}{
|
|
"type": "outgoing",
|
|
"isBroadcast": msg.IsBroadcast(),
|
|
"toParties": toParties,
|
|
"payload": base64.StdEncoding.EncodeToString(msgBytes),
|
|
}
|
|
|
|
jsMsgJSON, _ := json.Marshal(jsMsg)
|
|
s.OnMessage.Invoke(string(jsMsgJSON))
|
|
|
|
// Extract current round from message type and send progress update
|
|
currentRound := extractRoundFromMessageType(msg.Type())
|
|
s.OnProgress.Invoke(currentRound, totalRounds)
|
|
}
|
|
}
|
|
}
|
|
|
|
// waitForKeygenCompletion waits for keygen to complete
|
|
func (s *TSSSession) waitForKeygenCompletion() {
|
|
select {
|
|
case <-s.Done:
|
|
return
|
|
case err := <-s.ErrCh:
|
|
s.OnError.Invoke(err.Error())
|
|
s.cleanup()
|
|
case saveData := <-s.EndChKeygen:
|
|
// Get public key (33 bytes compressed)
|
|
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
|
|
pubKeyBytes := make([]byte, 33)
|
|
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
|
|
xBytes := pubKey.X.Bytes()
|
|
copy(pubKeyBytes[33-len(xBytes):], xBytes)
|
|
|
|
// Serialize save data
|
|
saveDataBytes, err := json.Marshal(saveData)
|
|
if err != nil {
|
|
s.OnError.Invoke(fmt.Sprintf("Failed to serialize save data: %v", err))
|
|
s.cleanup()
|
|
return
|
|
}
|
|
|
|
// Encrypt with password
|
|
encryptedShare := encryptShare(saveDataBytes, s.Password)
|
|
|
|
// Call completion callback
|
|
result := map[string]interface{}{
|
|
"type": "result",
|
|
"publicKey": base64.StdEncoding.EncodeToString(pubKeyBytes),
|
|
"encryptedShare": base64.StdEncoding.EncodeToString(encryptedShare),
|
|
"partyIndex": s.PartyIndex,
|
|
}
|
|
|
|
resultJSON, _ := json.Marshal(result)
|
|
s.OnComplete.Invoke(string(resultJSON))
|
|
s.cleanup()
|
|
}
|
|
}
|
|
|
|
// waitForSigningCompletion waits for signing to complete
|
|
func (s *TSSSession) waitForSigningCompletion() {
|
|
select {
|
|
case <-s.Done:
|
|
return
|
|
case err := <-s.ErrCh:
|
|
s.OnError.Invoke(err.Error())
|
|
s.cleanup()
|
|
case sigData := <-s.EndChSign:
|
|
// Serialize signature (R, S format)
|
|
sigBytes := append(sigData.GetR(), sigData.GetS()...)
|
|
|
|
result := map[string]interface{}{
|
|
"type": "result",
|
|
"signature": base64.StdEncoding.EncodeToString(sigBytes),
|
|
"partyIndex": s.PartyIndex,
|
|
}
|
|
|
|
resultJSON, _ := json.Marshal(result)
|
|
s.OnComplete.Invoke(string(resultJSON))
|
|
s.cleanup()
|
|
}
|
|
}
|
|
|
|
func (s *TSSSession) cleanup() {
|
|
sessionMutex.Lock()
|
|
delete(activeSessions, s.SessionID)
|
|
sessionMutex.Unlock()
|
|
}
|
|
|
|
func isDuplicateError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
errStr := err.Error()
|
|
return contains(errStr, "duplicate") || contains(errStr, "already received")
|
|
}
|
|
|
|
func contains(s, substr string) bool {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func encryptShare(data []byte, password string) []byte {
|
|
// TODO: Use proper AES-256-GCM encryption
|
|
// For now, just prepend a marker and the password hash
|
|
result := make([]byte, len(data)+32)
|
|
copy(result[:32], hashPassword(password))
|
|
copy(result[32:], data)
|
|
return result
|
|
}
|
|
|
|
func hashPassword(password string) []byte {
|
|
hash := make([]byte, 32)
|
|
for i := 0; i < len(password) && i < 32; i++ {
|
|
hash[i] = password[i]
|
|
}
|
|
return hash
|
|
}
|
|
|
|
func createSuccessResult(data interface{}) map[string]interface{} {
|
|
result := map[string]interface{}{
|
|
"success": true,
|
|
}
|
|
if data != nil {
|
|
result["data"] = data
|
|
}
|
|
return result
|
|
}
|
|
|
|
func createErrorResult(errMsg string) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"success": false,
|
|
"error": errMsg,
|
|
}
|
|
}
|