rwadurian/backend/mpc-system/services/service-party-app/tss-party/main.go

706 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package main provides the TSS party subprocess for Electron app
//
// This program handles TSS (Threshold Signature Scheme) protocol execution
// It communicates with the parent Electron process via stdin/stdout using JSON messages
package main
import (
"bufio"
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"math/big"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/bnb-chain/tss-lib/v2/common"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Message types for IPC
type Message struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"` // base64 encoded
PublicKey string `json:"publicKey,omitempty"` // base64 encoded
EncryptedShare string `json:"encryptedShare,omitempty"` // base64 encoded
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
}
// Participant info
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
func main() {
// Parse command
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: tss-party <command> [options]")
os.Exit(1)
}
command := os.Args[1]
switch command {
case "keygen":
runKeygen()
case "sign":
runSign()
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(1)
}
}
func runKeygen() {
// Parse keygen flags
fs := flag.NewFlagSet("keygen", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Threshold N")
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
password := fs.String("password", "", "Encryption password for share")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
if len(participants) != *thresholdN {
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), *thresholdN))
os.Exit(1)
}
// Run keygen protocol
result, err := executeKeygen(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
*password,
)
if err != nil {
sendError(fmt.Sprintf("Keygen failed: %v", err))
os.Exit(1)
}
// Send result
sendResult(result.PublicKey, result.EncryptedShare, *partyIndex)
}
func runSign() {
// Parse sign flags
fs := flag.NewFlagSet("sign", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Threshold N (total parties in keygen)")
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
messageHash := fs.String("message-hash", "", "Message hash to sign (hex encoded)")
shareData := fs.String("share-data", "", "Encrypted share data (base64 encoded)")
password := fs.String("password", "", "Password to decrypt share")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
if *messageHash == "" {
sendError("Missing message hash")
os.Exit(1)
}
if *shareData == "" {
sendError("Missing share data")
os.Exit(1)
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
// Note: For signing, participant count equals threshold T (not N)
// because only T parties participate in signing
if len(participants) != *thresholdT {
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d (threshold T)", len(participants), *thresholdT))
os.Exit(1)
}
// Run sign protocol
result, err := executeSign(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
*messageHash,
*shareData,
*password,
)
if err != nil {
sendError(fmt.Sprintf("Sign failed: %v", err))
os.Exit(1)
}
// Send result
sendSignResult(result.Signature, result.RecoveryID, *partyIndex)
}
type keygenResult struct {
PublicKey []byte
EncryptedShare []byte
}
type signResult struct {
Signature []byte
RecoveryID int
}
func executeKeygen(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
password string,
) (*keygenResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT)
// Create channels
outCh := make(chan tss.Message, thresholdN*10)
endCh := make(chan *keygen.LocalPartySaveData, 1)
errCh := make(chan error, 1)
// Create local party
localParty := keygen.NewLocalParty(params, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for i, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
partyIndexMap[orig.PartyIndex] = p
break
}
}
_ = i
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
// 增加 buffer 大小到 1MB默认 64KB 可能不够大消息
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
}
}
}()
// Track progress
totalRounds := 4 // GG20 keygen has 4 rounds
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case saveData := <-endCh:
// Keygen completed successfully
sendProgress(totalRounds, totalRounds)
// Get public key
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize and encrypt save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return nil, fmt.Errorf("failed to serialize save data: %w", err)
}
// Encrypt with password (simple XOR for now - should use AES-GCM in production)
encryptedShare := encryptShare(saveDataBytes, password)
return &keygenResult{
PublicKey: pubKeyBytes,
EncryptedShare: encryptedShare,
}, nil
}
}
func handleOutgoingMessage(msg tss.Message) {
msgBytes, _, err := msg.WireBytes()
if err != nil {
return
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
outMsg := Message{
Type: "outgoing",
IsBroadcast: msg.IsBroadcast(),
ToParties: toParties,
Payload: base64.StdEncoding.EncodeToString(msgBytes),
}
data, _ := json.Marshal(outMsg)
fmt.Println(string(data))
}
func handleIncomingMessage(
msg Message,
localParty tss.Party,
partyIndexMap map[int]*tss.PartyID,
errCh chan error,
) {
fromParty, ok := partyIndexMap[msg.FromPartyIndex]
if !ok {
return
}
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
if err != nil {
return
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, msg.IsBroadcast)
if err != nil {
return
}
go func() {
_, err := localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
errCh <- err
}
}
}()
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsImpl(s, substr))
}
func containsImpl(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
// This is NOT secure - just a placeholder
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
func hashPassword(password string) []byte {
// Simple hash - should use PBKDF2 or Argon2 in production
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}
func sendProgress(round, totalRounds int) {
msg := Message{
Type: "progress",
Round: round,
TotalRounds: totalRounds,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendError(errMsg string) {
msg := Message{
Type: "error",
Error: errMsg,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendResult(publicKey, encryptedShare []byte, partyIndex int) {
msg := Message{
Type: "result",
PublicKey: base64.StdEncoding.EncodeToString(publicKey),
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
PartyIndex: partyIndex,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendSignResult(signature []byte, recoveryID int, partyIndex int) {
msg := Message{
Type: "result",
Payload: base64.StdEncoding.EncodeToString(signature),
PartyIndex: partyIndex,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func decryptShare(encryptedData []byte, password string) ([]byte, error) {
// Match the encryption format: first 32 bytes are password hash, rest is data
if len(encryptedData) < 32 {
return nil, fmt.Errorf("encrypted data too short")
}
// Verify password (simple check - matches encryptShare)
expectedHash := hashPassword(password)
actualHash := encryptedData[:32]
// Simple comparison
match := true
for i := 0; i < 32; i++ {
if expectedHash[i] != actualHash[i] {
match = false
break
}
}
if !match {
return nil, fmt.Errorf("incorrect password")
}
return encryptedData[32:], nil
}
func executeSign(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
messageHashHex string,
shareDataBase64 string,
password string,
) (*signResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Decode and decrypt share data
encryptedShare, err := base64.StdEncoding.DecodeString(shareDataBase64)
if err != nil {
return nil, fmt.Errorf("failed to decode share data: %w", err)
}
shareBytes, err := decryptShare(encryptedShare, password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt share: %w", err)
}
// Parse keygen save data
var keygenData keygen.LocalPartySaveData
if err := json.Unmarshal(shareBytes, &keygenData); err != nil {
return nil, fmt.Errorf("failed to parse keygen data: %w", err)
}
// Decode message hash
messageHash, err := hex.DecodeString(messageHashHex)
if err != nil {
return nil, fmt.Errorf("failed to decode message hash: %w", err)
}
if len(messageHash) != 32 {
return nil, fmt.Errorf("message hash must be 32 bytes, got %d", len(messageHash))
}
msgBigInt := new(big.Int).SetBytes(messageHash)
// Create TSS party IDs for signing participants
// IMPORTANT: For tss-lib signing, we must reconstruct the party IDs in the same way
// as during keygen. The signing subset (T parties) must use their original keys from keygen.
//
// The keygenData.Ks contains the public keys for all N parties from keygen.
// We need to create party IDs that match the original keygen party structure,
// but only include the T parties that are participating in this signing session.
// Create party IDs only for the signing participants
tssPartyIDs := make([]*tss.PartyID, 0, len(participants))
var selfTSSID *tss.PartyID
for _, p := range participants {
// Use the keygen key at this party's index
// The party key in tss-lib uses the key's big.Int representation
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs = append(tssPartyIDs, partyKey)
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs (important for tss-lib)
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
// For signing with T parties from an N-party keygen:
// - The peer context contains only the T signing parties
// - threshold parameter should be T-1 (since we need T parties to sign, threshold = T-1)
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT-1)
// Create channels
outCh := make(chan tss.Message, thresholdT*10)
endCh := make(chan *common.SignatureData, 1)
errCh := make(chan error, 1)
// Create local party for signing
localParty := signing.NewLocalParty(msgBigInt, params, keygenData, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for _, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
partyIndexMap[orig.PartyIndex] = p
break
}
}
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
// 增加 buffer 大小到 1MB
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
}
}
}()
// Track progress - GG20 signing has 9 rounds
totalRounds := 9
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case sigData := <-endCh:
// Signing completed successfully
sendProgress(totalRounds, totalRounds)
// Construct signature in DER format or raw R||S format
// sigData contains R, S, and recovery ID
rBytes := sigData.R
sBytes := sigData.S
// Create raw signature: R (32 bytes) || S (32 bytes)
signature := make([]byte, 64)
copy(signature[32-len(rBytes):32], rBytes)
copy(signature[64-len(sBytes):64], sBytes)
// Recovery ID for Ethereum-style signatures
recoveryID := int(sigData.SignatureRecovery[0])
return &signResult{
Signature: signature,
RecoveryID: recoveryID,
}, nil
}
}