fix(server-party): co_managed_keygen 等待所有参与者加入后再开始 keygen
- Message Router GetSessionStatus 透传 participants 列表 - Server Party 新增 GetSessionStatusFull 方法获取完整会话状态 - participate_keygen.go 对 co_managed_keygen 类型轮询等待所有 N 个参与者加入 - 不影响原有 keygen/sign 功能(仅 co_managed_keygen 触发等待逻辑) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
48c8c071d5
commit
e222279d77
|
|
@ -679,11 +679,24 @@ func (s *MessageRouterServer) GetSessionStatus(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert participants from coordinator response
|
||||||
|
var participants []*pb.PartyInfo
|
||||||
|
if len(coordResp.Participants) > 0 {
|
||||||
|
participants = make([]*pb.PartyInfo, len(coordResp.Participants))
|
||||||
|
for i, p := range coordResp.Participants {
|
||||||
|
participants[i] = &pb.PartyInfo{
|
||||||
|
PartyId: p.PartyId,
|
||||||
|
PartyIndex: p.PartyIndex,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &pb.GetSessionStatusResponse{
|
return &pb.GetSessionStatusResponse{
|
||||||
SessionId: req.SessionId,
|
SessionId: req.SessionId,
|
||||||
Status: coordResp.Status,
|
Status: coordResp.Status,
|
||||||
ThresholdN: coordResp.TotalParties, // Use TotalParties as N
|
ThresholdN: coordResp.TotalParties, // Use TotalParties as N
|
||||||
ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info
|
ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info
|
||||||
|
Participants: participants, // Include participants for co_managed_keygen
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -851,3 +851,45 @@ func (c *MessageRouterClient) SubmitDelegateShare(
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionStatusFull gets the full session status including participants via Message Router
|
||||||
|
// This is used for co_managed_keygen sessions to wait for all parties to join
|
||||||
|
// Includes automatic retry with exponential backoff for transient failures
|
||||||
|
func (c *MessageRouterClient) GetSessionStatusFull(
|
||||||
|
ctx context.Context,
|
||||||
|
sessionID uuid.UUID,
|
||||||
|
) (*use_cases.SessionStatusInfo, error) {
|
||||||
|
req := &router.GetSessionStatusRequest{
|
||||||
|
SessionId: sessionID.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return retry.Do(ctx, c.retryCfg, "GetSessionStatusFull", func() (*use_cases.SessionStatusInfo, error) {
|
||||||
|
resp := &router.GetSessionStatusResponse{}
|
||||||
|
err := c.getConn().Invoke(ctx, "/mpc.router.v1.MessageRouter/GetSessionStatus", req, resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert participants from response
|
||||||
|
participants := make([]use_cases.ParticipantInfo, len(resp.Participants))
|
||||||
|
for i, p := range resp.Participants {
|
||||||
|
participants[i] = use_cases.ParticipantInfo{
|
||||||
|
PartyID: p.PartyId,
|
||||||
|
PartyIndex: int(p.PartyIndex),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("GetSessionStatusFull response",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.String("status", resp.Status),
|
||||||
|
zap.Int32("threshold_n", resp.ThresholdN),
|
||||||
|
zap.Int("participants_count", len(participants)))
|
||||||
|
|
||||||
|
return &use_cases.SessionStatusInfo{
|
||||||
|
Status: resp.Status,
|
||||||
|
ThresholdN: int(resp.ThresholdN),
|
||||||
|
ThresholdT: int(resp.ThresholdT),
|
||||||
|
Participants: participants,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,15 @@ type ParticipateKeygenOutput struct {
|
||||||
type SessionCoordinatorClient interface {
|
type SessionCoordinatorClient interface {
|
||||||
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
||||||
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
||||||
|
GetSessionStatusFull(ctx context.Context, sessionID uuid.UUID) (*SessionStatusInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStatusInfo contains full session status information
|
||||||
|
type SessionStatusInfo struct {
|
||||||
|
Status string
|
||||||
|
ThresholdN int
|
||||||
|
ThresholdT int
|
||||||
|
Participants []ParticipantInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageRouterClient defines the interface for message router communication
|
// MessageRouterClient defines the interface for message router communication
|
||||||
|
|
@ -115,6 +124,15 @@ func (uc *ParticipateKeygenUseCase) Execute(
|
||||||
return nil, ErrInvalidSession
|
return nil, ErrInvalidSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For co_managed_keygen: wait for all N participants to join before proceeding
|
||||||
|
// This is necessary because server parties join immediately but external party joins later
|
||||||
|
if sessionInfo.SessionType == "co_managed_keygen" {
|
||||||
|
sessionInfo, err = uc.waitForAllParticipants(ctx, input.SessionID, sessionInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 2. Find self in participants and build party index map
|
// 2. Find self in participants and build party index map
|
||||||
var selfIndex int
|
var selfIndex int
|
||||||
partyIndexMap := make(map[string]int)
|
partyIndexMap := make(map[string]int)
|
||||||
|
|
@ -369,3 +387,77 @@ func (uc *ParticipateKeygenUseCase) getPartyRole() string {
|
||||||
}
|
}
|
||||||
return role
|
return role
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForAllParticipants waits for all N participants to join the session
|
||||||
|
// This is only used for co_managed_keygen sessions where server parties join first
|
||||||
|
// and need to wait for the external party to join via invite code
|
||||||
|
func (uc *ParticipateKeygenUseCase) waitForAllParticipants(
|
||||||
|
ctx context.Context,
|
||||||
|
sessionID uuid.UUID,
|
||||||
|
initialSessionInfo *SessionInfo,
|
||||||
|
) (*SessionInfo, error) {
|
||||||
|
logger.Info("Waiting for all participants to join co_managed_keygen session",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.Int("expected_n", initialSessionInfo.ThresholdN),
|
||||||
|
zap.Int("current_participants", len(initialSessionInfo.Participants)))
|
||||||
|
|
||||||
|
// If already have all participants, return immediately
|
||||||
|
if len(initialSessionInfo.Participants) >= initialSessionInfo.ThresholdN {
|
||||||
|
logger.Info("All participants already joined",
|
||||||
|
zap.String("session_id", sessionID.String()))
|
||||||
|
return initialSessionInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Poll for session status until all participants join or timeout
|
||||||
|
pollInterval := 2 * time.Second
|
||||||
|
maxWaitTime := 5 * time.Minute
|
||||||
|
deadline := time.Now().Add(maxWaitTime)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(pollInterval):
|
||||||
|
// Get full session status including participants
|
||||||
|
statusInfo, err := uc.sessionClient.GetSessionStatusFull(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to get session status, will retry",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Polled session status",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.String("status", statusInfo.Status),
|
||||||
|
zap.Int("participants", len(statusInfo.Participants)),
|
||||||
|
zap.Int("expected_n", initialSessionInfo.ThresholdN))
|
||||||
|
|
||||||
|
// Check if session is in_progress (all parties joined and ready)
|
||||||
|
if statusInfo.Status == "in_progress" && len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
|
||||||
|
logger.Info("All participants joined, session is in_progress",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.Int("participants", len(statusInfo.Participants)))
|
||||||
|
|
||||||
|
// Update session info with full participants list
|
||||||
|
initialSessionInfo.Participants = statusInfo.Participants
|
||||||
|
return initialSessionInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also accept if we have all N participants even if status hasn't changed
|
||||||
|
if len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
|
||||||
|
logger.Info("All participants joined",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.Int("participants", len(statusInfo.Participants)))
|
||||||
|
|
||||||
|
initialSessionInfo.Participants = statusInfo.Participants
|
||||||
|
return initialSessionInfo, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("Timeout waiting for all participants",
|
||||||
|
zap.String("session_id", sessionID.String()),
|
||||||
|
zap.Int("expected_n", initialSessionInfo.ThresholdN))
|
||||||
|
return nil, ErrKeygenTimeout
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue