407 lines
9.3 KiB
Go
407 lines
9.3 KiB
Go
// Package main provides the TSS party subprocess for Electron app
|
|
//
|
|
// This program handles TSS (Threshold Signature Scheme) protocol execution
|
|
// It communicates with the parent Electron process via stdin/stdout using JSON messages
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
|
|
"github.com/bnb-chain/tss-lib/v2/tss"
|
|
)
|
|
|
|
// Message types for IPC
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
IsBroadcast bool `json:"isBroadcast,omitempty"`
|
|
ToParties []string `json:"toParties,omitempty"`
|
|
Payload string `json:"payload,omitempty"` // base64 encoded
|
|
PublicKey string `json:"publicKey,omitempty"` // base64 encoded
|
|
EncryptedShare string `json:"encryptedShare,omitempty"` // base64 encoded
|
|
PartyIndex int `json:"partyIndex,omitempty"`
|
|
Round int `json:"round,omitempty"`
|
|
TotalRounds int `json:"totalRounds,omitempty"`
|
|
FromPartyIndex int `json:"fromPartyIndex,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// Participant info
|
|
type Participant struct {
|
|
PartyID string `json:"partyId"`
|
|
PartyIndex int `json:"partyIndex"`
|
|
}
|
|
|
|
func main() {
|
|
// Parse command
|
|
if len(os.Args) < 2 {
|
|
fmt.Fprintln(os.Stderr, "Usage: tss-party <command> [options]")
|
|
os.Exit(1)
|
|
}
|
|
|
|
command := os.Args[1]
|
|
|
|
switch command {
|
|
case "keygen":
|
|
runKeygen()
|
|
case "sign":
|
|
runSign()
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func runKeygen() {
|
|
// Parse keygen flags
|
|
fs := flag.NewFlagSet("keygen", flag.ExitOnError)
|
|
sessionID := fs.String("session-id", "", "Session ID")
|
|
partyID := fs.String("party-id", "", "Party ID")
|
|
partyIndex := fs.Int("party-index", 0, "Party index (0-based)")
|
|
thresholdT := fs.Int("threshold-t", 0, "Threshold T")
|
|
thresholdN := fs.Int("threshold-n", 0, "Threshold N")
|
|
participantsJSON := fs.String("participants", "[]", "Participants JSON array")
|
|
password := fs.String("password", "", "Encryption password for share")
|
|
|
|
if err := fs.Parse(os.Args[2:]); err != nil {
|
|
sendError(fmt.Sprintf("Failed to parse flags: %v", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Validate required fields
|
|
if *sessionID == "" || *partyID == "" || *thresholdT == 0 || *thresholdN == 0 {
|
|
sendError("Missing required parameters")
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Parse participants
|
|
var participants []Participant
|
|
if err := json.Unmarshal([]byte(*participantsJSON), &participants); err != nil {
|
|
sendError(fmt.Sprintf("Failed to parse participants: %v", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
if len(participants) != *thresholdN {
|
|
sendError(fmt.Sprintf("Participant count mismatch: got %d, expected %d", len(participants), *thresholdN))
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Run keygen protocol
|
|
result, err := executeKeygen(
|
|
*sessionID,
|
|
*partyID,
|
|
*partyIndex,
|
|
*thresholdT,
|
|
*thresholdN,
|
|
participants,
|
|
*password,
|
|
)
|
|
|
|
if err != nil {
|
|
sendError(fmt.Sprintf("Keygen failed: %v", err))
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Send result
|
|
sendResult(result.PublicKey, result.EncryptedShare, *partyIndex)
|
|
}
|
|
|
|
func runSign() {
|
|
// TODO: Implement signing
|
|
sendError("Signing not implemented yet")
|
|
os.Exit(1)
|
|
}
|
|
|
|
type keygenResult struct {
|
|
PublicKey []byte
|
|
EncryptedShare []byte
|
|
}
|
|
|
|
func executeKeygen(
|
|
sessionID, partyID string,
|
|
partyIndex, thresholdT, thresholdN int,
|
|
participants []Participant,
|
|
password string,
|
|
) (*keygenResult, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
defer cancel()
|
|
|
|
// Handle signals for graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
go func() {
|
|
<-sigChan
|
|
cancel()
|
|
}()
|
|
|
|
// Create TSS party IDs
|
|
tssPartyIDs := make([]*tss.PartyID, len(participants))
|
|
var selfTSSID *tss.PartyID
|
|
|
|
for i, p := range participants {
|
|
partyKey := tss.NewPartyID(
|
|
p.PartyID,
|
|
fmt.Sprintf("party-%d", p.PartyIndex),
|
|
tss.S256(),
|
|
)
|
|
tssPartyIDs[i] = partyKey
|
|
if p.PartyID == partyID {
|
|
selfTSSID = partyKey
|
|
}
|
|
}
|
|
|
|
if selfTSSID == nil {
|
|
return nil, fmt.Errorf("self party not found in participants")
|
|
}
|
|
|
|
// Sort party IDs
|
|
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
|
|
|
|
// Create peer context and parameters
|
|
peerCtx := tss.NewPeerContext(sortedPartyIDs)
|
|
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), thresholdT)
|
|
|
|
// Create channels
|
|
outCh := make(chan tss.Message, thresholdN*10)
|
|
endCh := make(chan *keygen.LocalPartySaveData, 1)
|
|
errCh := make(chan error, 1)
|
|
|
|
// Create local party
|
|
localParty := keygen.NewLocalParty(params, outCh, endCh)
|
|
|
|
// Build party index map for incoming messages
|
|
partyIndexMap := make(map[int]*tss.PartyID)
|
|
for i, p := range sortedPartyIDs {
|
|
for _, orig := range participants {
|
|
if orig.PartyID == p.Id {
|
|
partyIndexMap[orig.PartyIndex] = p
|
|
break
|
|
}
|
|
}
|
|
_ = i
|
|
}
|
|
|
|
// Start the local party
|
|
go func() {
|
|
if err := localParty.Start(); err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
|
|
// Handle outgoing messages
|
|
var outWg sync.WaitGroup
|
|
outWg.Add(1)
|
|
go func() {
|
|
defer outWg.Done()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-outCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
handleOutgoingMessage(msg)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Handle incoming messages from stdin
|
|
var inWg sync.WaitGroup
|
|
inWg.Add(1)
|
|
go func() {
|
|
defer inWg.Done()
|
|
scanner := bufio.NewScanner(os.Stdin)
|
|
for scanner.Scan() {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
var msg Message
|
|
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
|
continue
|
|
}
|
|
|
|
if msg.Type == "incoming" {
|
|
handleIncomingMessage(msg, localParty, partyIndexMap, errCh)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Track progress
|
|
totalRounds := 4 // GG20 keygen has 4 rounds
|
|
currentRound := 0
|
|
|
|
// Wait for completion
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case err := <-errCh:
|
|
return nil, err
|
|
case saveData := <-endCh:
|
|
// Keygen completed successfully
|
|
sendProgress(totalRounds, totalRounds)
|
|
|
|
// Get public key
|
|
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
|
|
pubKeyBytes := make([]byte, 33)
|
|
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
|
|
xBytes := pubKey.X.Bytes()
|
|
copy(pubKeyBytes[33-len(xBytes):], xBytes)
|
|
|
|
// Serialize and encrypt save data
|
|
saveDataBytes, err := json.Marshal(saveData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize save data: %w", err)
|
|
}
|
|
|
|
// Encrypt with password (simple XOR for now - should use AES-GCM in production)
|
|
encryptedShare := encryptShare(saveDataBytes, password)
|
|
|
|
return &keygenResult{
|
|
PublicKey: pubKeyBytes,
|
|
EncryptedShare: encryptedShare,
|
|
}, nil
|
|
}
|
|
|
|
_ = currentRound
|
|
}
|
|
|
|
func handleOutgoingMessage(msg tss.Message) {
|
|
msgBytes, _, err := msg.WireBytes()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var toParties []string
|
|
if !msg.IsBroadcast() {
|
|
for _, to := range msg.GetTo() {
|
|
toParties = append(toParties, to.Id)
|
|
}
|
|
}
|
|
|
|
outMsg := Message{
|
|
Type: "outgoing",
|
|
IsBroadcast: msg.IsBroadcast(),
|
|
ToParties: toParties,
|
|
Payload: base64.StdEncoding.EncodeToString(msgBytes),
|
|
}
|
|
|
|
data, _ := json.Marshal(outMsg)
|
|
fmt.Println(string(data))
|
|
}
|
|
|
|
func handleIncomingMessage(
|
|
msg Message,
|
|
localParty tss.Party,
|
|
partyIndexMap map[int]*tss.PartyID,
|
|
errCh chan error,
|
|
) {
|
|
fromParty, ok := partyIndexMap[msg.FromPartyIndex]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
payload, err := base64.StdEncoding.DecodeString(msg.Payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, msg.IsBroadcast)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
_, err := localParty.Update(parsedMsg)
|
|
if err != nil {
|
|
// Only send fatal errors
|
|
if !isDuplicateError(err) {
|
|
errCh <- err
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func isDuplicateError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
errStr := err.Error()
|
|
return contains(errStr, "duplicate") || contains(errStr, "already received")
|
|
}
|
|
|
|
func contains(s, substr string) bool {
|
|
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsImpl(s, substr))
|
|
}
|
|
|
|
func containsImpl(s, substr string) bool {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func encryptShare(data []byte, password string) []byte {
|
|
// TODO: Use proper AES-256-GCM encryption
|
|
// For now, just prepend a marker and the password hash
|
|
// This is NOT secure - just a placeholder
|
|
result := make([]byte, len(data)+32)
|
|
copy(result[:32], hashPassword(password))
|
|
copy(result[32:], data)
|
|
return result
|
|
}
|
|
|
|
func hashPassword(password string) []byte {
|
|
// Simple hash - should use PBKDF2 or Argon2 in production
|
|
hash := make([]byte, 32)
|
|
for i := 0; i < len(password) && i < 32; i++ {
|
|
hash[i] = password[i]
|
|
}
|
|
return hash
|
|
}
|
|
|
|
func sendProgress(round, totalRounds int) {
|
|
msg := Message{
|
|
Type: "progress",
|
|
Round: round,
|
|
TotalRounds: totalRounds,
|
|
}
|
|
data, _ := json.Marshal(msg)
|
|
fmt.Println(string(data))
|
|
}
|
|
|
|
func sendError(errMsg string) {
|
|
msg := Message{
|
|
Type: "error",
|
|
Error: errMsg,
|
|
}
|
|
data, _ := json.Marshal(msg)
|
|
fmt.Println(string(data))
|
|
}
|
|
|
|
func sendResult(publicKey, encryptedShare []byte, partyIndex int) {
|
|
msg := Message{
|
|
Type: "result",
|
|
PublicKey: base64.StdEncoding.EncodeToString(publicKey),
|
|
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
|
|
PartyIndex: partyIndex,
|
|
}
|
|
data, _ := json.Marshal(msg)
|
|
fmt.Println(string(data))
|
|
}
|