fix: add keygen index to sorted index mapping for signing session

When signing with a subset of parties (e.g., party-1 and party-3 in 2-of-3),
the TSS library creates a sorted array of party IDs. Messages contain the
original keygen party index, but we need to map it to the sorted array index.

This fixes the 'invalid FromPartyIndex' error when signing with non-consecutive
party indices.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
hailin 2025-12-06 11:04:19 -08:00
parent f769c7eebf
commit 9fc41cfa53
1 changed files with 52 additions and 19 deletions

View File

@ -66,6 +66,9 @@ type SigningSession struct {
msgHandler MessageHandler msgHandler MessageHandler
mu sync.Mutex mu sync.Mutex
started bool started bool
// keygenIndexToSortedIndex maps keygen party index to sorted array index
// This is needed because TSS messages use keygen index, but tssPartyIDs is sorted
keygenIndexToSortedIndex map[int]int
} }
// NewSigningSession creates a new signing session // NewSigningSession creates a new signing session
@ -112,6 +115,22 @@ func NewSigningSession(
// Sort party IDs // Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs) sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Build mapping from keygen index to sorted array index
// The sorted array is ordered by big.Int key (PartyIndex+1)
keygenIndexToSortedIndex := make(map[int]int)
for sortedIdx, partyID := range sortedPartyIDs {
// Find the original keygen index for this party
for _, p := range allParties {
if p.PartyID == partyID.Id {
keygenIndexToSortedIndex[p.PartyIndex] = sortedIdx
break
}
}
}
fmt.Printf("[TSS-SIGN] Built keygen index to sorted index mapping: %v party_id=%s\n",
keygenIndexToSortedIndex, selfParty.PartyID)
// Create peer context and parameters // Create peer context and parameters
// IMPORTANT: Use TotalParties from keygen, not len(sortedPartyIDs) which is current signers // IMPORTANT: Use TotalParties from keygen, not len(sortedPartyIDs) which is current signers
// For 2-of-3: threshold=2, TotalParties=3, but only 2 parties might participate in signing // For 2-of-3: threshold=2, TotalParties=3, but only 2 parties might participate in signing
@ -122,18 +141,19 @@ func NewSigningSession(
msgHash := new(big.Int).SetBytes(messageHash) msgHash := new(big.Int).SetBytes(messageHash)
return &SigningSession{ return &SigningSession{
config: config, config: config,
selfParty: selfParty, selfParty: selfParty,
allParties: allParties, allParties: allParties,
messageHash: msgHash, messageHash: msgHash,
saveData: &saveData, saveData: &saveData,
tssPartyIDs: sortedPartyIDs, tssPartyIDs: sortedPartyIDs,
selfTSSID: selfTSSID, selfTSSID: selfTSSID,
params: params, params: params,
outCh: make(chan tss.Message, config.TotalSigners*10), outCh: make(chan tss.Message, config.TotalSigners*10),
endCh: make(chan *common.SignatureData, 1), endCh: make(chan *common.SignatureData, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
msgHandler: msgHandler, msgHandler: msgHandler,
keygenIndexToSortedIndex: keygenIndexToSortedIndex,
}, nil }, nil
} }
@ -232,18 +252,31 @@ func (s *SigningSession) handleIncomingMessages(ctx context.Context) {
return return
} }
fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_index=%d is_broadcast=%v msg_len=%d\n", fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_keygen_index=%d is_broadcast=%v msg_len=%d\n",
s.selfParty.PartyID, msg.FromPartyIndex, msg.IsBroadcast, len(msg.MsgBytes)) s.selfParty.PartyID, msg.FromPartyIndex, msg.IsBroadcast, len(msg.MsgBytes))
// Check if FromPartyIndex is valid // Map keygen index to sorted array index
if msg.FromPartyIndex < 0 || msg.FromPartyIndex >= len(s.tssPartyIDs) { // msg.FromPartyIndex is the original keygen party index (e.g., 0, 1, 2)
fmt.Printf("[TSS-SIGN] ERROR: invalid FromPartyIndex=%d, len(tssPartyIDs)=%d party_id=%s\n", // We need the sorted array index (e.g., 0, 1 for a 2-party signing session)
msg.FromPartyIndex, len(s.tssPartyIDs), s.selfParty.PartyID) sortedIndex, exists := s.keygenIndexToSortedIndex[msg.FromPartyIndex]
if !exists {
fmt.Printf("[TSS-SIGN] ERROR: unknown keygen index=%d, mapping=%v party_id=%s\n",
msg.FromPartyIndex, s.keygenIndexToSortedIndex, s.selfParty.PartyID)
continue continue
} }
// Parse the message fmt.Printf("[TSS-SIGN] mapped keygen_index=%d to sorted_index=%d party_id=%s\n",
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast) msg.FromPartyIndex, sortedIndex, s.selfParty.PartyID)
// Check if sorted index is valid
if sortedIndex < 0 || sortedIndex >= len(s.tssPartyIDs) {
fmt.Printf("[TSS-SIGN] ERROR: invalid sortedIndex=%d, len(tssPartyIDs)=%d party_id=%s\n",
sortedIndex, len(s.tssPartyIDs), s.selfParty.PartyID)
continue
}
// Parse the message using the sorted index
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[sortedIndex], msg.IsBroadcast)
if err != nil { if err != nil {
fmt.Printf("[TSS-SIGN] ERROR: failed to parse wire message party_id=%s from_index=%d error=%v\n", fmt.Printf("[TSS-SIGN] ERROR: failed to parse wire message party_id=%s from_index=%d error=%v\n",
s.selfParty.PartyID, msg.FromPartyIndex, err) s.selfParty.PartyID, msg.FromPartyIndex, err)