fix: add keygen index to sorted index mapping for signing session

When signing with a subset of parties (e.g., party-1 and party-3 in 2-of-3),
the TSS library creates a sorted array of party IDs. Messages contain the
original keygen party index, but we need to map it to the sorted array index.

This fixes the 'invalid FromPartyIndex' error when signing with non-consecutive
party indices.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
hailin 2025-12-06 11:04:19 -08:00
parent f769c7eebf
commit 9fc41cfa53
1 changed files with 52 additions and 19 deletions

View File

@ -66,6 +66,9 @@ type SigningSession struct {
msgHandler MessageHandler
mu sync.Mutex
started bool
// keygenIndexToSortedIndex maps keygen party index to sorted array index
// This is needed because TSS messages use keygen index, but tssPartyIDs is sorted
keygenIndexToSortedIndex map[int]int
}
// NewSigningSession creates a new signing session
@ -112,6 +115,22 @@ func NewSigningSession(
// Sort party IDs
sortedPartyIDs := tss.SortPartyIDs(tssPartyIDs)
// Build mapping from keygen index to sorted array index
// The sorted array is ordered by big.Int key (PartyIndex+1)
keygenIndexToSortedIndex := make(map[int]int)
for sortedIdx, partyID := range sortedPartyIDs {
// Find the original keygen index for this party
for _, p := range allParties {
if p.PartyID == partyID.Id {
keygenIndexToSortedIndex[p.PartyIndex] = sortedIdx
break
}
}
}
fmt.Printf("[TSS-SIGN] Built keygen index to sorted index mapping: %v party_id=%s\n",
keygenIndexToSortedIndex, selfParty.PartyID)
// Create peer context and parameters
// IMPORTANT: Use TotalParties from keygen, not len(sortedPartyIDs) which is current signers
// For 2-of-3: threshold=2, TotalParties=3, but only 2 parties might participate in signing
@ -122,18 +141,19 @@ func NewSigningSession(
msgHash := new(big.Int).SetBytes(messageHash)
return &SigningSession{
config: config,
selfParty: selfParty,
allParties: allParties,
messageHash: msgHash,
saveData: &saveData,
tssPartyIDs: sortedPartyIDs,
selfTSSID: selfTSSID,
params: params,
outCh: make(chan tss.Message, config.TotalSigners*10),
endCh: make(chan *common.SignatureData, 1),
errCh: make(chan error, 1),
msgHandler: msgHandler,
config: config,
selfParty: selfParty,
allParties: allParties,
messageHash: msgHash,
saveData: &saveData,
tssPartyIDs: sortedPartyIDs,
selfTSSID: selfTSSID,
params: params,
outCh: make(chan tss.Message, config.TotalSigners*10),
endCh: make(chan *common.SignatureData, 1),
errCh: make(chan error, 1),
msgHandler: msgHandler,
keygenIndexToSortedIndex: keygenIndexToSortedIndex,
}, nil
}
@ -232,18 +252,31 @@ func (s *SigningSession) handleIncomingMessages(ctx context.Context) {
return
}
fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_index=%d is_broadcast=%v msg_len=%d\n",
fmt.Printf("[TSS-SIGN] received incoming message party_id=%s from_keygen_index=%d is_broadcast=%v msg_len=%d\n",
s.selfParty.PartyID, msg.FromPartyIndex, msg.IsBroadcast, len(msg.MsgBytes))
// Check if FromPartyIndex is valid
if msg.FromPartyIndex < 0 || msg.FromPartyIndex >= len(s.tssPartyIDs) {
fmt.Printf("[TSS-SIGN] ERROR: invalid FromPartyIndex=%d, len(tssPartyIDs)=%d party_id=%s\n",
msg.FromPartyIndex, len(s.tssPartyIDs), s.selfParty.PartyID)
// Map keygen index to sorted array index
// msg.FromPartyIndex is the original keygen party index (e.g., 0, 1, 2)
// We need the sorted array index (e.g., 0, 1 for a 2-party signing session)
sortedIndex, exists := s.keygenIndexToSortedIndex[msg.FromPartyIndex]
if !exists {
fmt.Printf("[TSS-SIGN] ERROR: unknown keygen index=%d, mapping=%v party_id=%s\n",
msg.FromPartyIndex, s.keygenIndexToSortedIndex, s.selfParty.PartyID)
continue
}
// Parse the message
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[msg.FromPartyIndex], msg.IsBroadcast)
fmt.Printf("[TSS-SIGN] mapped keygen_index=%d to sorted_index=%d party_id=%s\n",
msg.FromPartyIndex, sortedIndex, s.selfParty.PartyID)
// Check if sorted index is valid
if sortedIndex < 0 || sortedIndex >= len(s.tssPartyIDs) {
fmt.Printf("[TSS-SIGN] ERROR: invalid sortedIndex=%d, len(tssPartyIDs)=%d party_id=%s\n",
sortedIndex, len(s.tssPartyIDs), s.selfParty.PartyID)
continue
}
// Parse the message using the sorted index
parsedMsg, err := tss.ParseWireMessage(msg.MsgBytes, s.tssPartyIDs[sortedIndex], msg.IsBroadcast)
if err != nil {
fmt.Printf("[TSS-SIGN] ERROR: failed to parse wire message party_id=%s from_index=%d error=%v\n",
s.selfParty.PartyID, msg.FromPartyIndex, err)