rwadurian/backend/mpc-system/services/service-party-android/tsslib/tsslib.go

658 lines
18 KiB
Go

// Package tsslib provides TSS (Threshold Signature Scheme) functionality for Android
// This package is designed to be compiled with gomobile for Android integration via JNI
//
// Based on the verified tss-party implementation from service-party-app (Electron version)
package tsslib
import (
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
"regexp"
"strconv"
"sync"
"time"
"github.com/bnb-chain/tss-lib/v2/common"
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
"github.com/bnb-chain/tss-lib/v2/tss"
)
// Regex to extract round number from tss-lib message type
// Message types look like: "binance.tsslib.ecdsa.keygen.KGRound1Message"
// or "binance.tsslib.ecdsa.signing.SignRound3Message"
var roundRegex = regexp.MustCompile(`Round(\d+)`)
// MessageCallback is the interface for receiving TSS protocol messages
// Android side implements this interface to handle message routing
type MessageCallback interface {
// OnOutgoingMessage is called when TSS needs to send a message to other parties
// messageJSON contains: type, isBroadcast, toParties, payload (base64)
OnOutgoingMessage(messageJSON string)
// OnProgress is called to report protocol progress
OnProgress(round, totalRounds int)
// OnError is called when an error occurs
OnError(errorMessage string)
// OnLog is called for debug logging
OnLog(message string)
}
// Participant represents a party in the TSS protocol
type Participant struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
}
// tssSession manages a TSS keygen or signing session
type tssSession struct {
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
callback MessageCallback
localParty tss.Party
partyIndexMap map[int]*tss.PartyID
errCh chan error
keygenResultCh chan *keygen.LocalPartySaveData
signResultCh chan *common.SignatureData
isKeygen bool
}
var (
currentSession *tssSession
sessionMu sync.Mutex
)
// StartKeygen initiates a new key generation session
// This is the entry point called from Android via JNI
func StartKeygen(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participantsJSON, password string,
callback MessageCallback,
) error {
sessionMu.Lock()
defer sessionMu.Unlock()
if currentSession != nil {
return fmt.Errorf("a session is already in progress")
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
return fmt.Errorf("failed to parse participants: %w", err)
}
if len(participants) != thresholdN {
return fmt.Errorf("participant count mismatch: got %d, expected %d", len(participants), thresholdN)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
session := &tssSession{
ctx: ctx,
cancel: cancel,
callback: callback,
partyIndexMap: make(map[int]*tss.PartyID),
errCh: make(chan error, 1),
keygenResultCh: make(chan *keygen.LocalPartySaveData, 1),
isKeygen: true,
}
// Create TSS party IDs - same as verified Electron version
tssPartyIDs := make([]*tss.PartyID, len(participants))
var selfTSSID *tss.PartyID
for i, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs[i] = partyKey
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
cancel()
return fmt.Errorf("self party not found in participants")
}
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Build party index map for incoming messages
for _, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
session.partyIndexMap[orig.PartyIndex] = p
break
}
}
}
// Create peer context and parameters
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
// User says "2-of-3" meaning 2 signers needed, so we pass (thresholdT-1) to TSS-lib
// For 2-of-3: thresholdT=2, tss-lib threshold=1, signers_needed=1+1=2 ✓
peerCtx := tss.NewPeerContext(sortedPartyIDs)
tssThreshold := thresholdT - 1
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
callback.OnLog(fmt.Sprintf("[TSS-KEYGEN] NewParameters: partyCount=%d, tssThreshold=%d (from thresholdT=%d, means %d signers needed)",
len(sortedPartyIDs), tssThreshold, thresholdT, thresholdT))
// Create channels
outCh := make(chan tss.Message, thresholdN*10)
endCh := make(chan *keygen.LocalPartySaveData, 1)
// Create local party
localParty := keygen.NewLocalParty(params, outCh, endCh)
session.localParty = localParty
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
session.errCh <- err
}
}()
// Handle outgoing messages
go func() {
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
session.handleOutgoingMessage(msg)
}
}
}()
// Handle completion
go func() {
select {
case <-ctx.Done():
callback.OnError("session timeout or cancelled")
case err := <-session.errCh:
callback.OnError(fmt.Sprintf("keygen error: %v", err))
case saveData := <-endCh:
session.keygenResultCh <- saveData
}
}()
currentSession = session
return nil
}
// StartSign initiates a new signing session
// Based on verified executeSign from Electron version
func StartSign(
sessionID, partyID string,
partyIndex, thresholdT, thresholdN int,
participantsJSON, messageHashHex, shareDataBase64, password string,
callback MessageCallback,
) error {
sessionMu.Lock()
defer sessionMu.Unlock()
if currentSession != nil {
return fmt.Errorf("a session is already in progress")
}
// Parse participants
var participants []Participant
if err := json.Unmarshal([]byte(participantsJSON), &participants); err != nil {
return fmt.Errorf("failed to parse participants: %w", err)
}
// Note: For signing, participant count equals threshold T (not N)
// because only T parties participate in signing
if len(participants) != thresholdT {
return fmt.Errorf("participant count mismatch: got %d, expected %d (threshold T)", len(participants), thresholdT)
}
// Decode and decrypt share data
encryptedShare, err := base64.StdEncoding.DecodeString(shareDataBase64)
if err != nil {
return fmt.Errorf("failed to decode share data: %w", err)
}
shareBytes, err := decryptShare(encryptedShare, password)
if err != nil {
return fmt.Errorf("failed to decrypt share: %w", err)
}
// Parse keygen save data
var keygenData keygen.LocalPartySaveData
if err := json.Unmarshal(shareBytes, &keygenData); err != nil {
return fmt.Errorf("failed to parse keygen data: %w", err)
}
// Decode message hash
messageHash, err := hex.DecodeString(messageHashHex)
if err != nil {
return fmt.Errorf("failed to decode message hash: %w", err)
}
if len(messageHash) != 32 {
return fmt.Errorf("message hash must be 32 bytes, got %d", len(messageHash))
}
msgBigInt := new(big.Int).SetBytes(messageHash)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
session := &tssSession{
ctx: ctx,
cancel: cancel,
callback: callback,
partyIndexMap: make(map[int]*tss.PartyID),
errCh: make(chan error, 1),
signResultCh: make(chan *common.SignatureData, 1),
isKeygen: false,
}
// Create TSS party IDs for signing participants
// IMPORTANT: For tss-lib signing, we must reconstruct the party IDs in the same way
// as during keygen. The signing subset (T parties) must use their original keys from keygen.
tssPartyIDs := make([]*tss.PartyID, 0, len(participants))
var selfTSSID *tss.PartyID
for _, p := range participants {
partyKey := tss.NewPartyID(
p.PartyID,
fmt.Sprintf("party-%d", p.PartyIndex),
big.NewInt(int64(p.PartyIndex+1)),
)
tssPartyIDs = append(tssPartyIDs, partyKey)
if p.PartyID == partyID {
selfTSSID = partyKey
}
}
if selfTSSID == nil {
cancel()
return fmt.Errorf("self party not found in participants")
}
// Sort party IDs (important for tss-lib)
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Build party index map for incoming messages
for _, p := range sortedPartyIDs {
for _, orig := range participants {
if orig.PartyID == p.Id {
session.partyIndexMap[orig.PartyIndex] = p
break
}
}
}
// Create peer context and parameters
// IMPORTANT: TSS-lib threshold convention: threshold=t means (t+1) signers required
// This MUST match keygen exactly!
peerCtx := tss.NewPeerContext(sortedPartyIDs)
tssThreshold := thresholdT - 1
params := tss.NewParameters(tss.S256(), peerCtx, selfTSSID, len(sortedPartyIDs), tssThreshold)
callback.OnLog(fmt.Sprintf("[TSS-SIGN] NewParameters: partyCount=%d, tssThreshold=%d (from thresholdT=%d, means %d signers needed)",
len(sortedPartyIDs), tssThreshold, thresholdT, thresholdT))
callback.OnLog(fmt.Sprintf("[TSS-SIGN] Original keygenData has %d parties (Ks length)", len(keygenData.Ks)))
callback.OnLog(fmt.Sprintf("[TSS-SIGN] Building subset for %d signing parties", len(sortedPartyIDs)))
// CRITICAL: Build a subset of the keygen save data for the current signing parties
// This is required when signing with a subset of the original keygen participants.
subsetKeygenData := keygen.BuildLocalSaveDataSubset(keygenData, sortedPartyIDs)
callback.OnLog(fmt.Sprintf("[TSS-SIGN] Subset keygenData has %d parties (Ks length)", len(subsetKeygenData.Ks)))
// Create channels
outCh := make(chan tss.Message, thresholdT*10)
endCh := make(chan *common.SignatureData, 1)
// Create local party for signing with the SUBSET keygen data
localParty := signing.NewLocalParty(msgBigInt, params, subsetKeygenData, outCh, endCh)
session.localParty = localParty
// Start the local party
go func() {
if err := localParty.Start(); err != nil {
session.errCh <- err
}
}()
// Handle outgoing messages
go func() {
for {
select {
case <-ctx.Done():
return
case msg, ok := <-outCh:
if !ok {
return
}
session.handleOutgoingMessage(msg)
}
}
}()
// Handle completion
go func() {
select {
case <-ctx.Done():
callback.OnError("session timeout or cancelled")
case err := <-session.errCh:
callback.OnError(fmt.Sprintf("sign error: %v", err))
case sigData := <-endCh:
session.signResultCh <- sigData
}
}()
currentSession = session
return nil
}
// SendIncomingMessage delivers a message from another party to the current session
func SendIncomingMessage(fromPartyIndex int, isBroadcast bool, payloadBase64 string) error {
sessionMu.Lock()
session := currentSession
sessionMu.Unlock()
if session == nil {
return fmt.Errorf("no active session")
}
session.mu.Lock()
defer session.mu.Unlock()
fromParty, ok := session.partyIndexMap[fromPartyIndex]
if !ok {
return fmt.Errorf("unknown party index: %d", fromPartyIndex)
}
payload, err := base64.StdEncoding.DecodeString(payloadBase64)
if err != nil {
return fmt.Errorf("failed to decode payload: %w", err)
}
parsedMsg, err := tss.ParseWireMessage(payload, fromParty, isBroadcast)
if err != nil {
return fmt.Errorf("failed to parse message: %w", err)
}
go func() {
_, err := session.localParty.Update(parsedMsg)
if err != nil {
// Only send fatal errors
if !isDuplicateError(err) {
session.errCh <- err
}
}
}()
return nil
}
// WaitForKeygenResult blocks until keygen completes and returns the result as JSON
func WaitForKeygenResult(password string) (string, error) {
sessionMu.Lock()
session := currentSession
sessionMu.Unlock()
if session == nil {
return "", fmt.Errorf("no active session")
}
if !session.isKeygen {
return "", fmt.Errorf("current session is not a keygen session")
}
// Track progress - GG20 keygen has 4 rounds
totalRounds := 4
select {
case <-session.ctx.Done():
return "", session.ctx.Err()
case saveData := <-session.keygenResultCh:
// Keygen completed successfully
session.callback.OnProgress(totalRounds, totalRounds)
// Get public key - same as Electron version
pubKey := saveData.ECDSAPub.ToECDSAPubKey()
pubKeyBytes := make([]byte, 33)
pubKeyBytes[0] = 0x02 + byte(pubKey.Y.Bit(0))
xBytes := pubKey.X.Bytes()
copy(pubKeyBytes[33-len(xBytes):], xBytes)
// Serialize and encrypt save data
saveDataBytes, err := json.Marshal(saveData)
if err != nil {
return "", fmt.Errorf("failed to serialize save data: %w", err)
}
// Encrypt with password (same as Electron version)
encryptedShare := encryptShare(saveDataBytes, password)
result := struct {
PublicKey string `json:"publicKey"`
EncryptedShare string `json:"encryptedShare"`
}{
PublicKey: base64.StdEncoding.EncodeToString(pubKeyBytes),
EncryptedShare: base64.StdEncoding.EncodeToString(encryptedShare),
}
resultJSON, _ := json.Marshal(result)
// Clean up session
session.cancel()
sessionMu.Lock()
currentSession = nil
sessionMu.Unlock()
return string(resultJSON), nil
}
}
// WaitForSignResult blocks until signing completes and returns the result as JSON
func WaitForSignResult() (string, error) {
sessionMu.Lock()
session := currentSession
sessionMu.Unlock()
if session == nil {
return "", fmt.Errorf("no active session")
}
if session.isKeygen {
return "", fmt.Errorf("current session is not a sign session")
}
// Track progress - GG20 signing has 9 rounds
totalRounds := 9
select {
case <-session.ctx.Done():
return "", session.ctx.Err()
case sigData := <-session.signResultCh:
// Signing completed successfully
session.callback.OnProgress(totalRounds, totalRounds)
// Construct signature: R (32 bytes) || S (32 bytes)
rBytes := sigData.R
sBytes := sigData.S
signature := make([]byte, 64)
copy(signature[32-len(rBytes):32], rBytes)
copy(signature[64-len(sBytes):64], sBytes)
// Recovery ID for Ethereum-style signatures
recoveryID := int(sigData.SignatureRecovery[0])
// Append recovery ID to signature (r + s + v = 64 + 1 = 65 bytes)
// This is needed for EVM transaction signing
signatureWithV := make([]byte, 65)
copy(signatureWithV, signature)
signatureWithV[64] = byte(recoveryID)
result := struct {
Signature string `json:"signature"`
RecoveryID int `json:"recoveryId"`
}{
Signature: base64.StdEncoding.EncodeToString(signatureWithV),
RecoveryID: recoveryID,
}
resultJSON, _ := json.Marshal(result)
// Clean up session
session.cancel()
sessionMu.Lock()
currentSession = nil
sessionMu.Unlock()
return string(resultJSON), nil
}
}
// CancelSession cancels the current session
func CancelSession() {
sessionMu.Lock()
defer sessionMu.Unlock()
if currentSession != nil {
currentSession.cancel()
currentSession = nil
}
}
// extractRoundFromMessageType parses the round number from a tss-lib message type string.
// Returns 0 if parsing fails (safe fallback).
// Example: "binance.tsslib.ecdsa.keygen.KGRound2Message1" -> 2
func extractRoundFromMessageType(msgType string) int {
matches := roundRegex.FindStringSubmatch(msgType)
if len(matches) >= 2 {
if round, err := strconv.Atoi(matches[1]); err == nil {
return round
}
}
return 0 // Safe fallback - doesn't affect protocol, just shows 0 in UI
}
func (s *tssSession) handleOutgoingMessage(msg tss.Message) {
msgBytes, _, err := msg.WireBytes()
if err != nil {
return
}
var toParties []string
if !msg.IsBroadcast() {
for _, to := range msg.GetTo() {
toParties = append(toParties, to.Id)
}
}
outMsg := struct {
Type string `json:"type"`
IsBroadcast bool `json:"isBroadcast"`
ToParties []string `json:"toParties,omitempty"`
Payload string `json:"payload"`
}{
Type: "outgoing",
IsBroadcast: msg.IsBroadcast(),
ToParties: toParties,
Payload: base64.StdEncoding.EncodeToString(msgBytes),
}
data, _ := json.Marshal(outMsg)
s.callback.OnOutgoingMessage(string(data))
// Extract current round from message type and send progress update
totalRounds := 4 // GG20 keygen has 4 rounds
if !s.isKeygen {
totalRounds = 9 // GG20 signing has 9 rounds
}
currentRound := extractRoundFromMessageType(msg.Type())
s.callback.OnProgress(currentRound, totalRounds)
}
func isDuplicateError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return contains(errStr, "duplicate") || contains(errStr, "already received")
}
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// encryptShare encrypts the share data with password
// Same implementation as Electron version for compatibility
func encryptShare(data []byte, password string) []byte {
// TODO: Use proper AES-256-GCM encryption
// For now, just prepend a marker and the password hash
// This is NOT secure - just a placeholder (same as Electron version)
result := make([]byte, len(data)+32)
copy(result[:32], hashPassword(password))
copy(result[32:], data)
return result
}
// decryptShare decrypts the share data with password
// Same implementation as Electron version for compatibility
func decryptShare(encryptedData []byte, password string) ([]byte, error) {
// Match the encryption format: first 32 bytes are password hash, rest is data
if len(encryptedData) < 32 {
return nil, fmt.Errorf("encrypted data too short")
}
// Verify password (simple check - matches encryptShare)
expectedHash := hashPassword(password)
actualHash := encryptedData[:32]
// Simple comparison
match := true
for i := 0; i < 32; i++ {
if expectedHash[i] != actualHash[i] {
match = false
break
}
}
if !match {
return nil, fmt.Errorf("incorrect password")
}
return encryptedData[32:], nil
}
// hashPassword creates a simple hash of the password
// Same implementation as Electron version for compatibility
func hashPassword(password string) []byte {
// Simple hash - should use PBKDF2 or Argon2 in production
hash := make([]byte, 32)
for i := 0; i < len(password) && i < 32; i++ {
hash[i] = password[i]
}
return hash
}