rwadurian/backend/mpc-system/services/service-party-app/tss-party/main.go

407 lines
9.3 KiB
Go

// Package main provides the TSS party subprocess for Electron app
//
// This program handles TSS (Threshold Signature Scheme) protocol execution
// It communicates with the parent Electron process via stdin/stdout using JSON messages
package main
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Message types for IPC
type Message struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast,omitempty"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload,omitempty"` // base64 encoded
PublicKey string `json:"publicKey,omitempty"` // base64 encoded
EncryptedShare string `json:"encryptedShare,omitempty"` // base64 encoded
PartyIndex int `json:"partyIndex,omitempty"`
Round int `json:"round,omitempty"`
TotalRounds int `json:"totalRounds,omitempty"`
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
Error string `json:"error,omitempty"`
}
// Participant info
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
func main() {
// Parse command
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: tss-party <command> [options]")
os.Exit(1)
}
command := os.Args[1]
switch command {
case "keygen":
runKeygen()
case "sign":
runSign()
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(1)
}
}
func runKeygen() {
// Parse keygen flags
fs := flag.NewFlagSet("keygen", flag.ExitOnError)
sessionID := fs.String("session-id", "", "Session ID")
partyID := fs.String("party-id", "", "Party ID")
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
thresholdN := fs.Int("threshold-n", 0, "Threshold N")
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
password := fs.String("password", "", "Encryption password for share")
if err := fs.Parse(os.Args[2:]); err != nil {
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
os.Exit(1)
}
// Validate required fields
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
sendError("Missing required parameters")
os.Exit(1)
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
os.Exit(1)
}
if len(participants) != *thresholdN {
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), *thresholdN))
os.Exit(1)
}
// Run keygen protocol
result, err := executeKeygen(
*sessionID,
*partyID,
*partyIndex,
*thresholdT,
*thresholdN,
participants,
*password,
)
if err != nil {
sendError(fmt.Sprintf("Keygen failed: %v", err))
os.Exit(1)
}
// Send result
sendResult(result.PublicKey, result.EncryptedShare, *partyIndex)
}
func runSign() {
// TODO: Implement signing
sendError("Signing not implemented yet")
os.Exit(1)
}
type keygenResult struct {
PublicKey []byte
EncryptedShare []byte
}
func executeKeygen(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participants []Participant,
password string,
) (*keygenResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Handle signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
cancel()
}()
// Create TSS party IDs
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
tss.S256(),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
return nil, fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Create peer context and parameters
peerCtx := tss.NewPeerContext(sortedPartyIDs)
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT)
// Create channels
outCh := make(chan tss.Message, thresholdN*10)
endCh := make(chan *keygen.LocalPartySaveData, 1)
errCh := make(chan error, 1)
// Create local party
localParty := keygen.NewLocalParty(params, outCh, endCh)
// Build party index map for incoming messages
partyIndexMap := make(map[int]*tss.PartyID)
for i, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
partyIndexMap[orig.PartyIndex] = p
break
}
}
_ = i
}
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
errCh <- err
}
}()
// Handle outgoing messages
var outWg sync.WaitGroup
outWg.Add(1)
go func() {
defer outWg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
handleOutgoingMessage(msg)
}
}
}()
// Handle incoming messages from stdin
var inWg sync.WaitGroup
inWg.Add(1)
go func() {
defer inWg.Done()
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
var msg Message
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
continue
}
if msg.Type == "incoming" {
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
}
}
}()
// Track progress
totalRounds := 4 // GG20 keygen has 4 rounds
currentRound := 0
// Wait for completion
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case saveData := <-endCh:
// Keygen completed successfully
sendProgress(totalRounds, totalRounds)
// Get public key
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize and encrypt save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return nil, fmt.Errorf("failed to serialize save data: %w", err)
}
// Encrypt with password (simple XOR for now - should use AES-GCM in production)
encryptedShare := encryptShare(saveDataBytes, password)
return &keygenResult{
PublicKey: pubKeyBytes,
EncryptedShare: encryptedShare,
}, nil
}
_ = currentRound
}
func handleOutgoingMessage(msg tss.Message) {
msgBytes, _, err := msg.WireBytes()
if err != nil {
return
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
outMsg := Message{
Type: "outgoing",
IsBroadcast: msg.IsBroadcast(),
ToParties: toParties,
Payload: base64.StdEncoding.EncodeToString(msgBytes),
}
data, _ := json.Marshal(outMsg)
fmt.Println(string(data))
}
func handleIncomingMessage(
msg Message,
localParty tss.Party,
partyIndexMap map[int]*tss.PartyID,
errCh chan error,
) {
fromParty, ok := partyIndexMap[msg.FromPartyIndex]
if !ok {
return
}
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
if err != nil {
return
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, msg.IsBroadcast)
if err != nil {
return
}
go func() {
_, err := localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
errCh <- err
}
}
}()
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsImpl(s, substr))
}
func containsImpl(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
// This is NOT secure - just a placeholder
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
func hashPassword(password string) []byte {
// Simple hash - should use PBKDF2 or Argon2 in production
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}
func sendProgress(round, totalRounds int) {
msg := Message{
Type: "progress",
Round: round,
TotalRounds: totalRounds,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendError(errMsg string) {
msg := Message{
Type: "error",
Error: errMsg,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}
func sendResult(publicKey, encryptedShare []byte, partyIndex int) {
msg := Message{
Type: "result",
PublicKey: base64.StdEncoding.EncodeToString(publicKey),
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
PartyIndex: partyIndex,
}
data, _ := json.Marshal(msg)
fmt.Println(string(data))
}