feat(server-party): add ExecuteWithSessionInfo for co-managed keygen
Add new ExecuteWithSessionInfo method to ParticipateKeygenUseCase for server-party-co-managed to skip duplicate JoinSession call. - server-party-co-managed already calls JoinSession in session_created phase - ExecuteWithSessionInfo accepts pre-obtained SessionInfo and skips internal JoinSession - Refactor common execution logic to private executeWithSessionInfo method - Update server-party-co-managed to use ExecuteWithSessionInfo on session_started 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
fd6f84ce82
commit
2164664ca0
|
|
@ -29,10 +29,13 @@ import (
|
||||||
|
|
||||||
// PendingSession stores session info between session_created and session_started events
|
// PendingSession stores session info between session_created and session_started events
|
||||||
type PendingSession struct {
|
type PendingSession struct {
|
||||||
SessionID uuid.UUID
|
SessionID uuid.UUID
|
||||||
JoinToken string
|
JoinToken string
|
||||||
MessageHash []byte
|
MessageHash []byte
|
||||||
CreatedAt time.Time
|
ThresholdN int
|
||||||
|
ThresholdT int
|
||||||
|
SelectedParties []string
|
||||||
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// PendingSessionCache stores pending sessions waiting for session_started
|
// PendingSessionCache stores pending sessions waiting for session_started
|
||||||
|
|
@ -379,10 +382,13 @@ func createCoManagedSessionEventHandler(
|
||||||
|
|
||||||
// Store pending session for later use when session_started arrives
|
// Store pending session for later use when session_started arrives
|
||||||
pendingSessionCache.Store(event.SessionId, &PendingSession{
|
pendingSessionCache.Store(event.SessionId, &PendingSession{
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
JoinToken: joinToken,
|
JoinToken: joinToken,
|
||||||
MessageHash: event.MessageHash,
|
MessageHash: event.MessageHash,
|
||||||
CreatedAt: time.Now(),
|
ThresholdN: int(event.ThresholdN),
|
||||||
|
ThresholdT: int(event.ThresholdT),
|
||||||
|
SelectedParties: event.SelectedParties,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
})
|
})
|
||||||
|
|
||||||
case "session_started":
|
case "session_started":
|
||||||
|
|
@ -410,13 +416,32 @@ func createCoManagedSessionEventHandler(
|
||||||
zap.String("session_id", event.SessionId),
|
zap.String("session_id", event.SessionId),
|
||||||
zap.String("party_id", partyID))
|
zap.String("party_id", partyID))
|
||||||
|
|
||||||
input := use_cases.ParticipateKeygenInput{
|
// Build SessionInfo from pending session and event data
|
||||||
SessionID: pendingSession.SessionID,
|
// Note: We already called JoinSession in session_created phase,
|
||||||
PartyID: partyID,
|
// so we use ExecuteWithSessionInfo to skip the duplicate JoinSession call
|
||||||
JoinToken: pendingSession.JoinToken,
|
participants := make([]use_cases.ParticipantInfo, len(pendingSession.SelectedParties))
|
||||||
|
for i, p := range pendingSession.SelectedParties {
|
||||||
|
participants[i] = use_cases.ParticipantInfo{
|
||||||
|
PartyID: p,
|
||||||
|
PartyIndex: i,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := participateKeygenUC.Execute(participateCtx, input)
|
sessionInfo := &use_cases.SessionInfo{
|
||||||
|
SessionID: pendingSession.SessionID,
|
||||||
|
SessionType: "co_managed_keygen",
|
||||||
|
ThresholdN: pendingSession.ThresholdN,
|
||||||
|
ThresholdT: pendingSession.ThresholdT,
|
||||||
|
MessageHash: pendingSession.MessageHash,
|
||||||
|
Participants: participants,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := participateKeygenUC.ExecuteWithSessionInfo(
|
||||||
|
participateCtx,
|
||||||
|
pendingSession.SessionID,
|
||||||
|
partyID,
|
||||||
|
sessionInfo,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Co-managed keygen participation failed",
|
logger.Error("Co-managed keygen participation failed",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
|
|
|
||||||
|
|
@ -134,12 +134,47 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Find self in participants and build party index map
|
// Delegate to the common execution logic
|
||||||
|
return uc.executeWithSessionInfo(ctx, input.SessionID, input.PartyID, sessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteWithSessionInfo participates in a keygen session with pre-obtained SessionInfo
|
||||||
|
// This is used by server-party-co-managed which has already called JoinSession in session_created phase
|
||||||
|
// and receives session_started event when all participants have joined
|
||||||
|
func (uc *ParticipateKeygenUseCase) ExecuteWithSessionInfo(
|
||||||
|
ctx context.Context,
|
||||||
|
sessionID uuid.UUID,
|
||||||
|
partyID string,
|
||||||
|
sessionInfo *SessionInfo,
|
||||||
|
) (*ParticipateKeygenOutput, error) {
|
||||||
|
// Validate session type
|
||||||
|
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
|
||||||
|
return nil, ErrInvalidSession
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("ExecuteWithSessionInfo: starting keygen with pre-obtained session info",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.String("party_id", partyID),
|
||||||
|
zap.String("session_type", sessionInfo.SessionType),
|
||||||
|
zap.Int("participants", len(sessionInfo.Participants)))
|
||||||
|
|
||||||
|
// Delegate to the common execution logic
|
||||||
|
return uc.executeWithSessionInfo(ctx, sessionID, partyID, sessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeWithSessionInfo is the common execution logic shared by Execute and ExecuteWithSessionInfo
|
||||||
|
func (uc *ParticipateKeygenUseCase) executeWithSessionInfo(
|
||||||
|
ctx context.Context,
|
||||||
|
sessionID uuid.UUID,
|
||||||
|
partyID string,
|
||||||
|
sessionInfo *SessionInfo,
|
||||||
|
) (*ParticipateKeygenOutput, error) {
|
||||||
|
// 1. Find self in participants and build party index map
|
||||||
var selfIndex int
|
var selfIndex int
|
||||||
partyIndexMap := make(map[string]int)
|
partyIndexMap := make(map[string]int)
|
||||||
for _, p := range sessionInfo.Participants {
|
for _, p := range sessionInfo.Participants {
|
||||||
partyIndexMap[p.PartyID] = p.PartyIndex
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
||||||
if p.PartyID == input.PartyID {
|
if p.PartyID == partyID {
|
||||||
selfIndex = p.PartyIndex
|
selfIndex = p.PartyIndex
|
||||||
}
|
}
|
||||||
logger.Debug("Added participant to index map",
|
logger.Debug("Added participant to index map",
|
||||||
|
|
@ -147,13 +182,13 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
zap.Int("party_index", p.PartyIndex))
|
zap.Int("party_index", p.PartyIndex))
|
||||||
}
|
}
|
||||||
logger.Info("Built party index map",
|
logger.Info("Built party index map",
|
||||||
zap.String("session_id", input.SessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.String("self_party_id", input.PartyID),
|
zap.String("self_party_id", partyID),
|
||||||
zap.Int("self_index", selfIndex),
|
zap.Int("self_index", selfIndex),
|
||||||
zap.Int("total_participants", len(sessionInfo.Participants)))
|
zap.Int("total_participants", len(sessionInfo.Participants)))
|
||||||
|
|
||||||
// 3. Subscribe to messages
|
// 3. Subscribe to messages
|
||||||
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, sessionID, partyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -161,8 +196,8 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
// 4. Run TSS Keygen protocol
|
// 4. Run TSS Keygen protocol
|
||||||
saveData, publicKey, err := uc.runKeygenProtocol(
|
saveData, publicKey, err := uc.runKeygenProtocol(
|
||||||
ctx,
|
ctx,
|
||||||
input.SessionID,
|
sessionID,
|
||||||
input.PartyID,
|
partyID,
|
||||||
selfIndex,
|
selfIndex,
|
||||||
sessionInfo.Participants,
|
sessionInfo.Participants,
|
||||||
sessionInfo.ThresholdN,
|
sessionInfo.ThresholdN,
|
||||||
|
|
@ -175,15 +210,15 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Encrypt the share
|
// 5. Encrypt the share
|
||||||
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID)
|
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, partyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyShare := entities.NewPartyKeyShare(
|
keyShare := entities.NewPartyKeyShare(
|
||||||
input.PartyID,
|
partyID,
|
||||||
selfIndex,
|
selfIndex,
|
||||||
input.SessionID,
|
sessionID,
|
||||||
sessionInfo.ThresholdN,
|
sessionInfo.ThresholdN,
|
||||||
sessionInfo.ThresholdT,
|
sessionInfo.ThresholdT,
|
||||||
encryptedShare,
|
encryptedShare,
|
||||||
|
|
@ -201,21 +236,21 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
return nil, ErrShareSaveFailed
|
return nil, ErrShareSaveFailed
|
||||||
}
|
}
|
||||||
logger.Info("Share saved to database (persistent party)",
|
logger.Info("Share saved to database (persistent party)",
|
||||||
zap.String("party_id", input.PartyID),
|
zap.String("party_id", partyID),
|
||||||
zap.String("session_id", input.SessionID.String()))
|
zap.String("session_id", sessionID.String()))
|
||||||
|
|
||||||
case "delegate":
|
case "delegate":
|
||||||
// Delegate Party: do NOT save to database, return to user
|
// Delegate Party: do NOT save to database, return to user
|
||||||
shareForUser = encryptedShare
|
shareForUser = encryptedShare
|
||||||
logger.Info("Share NOT saved, will be returned to user (delegate party)",
|
logger.Info("Share NOT saved, will be returned to user (delegate party)",
|
||||||
zap.String("party_id", input.PartyID),
|
zap.String("party_id", partyID),
|
||||||
zap.String("session_id", input.SessionID.String()),
|
zap.String("session_id", sessionID.String()),
|
||||||
zap.Int("share_size", len(shareForUser)))
|
zap.Int("share_size", len(shareForUser)))
|
||||||
|
|
||||||
case "temporary":
|
case "temporary":
|
||||||
// Temporary Party: optionally save to temp storage (not implemented yet)
|
// Temporary Party: optionally save to temp storage (not implemented yet)
|
||||||
logger.Info("Temporary party - share not saved",
|
logger.Info("Temporary party - share not saved",
|
||||||
zap.String("party_id", input.PartyID))
|
zap.String("party_id", partyID))
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Default to persistent for safety
|
// Default to persistent for safety
|
||||||
|
|
@ -223,12 +258,12 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
return nil, ErrShareSaveFailed
|
return nil, ErrShareSaveFailed
|
||||||
}
|
}
|
||||||
logger.Warn("Unknown party role, defaulting to persistent",
|
logger.Warn("Unknown party role, defaulting to persistent",
|
||||||
zap.String("party_id", input.PartyID),
|
zap.String("party_id", partyID),
|
||||||
zap.String("role", partyRole))
|
zap.String("role", partyRole))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Report completion to coordinator
|
// 7. Report completion to coordinator
|
||||||
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil {
|
if err := uc.sessionClient.ReportCompletion(ctx, sessionID, partyID, publicKey); err != nil {
|
||||||
logger.Error("failed to report completion", zap.Error(err))
|
logger.Error("failed to report completion", zap.Error(err))
|
||||||
// Don't fail - share is handled
|
// Don't fail - share is handled
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue