diff --git a/backend/mpc-system/services/message-router/adapters/input/grpc/message_grpc_handler.go b/backend/mpc-system/services/message-router/adapters/input/grpc/message_grpc_handler.go index 57cd95bd..a0f926f4 100644 --- a/backend/mpc-system/services/message-router/adapters/input/grpc/message_grpc_handler.go +++ b/backend/mpc-system/services/message-router/adapters/input/grpc/message_grpc_handler.go @@ -679,11 +679,24 @@ func (s *MessageRouterServer) GetSessionStatus( return nil, err } + // Convert participants from coordinator response + var participants []*pb.PartyInfo + if len(coordResp.Participants) > 0 { + participants = make([]*pb.PartyInfo, len(coordResp.Participants)) + for i, p := range coordResp.Participants { + participants[i] = &pb.PartyInfo{ + PartyId: p.PartyId, + PartyIndex: p.PartyIndex, + } + } + } + return &pb.GetSessionStatusResponse{ - SessionId: req.SessionId, - Status: coordResp.Status, - ThresholdN: coordResp.TotalParties, // Use TotalParties as N - ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info + SessionId: req.SessionId, + Status: coordResp.Status, + ThresholdN: coordResp.TotalParties, // Use TotalParties as N + ThresholdT: coordResp.CompletedParties, // Return completed count in ThresholdT for info + Participants: participants, // Include participants for co_managed_keygen }, nil } diff --git a/backend/mpc-system/services/server-party/adapters/output/grpc/message_router_client.go b/backend/mpc-system/services/server-party/adapters/output/grpc/message_router_client.go index 1d7a724d..77d0dff6 100644 --- a/backend/mpc-system/services/server-party/adapters/output/grpc/message_router_client.go +++ b/backend/mpc-system/services/server-party/adapters/output/grpc/message_router_client.go @@ -851,3 +851,45 @@ func (c *MessageRouterClient) SubmitDelegateShare( return nil }) } + +// GetSessionStatusFull gets the full session status including participants via Message Router +// This is used for co_managed_keygen sessions to wait for all parties to join +// Includes automatic retry with exponential backoff for transient failures +func (c *MessageRouterClient) GetSessionStatusFull( + ctx context.Context, + sessionID uuid.UUID, +) (*use_cases.SessionStatusInfo, error) { + req := &router.GetSessionStatusRequest{ + SessionId: sessionID.String(), + } + + return retry.Do(ctx, c.retryCfg, "GetSessionStatusFull", func() (*use_cases.SessionStatusInfo, error) { + resp := &router.GetSessionStatusResponse{} + err := c.getConn().Invoke(ctx, "/mpc.router.v1.MessageRouter/GetSessionStatus", req, resp) + if err != nil { + return nil, err + } + + // Convert participants from response + participants := make([]use_cases.ParticipantInfo, len(resp.Participants)) + for i, p := range resp.Participants { + participants[i] = use_cases.ParticipantInfo{ + PartyID: p.PartyId, + PartyIndex: int(p.PartyIndex), + } + } + + logger.Debug("GetSessionStatusFull response", + zap.String("session_id", sessionID.String()), + zap.String("status", resp.Status), + zap.Int32("threshold_n", resp.ThresholdN), + zap.Int("participants_count", len(participants))) + + return &use_cases.SessionStatusInfo{ + Status: resp.Status, + ThresholdN: int(resp.ThresholdN), + ThresholdT: int(resp.ThresholdT), + Participants: participants, + }, nil + }) +} diff --git a/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go b/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go index 78159a67..73bbc6f5 100644 --- a/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go +++ b/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go @@ -41,6 +41,15 @@ type ParticipateKeygenOutput struct { type SessionCoordinatorClient interface { JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error) ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error + GetSessionStatusFull(ctx context.Context, sessionID uuid.UUID) (*SessionStatusInfo, error) +} + +// SessionStatusInfo contains full session status information +type SessionStatusInfo struct { + Status string + ThresholdN int + ThresholdT int + Participants []ParticipantInfo } // MessageRouterClient defines the interface for message router communication @@ -115,6 +124,15 @@ func (uc *ParticipateKeygenUseCase) Execute( return nil, ErrInvalidSession } + // For co_managed_keygen: wait for all N participants to join before proceeding + // This is necessary because server parties join immediately but external party joins later + if sessionInfo.SessionType == "co_managed_keygen" { + sessionInfo, err = uc.waitForAllParticipants(ctx, input.SessionID, sessionInfo) + if err != nil { + return nil, err + } + } + // 2. Find self in participants and build party index map var selfIndex int partyIndexMap := make(map[string]int) @@ -369,3 +387,77 @@ func (uc *ParticipateKeygenUseCase) getPartyRole() string { } return role } + +// waitForAllParticipants waits for all N participants to join the session +// This is only used for co_managed_keygen sessions where server parties join first +// and need to wait for the external party to join via invite code +func (uc *ParticipateKeygenUseCase) waitForAllParticipants( + ctx context.Context, + sessionID uuid.UUID, + initialSessionInfo *SessionInfo, +) (*SessionInfo, error) { + logger.Info("Waiting for all participants to join co_managed_keygen session", + zap.String("session_id", sessionID.String()), + zap.Int("expected_n", initialSessionInfo.ThresholdN), + zap.Int("current_participants", len(initialSessionInfo.Participants))) + + // If already have all participants, return immediately + if len(initialSessionInfo.Participants) >= initialSessionInfo.ThresholdN { + logger.Info("All participants already joined", + zap.String("session_id", sessionID.String())) + return initialSessionInfo, nil + } + + // Poll for session status until all participants join or timeout + pollInterval := 2 * time.Second + maxWaitTime := 5 * time.Minute + deadline := time.Now().Add(maxWaitTime) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(pollInterval): + // Get full session status including participants + statusInfo, err := uc.sessionClient.GetSessionStatusFull(ctx, sessionID) + if err != nil { + logger.Warn("Failed to get session status, will retry", + zap.String("session_id", sessionID.String()), + zap.Error(err)) + continue + } + + logger.Debug("Polled session status", + zap.String("session_id", sessionID.String()), + zap.String("status", statusInfo.Status), + zap.Int("participants", len(statusInfo.Participants)), + zap.Int("expected_n", initialSessionInfo.ThresholdN)) + + // Check if session is in_progress (all parties joined and ready) + if statusInfo.Status == "in_progress" && len(statusInfo.Participants) >= initialSessionInfo.ThresholdN { + logger.Info("All participants joined, session is in_progress", + zap.String("session_id", sessionID.String()), + zap.Int("participants", len(statusInfo.Participants))) + + // Update session info with full participants list + initialSessionInfo.Participants = statusInfo.Participants + return initialSessionInfo, nil + } + + // Also accept if we have all N participants even if status hasn't changed + if len(statusInfo.Participants) >= initialSessionInfo.ThresholdN { + logger.Info("All participants joined", + zap.String("session_id", sessionID.String()), + zap.Int("participants", len(statusInfo.Participants))) + + initialSessionInfo.Participants = statusInfo.Participants + return initialSessionInfo, nil + } + } + } + + logger.Error("Timeout waiting for all participants", + zap.String("session_id", sessionID.String()), + zap.Int("expected_n", initialSessionInfo.ThresholdN)) + return nil, ErrKeygenTimeout +}