rwadurian/backend/mpc-system/services/service-party-app/tss-party/main.go

723 lines
18 KiB
Go

// Package main provides the TSS party subprocess for Electron app
//
// This program handles TSS (Threshold Signature Scheme) protocol execution
// It communicates with the parent Electron process via stdin/stdout using JSON messages
package main
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"math/big"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/bnb-chain/tss-lib/v2/common"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Message types for IPC
type Message struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"` // base64 encoded
PublicKey string `json:"publicKey,omitempty"` // base64 encoded
EncryptedShare string `json:"encryptedShare,omitempty"` // base64 encoded
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
// Signing result fields
Signature string `json:"signature,omitempty"` // base64 encoded (R || S, 64 bytes)
R string `json:"r,omitempty"` // hex encoded
S string `json:"s,omitempty"` // hex encoded
RecoveryID int `json:"recoveryId,omitempty"` // for ecrecover
}
// Participant info
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
func main() {
// Parse command
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: tss-party <command> [options]")
os.Exit(1)
}
command := os.Args[1]
switch command {
case "keygen":
runKeygen()
case "sign":
runSign()
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(1)
}
}
func runKeygen() {
// Parse keygen flags
fs := flag.NewFlagSet("keygen", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Threshold N")
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
password := fs.String("password", "", "Encryption password for share")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
if len(participants) != *thresholdN {
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), *thresholdN))
os.Exit(1)
}
// Run keygen protocol
result, err := executeKeygen(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
*password,
)
if err != nil {
sendError(fmt.Sprintf("Keygen failed: %v", err))
os.Exit(1)
}
// Send result
sendResult(result.PublicKey, result.EncryptedShare, *partyIndex)
}
func runSign() {
// Parse sign flags
fs := flag.NewFlagSet("sign", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Original Threshold N from keygen")
participantsJSON := fs.String("participants", "[]", "Participants JSON array (current signers)")
messageHashB64 := fs.String("message-hash", "", "Message hash to sign (base64 encoded)")
shareDataB64 := fs.String("share-data", "", "Decrypted share data from keygen (base64 encoded)")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
if *messageHashB64 == "" {
sendError("Missing message hash")
os.Exit(1)
}
if *shareDataB64 == "" {
sendError("Missing share data")
os.Exit(1)
}
// Parse participants (current signers, may be subset of original keygen participants)
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
if len(participants) < *thresholdT {
sendError(fmt.Sprintf("Not enough signers: got %d, need at least %d", len(participants), *thresholdT))
os.Exit(1)
}
// Decode message hash
messageHash, err := base64.StdEncoding.DecodeString(*messageHashB64)
if err != nil {
sendError(fmt.Sprintf("Failed to decode message hash: %v", err))
os.Exit(1)
}
// Decode share data
shareData, err := base64.StdEncoding.DecodeString(*shareDataB64)
if err != nil {
sendError(fmt.Sprintf("Failed to decode share data: %v", err))
os.Exit(1)
}
// Run sign protocol
result, err := executeSign(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
messageHash,
shareData,
)
if err != nil {
sendError(fmt.Sprintf("Signing failed: %v", err))
os.Exit(1)
}
// Send result
sendSignResult(result.Signature, result.R, result.S, result.RecoveryID)
}
type keygenResult struct {
PublicKey []byte
EncryptedShare []byte
}
func executeKeygen(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
password string,
) (*keygenResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT)
// Create channels
outCh := make(chan tss.Message, thresholdN*10)
endCh := make(chan *keygen.LocalPartySaveData, 1)
errCh := make(chan error, 1)
// Create local party
localParty := keygen.NewLocalParty(params, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for i, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
partyIndexMap[orig.PartyIndex] = p
break
}
}
_ = i
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
// Increase buffer for large messages (TSS messages can be ~200KB)
buf := make([]byte, 1024*1024)
scanner.Buffer(buf, len(buf))
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
}
}
}()
// Track progress
totalRounds := 4 // GG20 keygen has 4 rounds
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case saveData := <-endCh:
// Keygen completed successfully
sendProgress(totalRounds, totalRounds)
// Get public key
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize and encrypt save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return nil, fmt.Errorf("failed to serialize save data: %w", err)
}
// Encrypt with password (simple XOR for now - should use AES-GCM in production)
encryptedShare := encryptShare(saveDataBytes, password)
return &keygenResult{
PublicKey: pubKeyBytes,
EncryptedShare: encryptedShare,
}, nil
}
}
func handleOutgoingMessage(msg tss.Message) {
msgBytes, _, err := msg.WireBytes()
if err != nil {
return
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
outMsg := Message{
Type: "outgoing",
IsBroadcast: msg.IsBroadcast(),
ToParties: toParties,
Payload: base64.StdEncoding.EncodeToString(msgBytes),
}
data, _ := json.Marshal(outMsg)
fmt.Println(string(data))
}
func handleIncomingMessage(
msg Message,
localParty tss.Party,
partyIndexMap map[int]*tss.PartyID,
errCh chan error,
) {
fromParty, ok := partyIndexMap[msg.FromPartyIndex]
if !ok {
return
}
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
if err != nil {
return
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, msg.IsBroadcast)
if err != nil {
return
}
go func() {
_, err := localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
errCh <- err
}
}
}()
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsImpl(s, substr))
}
func containsImpl(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
// This is NOT secure - just a placeholder
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
func hashPassword(password string) []byte {
// Simple hash - should use PBKDF2 or Argon2 in production
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}
func sendProgress(round, totalRounds int) {
msg := Message{
Type: "progress",
Round: round,
TotalRounds: totalRounds,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendError(errMsg string) {
msg := Message{
Type: "error",
Error: errMsg,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendResult(publicKey, encryptedShare []byte, partyIndex int) {
msg := Message{
Type: "result",
PublicKey: base64.StdEncoding.EncodeToString(publicKey),
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
PartyIndex: partyIndex,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
// =============================================================================
// Signing Implementation
// =============================================================================
type signResult struct {
Signature []byte
R *big.Int
S *big.Int
RecoveryID int
}
func executeSign(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
messageHash []byte,
shareData []byte,
) (*signResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Deserialize keygen save data
var saveData keygen.LocalPartySaveData
if err := json.Unmarshal(shareData, &saveData); err != nil {
return nil, fmt.Errorf("failed to deserialize share data: %w", err)
}
// Create TSS party IDs for current signers (may be subset of original keygen participants)
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Build mapping from keygen index to sorted array index
// This is needed because TSS messages use keygen index, but tssPartyIDs is sorted
keygenIndexToSortedIndex := make(map[int]int)
for sortedIdx, pid := range sortedPartyIDs {
for _, p := range participants {
if p.PartyID == pid.Id {
keygenIndexToSortedIndex[p.PartyIndex] = sortedIdx
break
}
}
}
// Create peer context and parameters
// IMPORTANT: Use original thresholdN from keygen, not len(participants)
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, thresholdN, thresholdT)
// Convert message hash to big.Int
msgHashBigInt := new(big.Int).SetBytes(messageHash)
// Create channels
signerCount := len(participants)
outCh := make(chan tss.Message, signerCount*10)
endCh := make(chan *common.SignatureData, 1)
errCh := make(chan error, 1)
// Create local signing party
localParty := signing.NewLocalParty(msgHashBigInt, params, saveData, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for sortedIdx, pid := range sortedPartyIDs {
for _, p := range participants {
if p.PartyID == pid.Id {
partyIndexMap[p.PartyIndex] = sortedPartyIDs[sortedIdx]
break
}
}
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
// Increase buffer for large messages
buf := make([]byte, 1024*1024)
scanner.Buffer(buf, len(buf))
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleSignIncomingMessage(msg, localParty, partyIndexMap, keygenIndexToSortedIndex, sortedPartyIDs, errCh)
}
}
}()
// Track progress
totalRounds := 9 // GG20 signing has 9 rounds
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case signData := <-endCh:
// Signing completed successfully
sendProgress(totalRounds, totalRounds)
// Build signature (R || S)
signature := make([]byte, 64)
rBytes := signData.R
sBytes := signData.S
copy(signature[32-len(rBytes):32], rBytes)
copy(signature[64-len(sBytes):64], sBytes)
r := new(big.Int).SetBytes(signData.R)
s := new(big.Int).SetBytes(signData.S)
recoveryID := int(signData.SignatureRecovery[0])
return &signResult{
Signature: signature,
R: r,
S: s,
RecoveryID: recoveryID,
}, nil
}
}
func handleSignIncomingMessage(
msg Message,
localParty tss.Party,
partyIndexMap map[int]*tss.PartyID,
keygenIndexToSortedIndex map[int]int,
sortedPartyIDs []*tss.PartyID,
errCh chan error,
) {
// Map keygen index to sorted array index
sortedIndex, exists := keygenIndexToSortedIndex[msg.FromPartyIndex]
if !exists {
return
}
if sortedIndex < 0 || sortedIndex >= len(sortedPartyIDs) {
return
}
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
if err != nil {
return
}
parsedMsg, err := tss.ParseWireMessage(payload, sortedPartyIDs[sortedIndex], msg.IsBroadcast)
if err != nil {
return
}
go func() {
_, err := localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
errCh <- err
}
}
}()
}
func sendSignResult(signature []byte, r, s *big.Int, recoveryID int) {
msg := Message{
Type: "sign_result",
Signature: base64.StdEncoding.EncodeToString(signature),
R: fmt.Sprintf("%064x", r),
S: fmt.Sprintf("%064x", s),
RecoveryID: recoveryID,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}