fix(server-party): co_managed_keygen 等待所有参与者加入后再开始 keygen

- Message Router GetSessionStatus 透传 participants 列表
- Server Party 新增 GetSessionStatusFull 方法获取完整会话状态
- participate_keygen.go 对 co_managed_keygen 类型轮询等待所有 N 个参与者加入
- 不影响原有 keygen/sign 功能(仅 co_managed_keygen 触发等待逻辑)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
hailin 2025-12-29 09:55:52 -08:00
parent 48c8c071d5
commit e222279d77
3 changed files with 151 additions and 4 deletions

View File

@ -679,11 +679,24 @@ func (s *MessageRouterServer) GetSessionStatus(
return nil, err
}
// Convert participants from coordinator response
var participants []*pb.PartyInfo
if len(coordResp.Participants) > 0 {
participants = make([]*pb.PartyInfo, len(coordResp.Participants))
for i, p := range coordResp.Participants {
participants[i] = &pb.PartyInfo{
PartyId: p.PartyId,
PartyIndex: p.PartyIndex,
}
}
}
return &pb.GetSessionStatusResponse{
SessionId: req.SessionId,
Status: coordResp.Status,
ThresholdN: coordResp.TotalParties, // Use TotalParties as N
ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info
SessionId: req.SessionId,
Status: coordResp.Status,
ThresholdN: coordResp.TotalParties, // Use TotalParties as N
ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info
Participants: participants, // Include participants for co_managed_keygen
}, nil
}

View File

@ -851,3 +851,45 @@ func (c *MessageRouterClient) SubmitDelegateShare(
return nil
})
}
// GetSessionStatusFull gets the full session status including participants via Message Router
// This is used for co_managed_keygen sessions to wait for all parties to join
// Includes automatic retry with exponential backoff for transient failures
func (c *MessageRouterClient) GetSessionStatusFull(
ctx context.Context,
sessionID uuid.UUID,
) (*use_cases.SessionStatusInfo, error) {
req := &router.GetSessionStatusRequest{
SessionId: sessionID.String(),
}
return retry.Do(ctx, c.retryCfg, "GetSessionStatusFull", func() (*use_cases.SessionStatusInfo, error) {
resp := &router.GetSessionStatusResponse{}
err := c.getConn().Invoke(ctx, "/mpc.router.v1.MessageRouter/GetSessionStatus", req, resp)
if err != nil {
return nil, err
}
// Convert participants from response
participants := make([]use_cases.ParticipantInfo, len(resp.Participants))
for i, p := range resp.Participants {
participants[i] = use_cases.ParticipantInfo{
PartyID: p.PartyId,
PartyIndex: int(p.PartyIndex),
}
}
logger.Debug("GetSessionStatusFull response",
zap.String("session_id", sessionID.String()),
zap.String("status", resp.Status),
zap.Int32("threshold_n", resp.ThresholdN),
zap.Int("participants_count", len(participants)))
return &use_cases.SessionStatusInfo{
Status: resp.Status,
ThresholdN: int(resp.ThresholdN),
ThresholdT: int(resp.ThresholdT),
Participants: participants,
}, nil
})
}

View File

@ -41,6 +41,15 @@ type ParticipateKeygenOutput struct {
type SessionCoordinatorClient interface {
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
GetSessionStatusFull(ctx context.Context, sessionID uuid.UUID) (*SessionStatusInfo, error)
}
// SessionStatusInfo contains full session status information
type SessionStatusInfo struct {
Status string
ThresholdN int
ThresholdT int
Participants []ParticipantInfo
}
// MessageRouterClient defines the interface for message router communication
@ -115,6 +124,15 @@ func (uc *ParticipateKeygenUseCase) Execute(
return nil, ErrInvalidSession
}
// For co_managed_keygen: wait for all N participants to join before proceeding
// This is necessary because server parties join immediately but external party joins later
if sessionInfo.SessionType == "co_managed_keygen" {
sessionInfo, err = uc.waitForAllParticipants(ctx, input.SessionID, sessionInfo)
if err != nil {
return nil, err
}
}
// 2. Find self in participants and build party index map
var selfIndex int
partyIndexMap := make(map[string]int)
@ -369,3 +387,77 @@ func (uc *ParticipateKeygenUseCase) getPartyRole() string {
}
return role
}
// waitForAllParticipants waits for all N participants to join the session
// This is only used for co_managed_keygen sessions where server parties join first
// and need to wait for the external party to join via invite code
func (uc *ParticipateKeygenUseCase) waitForAllParticipants(
ctx context.Context,
sessionID uuid.UUID,
initialSessionInfo *SessionInfo,
) (*SessionInfo, error) {
logger.Info("Waiting for all participants to join co_managed_keygen session",
zap.String("session_id", sessionID.String()),
zap.Int("expected_n", initialSessionInfo.ThresholdN),
zap.Int("current_participants", len(initialSessionInfo.Participants)))
// If already have all participants, return immediately
if len(initialSessionInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants already joined",
zap.String("session_id", sessionID.String()))
return initialSessionInfo, nil
}
// Poll for session status until all participants join or timeout
pollInterval := 2 * time.Second
maxWaitTime := 5 * time.Minute
deadline := time.Now().Add(maxWaitTime)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
// Get full session status including participants
statusInfo, err := uc.sessionClient.GetSessionStatusFull(ctx, sessionID)
if err != nil {
logger.Warn("Failed to get session status, will retry",
zap.String("session_id", sessionID.String()),
zap.Error(err))
continue
}
logger.Debug("Polled session status",
zap.String("session_id", sessionID.String()),
zap.String("status", statusInfo.Status),
zap.Int("participants", len(statusInfo.Participants)),
zap.Int("expected_n", initialSessionInfo.ThresholdN))
// Check if session is in_progress (all parties joined and ready)
if statusInfo.Status == "in_progress" && len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants joined, session is in_progress",
zap.String("session_id", sessionID.String()),
zap.Int("participants", len(statusInfo.Participants)))
// Update session info with full participants list
initialSessionInfo.Participants = statusInfo.Participants
return initialSessionInfo, nil
}
// Also accept if we have all N participants even if status hasn't changed
if len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants joined",
zap.String("session_id", sessionID.String()),
zap.Int("participants", len(statusInfo.Participants)))
initialSessionInfo.Participants = statusInfo.Participants
return initialSessionInfo, nil
}
}
}
logger.Error("Timeout waiting for all participants",
zap.String("session_id", sessionID.String()),
zap.Int("expected_n", initialSessionInfo.ThresholdN))
return nil, ErrKeygenTimeout
}