feat(server-party): add ExecuteWithSessionInfo for co-managed keygen

Add new ExecuteWithSessionInfo method to ParticipateKeygenUseCase
for server-party-co-managed to skip duplicate JoinSession call.

- server-party-co-managed already calls JoinSession in session_created phase
- ExecuteWithSessionInfo accepts pre-obtained SessionInfo and skips internal JoinSession
- Refactor common execution logic to private executeWithSessionInfo method
- Update server-party-co-managed to use ExecuteWithSessionInfo on session_started

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
hailin 2025-12-30 00:43:09 -08:00
parent fd6f84ce82
commit 2164664ca0
2 changed files with 90 additions and 30 deletions

View File

@ -29,10 +29,13 @@ import (
// PendingSession stores session info between session_created and session_started events // PendingSession stores session info between session_created and session_started events
type PendingSession struct { type PendingSession struct {
SessionID uuid.UUID SessionID uuid.UUID
JoinToken string JoinToken string
MessageHash []byte MessageHash []byte
CreatedAt time.Time ThresholdN int
ThresholdT int
SelectedParties []string
CreatedAt time.Time
} }
// PendingSessionCache stores pending sessions waiting for session_started // PendingSessionCache stores pending sessions waiting for session_started
@ -379,10 +382,13 @@ func createCoManagedSessionEventHandler(
// Store pending session for later use when session_started arrives // Store pending session for later use when session_started arrives
pendingSessionCache.Store(event.SessionId, &PendingSession{ pendingSessionCache.Store(event.SessionId, &PendingSession{
SessionID: sessionID, SessionID: sessionID,
JoinToken: joinToken, JoinToken: joinToken,
MessageHash: event.MessageHash, MessageHash: event.MessageHash,
CreatedAt: time.Now(), ThresholdN: int(event.ThresholdN),
ThresholdT: int(event.ThresholdT),
SelectedParties: event.SelectedParties,
CreatedAt: time.Now(),
}) })
case "session_started": case "session_started":
@ -410,13 +416,32 @@ func createCoManagedSessionEventHandler(
zap.String("session_id", event.SessionId), zap.String("session_id", event.SessionId),
zap.String("party_id", partyID)) zap.String("party_id", partyID))
input := use_cases.ParticipateKeygenInput{ // Build SessionInfo from pending session and event data
SessionID: pendingSession.SessionID, // Note: We already called JoinSession in session_created phase,
PartyID: partyID, // so we use ExecuteWithSessionInfo to skip the duplicate JoinSession call
JoinToken: pendingSession.JoinToken, participants := make([]use_cases.ParticipantInfo, len(pendingSession.SelectedParties))
for i, p := range pendingSession.SelectedParties {
participants[i] = use_cases.ParticipantInfo{
PartyID: p,
PartyIndex: i,
}
} }
result, err := participateKeygenUC.Execute(participateCtx, input) sessionInfo := &use_cases.SessionInfo{
SessionID: pendingSession.SessionID,
SessionType: "co_managed_keygen",
ThresholdN: pendingSession.ThresholdN,
ThresholdT: pendingSession.ThresholdT,
MessageHash: pendingSession.MessageHash,
Participants: participants,
}
result, err := participateKeygenUC.ExecuteWithSessionInfo(
participateCtx,
pendingSession.SessionID,
partyID,
sessionInfo,
)
if err != nil { if err != nil {
logger.Error("Co-managed keygen participation failed", logger.Error("Co-managed keygen participation failed",
zap.Error(err), zap.Error(err),

View File

@ -134,12 +134,47 @@ func (uc *ParticipateKeygenUseCase) Execute(
} }
} }
// 2. Find self in participants and build party index map // Delegate to the common execution logic
return uc.executeWithSessionInfo(ctx, input.SessionID, input.PartyID, sessionInfo)
}
// ExecuteWithSessionInfo participates in a keygen session with pre-obtained SessionInfo
// This is used by server-party-co-managed which has already called JoinSession in session_created phase
// and receives session_started event when all participants have joined
func (uc *ParticipateKeygenUseCase) ExecuteWithSessionInfo(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
sessionInfo *SessionInfo,
) (*ParticipateKeygenOutput, error) {
// Validate session type
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
return nil, ErrInvalidSession
}
logger.Info("ExecuteWithSessionInfo: starting keygen with pre-obtained session info",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.String("session_type", sessionInfo.SessionType),
zap.Int("participants", len(sessionInfo.Participants)))
// Delegate to the common execution logic
return uc.executeWithSessionInfo(ctx, sessionID, partyID, sessionInfo)
}
// executeWithSessionInfo is the common execution logic shared by Execute and ExecuteWithSessionInfo
func (uc *ParticipateKeygenUseCase) executeWithSessionInfo(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
sessionInfo *SessionInfo,
) (*ParticipateKeygenOutput, error) {
// 1. Find self in participants and build party index map
var selfIndex int var selfIndex int
partyIndexMap := make(map[string]int) partyIndexMap := make(map[string]int)
for _, p := range sessionInfo.Participants { for _, p := range sessionInfo.Participants {
partyIndexMap[p.PartyID] = p.PartyIndex partyIndexMap[p.PartyID] = p.PartyIndex
if p.PartyID == input.PartyID { if p.PartyID == partyID {
selfIndex = p.PartyIndex selfIndex = p.PartyIndex
} }
logger.Debug("Added participant to index map", logger.Debug("Added participant to index map",
@ -147,13 +182,13 @@ func (uc *ParticipateKeygenUseCase) Execute(
zap.Int("party_index", p.PartyIndex)) zap.Int("party_index", p.PartyIndex))
} }
logger.Info("Built party index map", logger.Info("Built party index map",
zap.String("session_id", input.SessionID.String()), zap.String("session_id", sessionID.String()),
zap.String("self_party_id", input.PartyID), zap.String("self_party_id", partyID),
zap.Int("self_index", selfIndex), zap.Int("self_index", selfIndex),
zap.Int("total_participants", len(sessionInfo.Participants))) zap.Int("total_participants", len(sessionInfo.Participants)))
// 3. Subscribe to messages // 3. Subscribe to messages
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID) msgChan, err := uc.messageRouter.SubscribeMessages(ctx, sessionID, partyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -161,8 +196,8 @@ func (uc *ParticipateKeygenUseCase) Execute(
// 4. Run TSS Keygen protocol // 4. Run TSS Keygen protocol
saveData, publicKey, err := uc.runKeygenProtocol( saveData, publicKey, err := uc.runKeygenProtocol(
ctx, ctx,
input.SessionID, sessionID,
input.PartyID, partyID,
selfIndex, selfIndex,
sessionInfo.Participants, sessionInfo.Participants,
sessionInfo.ThresholdN, sessionInfo.ThresholdN,
@ -175,15 +210,15 @@ func (uc *ParticipateKeygenUseCase) Execute(
} }
// 5. Encrypt the share // 5. Encrypt the share
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID) encryptedShare, err := uc.cryptoService.EncryptShare(saveData, partyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
keyShare := entities.NewPartyKeyShare( keyShare := entities.NewPartyKeyShare(
input.PartyID, partyID,
selfIndex, selfIndex,
input.SessionID, sessionID,
sessionInfo.ThresholdN, sessionInfo.ThresholdN,
sessionInfo.ThresholdT, sessionInfo.ThresholdT,
encryptedShare, encryptedShare,
@ -201,21 +236,21 @@ func (uc *ParticipateKeygenUseCase) Execute(
return nil, ErrShareSaveFailed return nil, ErrShareSaveFailed
} }
logger.Info("Share saved to database (persistent party)", logger.Info("Share saved to database (persistent party)",
zap.String("party_id", input.PartyID), zap.String("party_id", partyID),
zap.String("session_id", input.SessionID.String())) zap.String("session_id", sessionID.String()))
case "delegate": case "delegate":
// Delegate Party: do NOT save to database, return to user // Delegate Party: do NOT save to database, return to user
shareForUser = encryptedShare shareForUser = encryptedShare
logger.Info("Share NOT saved, will be returned to user (delegate party)", logger.Info("Share NOT saved, will be returned to user (delegate party)",
zap.String("party_id", input.PartyID), zap.String("party_id", partyID),
zap.String("session_id", input.SessionID.String()), zap.String("session_id", sessionID.String()),
zap.Int("share_size", len(shareForUser))) zap.Int("share_size", len(shareForUser)))
case "temporary": case "temporary":
// Temporary Party: optionally save to temp storage (not implemented yet) // Temporary Party: optionally save to temp storage (not implemented yet)
logger.Info("Temporary party - share not saved", logger.Info("Temporary party - share not saved",
zap.String("party_id", input.PartyID)) zap.String("party_id", partyID))
default: default:
// Default to persistent for safety // Default to persistent for safety
@ -223,12 +258,12 @@ func (uc *ParticipateKeygenUseCase) Execute(
return nil, ErrShareSaveFailed return nil, ErrShareSaveFailed
} }
logger.Warn("Unknown party role, defaulting to persistent", logger.Warn("Unknown party role, defaulting to persistent",
zap.String("party_id", input.PartyID), zap.String("party_id", partyID),
zap.String("role", partyRole)) zap.String("role", partyRole))
} }
// 7. Report completion to coordinator // 7. Report completion to coordinator
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil { if err := uc.sessionClient.ReportCompletion(ctx, sessionID, partyID, publicKey); err != nil {
logger.Error("failed to report completion", zap.Error(err)) logger.Error("failed to report completion", zap.Error(err))
// Don't fail - share is handled // Don't fail - share is handled
} }