rwadurian/backend/mpc-system/services/service-party-app/tss-party/main.go

407 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package main provides the TSS party subprocess for Electron app
//
// This program handles TSS (Threshold Signature Scheme) protocol execution
// It communicates with the parent Electron process via stdin/stdout using JSON messages
package main
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"math/big"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Message types for IPC
type Message struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"` // base64 encoded
PublicKey string `json:"publicKey,omitempty"` // base64 encoded
EncryptedShare string `json:"encryptedShare,omitempty"` // base64 encoded
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
}
// Participant info
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
func main() {
// Parse command
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: tss-party <command> [options]")
os.Exit(1)
}
command := os.Args[1]
switch command {
case "keygen":
runKeygen()
case "sign":
runSign()
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(1)
}
}
func runKeygen() {
// Parse keygen flags
fs := flag.NewFlagSet("keygen", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Threshold N")
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
password := fs.String("password", "", "Encryption password for share")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
if len(participants) != *thresholdN {
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), *thresholdN))
os.Exit(1)
}
// Run keygen protocol
result, err := executeKeygen(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
*password,
)
if err != nil {
sendError(fmt.Sprintf("Keygen failed: %v", err))
os.Exit(1)
}
// Send result
sendResult(result.PublicKey, result.EncryptedShare, *partyIndex)
}
func runSign() {
// TODO: Implement signing
sendError("Signing not implemented yet")
os.Exit(1)
}
type keygenResult struct {
PublicKey []byte
EncryptedShare []byte
}
func executeKeygen(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
password string,
) (*keygenResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT)
// Create channels
outCh := make(chan tss.Message, thresholdN*10)
endCh := make(chan *keygen.LocalPartySaveData, 1)
errCh := make(chan error, 1)
// Create local party
localParty := keygen.NewLocalParty(params, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for i, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
partyIndexMap[orig.PartyIndex] = p
break
}
}
_ = i
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
// 增加 buffer 大小到 1MB默认 64KB 可能不够大消息
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
}
}
}()
// Track progress
totalRounds := 4 // GG20 keygen has 4 rounds
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case saveData := <-endCh:
// Keygen completed successfully
sendProgress(totalRounds, totalRounds)
// Get public key
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize and encrypt save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return nil, fmt.Errorf("failed to serialize save data: %w", err)
}
// Encrypt with password (simple XOR for now - should use AES-GCM in production)
encryptedShare := encryptShare(saveDataBytes, password)
return &keygenResult{
PublicKey: pubKeyBytes,
EncryptedShare: encryptedShare,
}, nil
}
}
func handleOutgoingMessage(msg tss.Message) {
msgBytes, _, err := msg.WireBytes()
if err != nil {
return
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
outMsg := Message{
Type: "outgoing",
IsBroadcast: msg.IsBroadcast(),
ToParties: toParties,
Payload: base64.StdEncoding.EncodeToString(msgBytes),
}
data, _ := json.Marshal(outMsg)
fmt.Println(string(data))
}
func handleIncomingMessage(
msg Message,
localParty tss.Party,
partyIndexMap map[int]*tss.PartyID,
errCh chan error,
) {
fromParty, ok := partyIndexMap[msg.FromPartyIndex]
if !ok {
return
}
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
if err != nil {
return
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, msg.IsBroadcast)
if err != nil {
return
}
go func() {
_, err := localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
errCh <- err
}
}
}()
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsImpl(s, substr))
}
func containsImpl(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
// This is NOT secure - just a placeholder
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
func hashPassword(password string) []byte {
// Simple hash - should use PBKDF2 or Argon2 in production
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}
func sendProgress(round, totalRounds int) {
msg := Message{
Type: "progress",
Round: round,
TotalRounds: totalRounds,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendError(errMsg string) {
msg := Message{
Type: "error",
Error: errMsg,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendResult(publicKey, encryptedShare []byte, partyIndex int) {
msg := Message{
Type: "result",
PublicKey: base64.StdEncoding.EncodeToString(publicKey),
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
PartyIndex: partyIndex,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}