From 7531cbd07aed8256fb5095b85cd63de46834a2ed Mon Sep 17 00:00:00 2001 From: hailin Date: Sat, 29 Nov 2025 00:31:24 -0800 Subject: [PATCH] fix: Implement MarkPartyReady and StartSession handlers, update domain logic - Add sessionRepo to HTTP handler for database operations - Implement MarkPartyReady handler to update participant status - Implement StartSession handler to start MPC sessions - Update CanStart() to accept participants in 'ready' status - Make Start() method idempotent to handle automatic + explicit starts - Fix repository injection through dependency chain in main.go - Add party_id parameter to test completion request --- .../input/grpc/session_grpc_handler.go | 326 +++++++++++ .../input/http/session_http_handler.go | 543 ++++++++++++++++++ .../session-coordinator/cmd/server/main.go | 306 ++++++++++ .../domain/entities/mpc_session.go | 342 +++++++++++ .../mpc-system/tests/e2e/keygen_flow_test.go | 354 ++++++++++++ 5 files changed, 1871 insertions(+) create mode 100644 backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go create mode 100644 backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go create mode 100644 backend/mpc-system/services/session-coordinator/cmd/server/main.go create mode 100644 backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go create mode 100644 backend/mpc-system/tests/e2e/keygen_flow_test.go diff --git a/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go b/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go new file mode 100644 index 00000000..91274322 --- /dev/null +++ b/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go @@ -0,0 +1,326 @@ +package grpc + +import ( + "context" + "time" + + "github.com/google/uuid" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// SessionCoordinatorServer implements the gRPC SessionCoordinator service +type SessionCoordinatorServer struct { + createSessionUC *use_cases.CreateSessionUseCase + joinSessionUC *use_cases.JoinSessionUseCase + getSessionStatusUC *use_cases.GetSessionStatusUseCase + reportCompletionUC *use_cases.ReportCompletionUseCase + closeSessionUC *use_cases.CloseSessionUseCase +} + +// NewSessionCoordinatorServer creates a new gRPC server +func NewSessionCoordinatorServer( + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, +) *SessionCoordinatorServer { + return &SessionCoordinatorServer{ + createSessionUC: createSessionUC, + joinSessionUC: joinSessionUC, + getSessionStatusUC: getSessionStatusUC, + reportCompletionUC: reportCompletionUC, + closeSessionUC: closeSessionUC, + } +} + +// CreateSession creates a new MPC session +func (s *SessionCoordinatorServer) CreateSession( + ctx context.Context, + req *CreateSessionRequest, +) (*CreateSessionResponse, error) { + // Convert request to input + participants := make([]input.ParticipantInfo, len(req.Participants)) + for i, p := range req.Participants { + participants[i] = input.ParticipantInfo{ + PartyID: p.PartyId, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceId, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + inputData := input.CreateSessionInput{ + InitiatorID: "", // Could be extracted from auth context + SessionType: req.SessionType, + ThresholdN: int(req.ThresholdN), + ThresholdT: int(req.ThresholdT), + Participants: participants, + MessageHash: req.MessageHash, + ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second, + } + + // Execute use case + output, err := s.createSessionUC.Execute(ctx, inputData) + if err != nil { + return nil, toGRPCError(err) + } + + // Convert output to response + return &CreateSessionResponse{ + SessionId: output.SessionID.String(), + JoinTokens: output.JoinTokens, + ExpiresAt: output.ExpiresAt.UnixMilli(), + }, nil +} + +// JoinSession allows a participant to join a session +func (s *SessionCoordinatorServer) JoinSession( + ctx context.Context, + req *JoinSessionRequest, +) (*JoinSessionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + inputData := input.JoinSessionInput{ + SessionID: sessionID, + PartyID: req.PartyId, + JoinToken: req.JoinToken, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), + DeviceID: req.DeviceInfo.DeviceId, + Platform: req.DeviceInfo.Platform, + AppVersion: req.DeviceInfo.AppVersion, + }, + } + + output, err := s.joinSessionUC.Execute(ctx, inputData) + if err != nil { + return nil, toGRPCError(err) + } + + // Convert other parties to response format + otherParties := make([]*PartyInfo, len(output.OtherParties)) + for i, p := range output.OtherParties { + otherParties[i] = &PartyInfo{ + PartyId: p.PartyID, + PartyIndex: int32(p.PartyIndex), + DeviceInfo: &DeviceInfo{ + DeviceType: string(p.DeviceInfo.DeviceType), + DeviceId: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + return &JoinSessionResponse{ + Success: output.Success, + SessionInfo: &SessionInfo{ + SessionId: output.SessionInfo.SessionID.String(), + SessionType: output.SessionInfo.SessionType, + ThresholdN: int32(output.SessionInfo.ThresholdN), + ThresholdT: int32(output.SessionInfo.ThresholdT), + MessageHash: output.SessionInfo.MessageHash, + Status: output.SessionInfo.Status, + }, + OtherParties: otherParties, + }, nil +} + +// GetSessionStatus retrieves the status of a session +func (s *SessionCoordinatorServer) GetSessionStatus( + ctx context.Context, + req *GetSessionStatusRequest, +) (*GetSessionStatusResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + output, err := s.getSessionStatusUC.Execute(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + + // Calculate completed parties from participants + completedParties := 0 + for _, p := range output.Participants { + if p.Status == "completed" { + completedParties++ + } + } + + return &GetSessionStatusResponse{ + Status: output.Status, + CompletedParties: int32(completedParties), + TotalParties: int32(len(output.Participants)), + PublicKey: output.PublicKey, + Signature: output.Signature, + }, nil +} + +// ReportCompletion reports that a participant has completed +func (s *SessionCoordinatorServer) ReportCompletion( + ctx context.Context, + req *ReportCompletionRequest, +) (*ReportCompletionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + inputData := input.ReportCompletionInput{ + SessionID: sessionID, + PartyID: req.PartyId, + PublicKey: req.PublicKey, + Signature: req.Signature, + } + + output, err := s.reportCompletionUC.Execute(ctx, inputData) + if err != nil { + return nil, toGRPCError(err) + } + + return &ReportCompletionResponse{ + Success: output.Success, + AllCompleted: output.AllCompleted, + }, nil +} + +// CloseSession closes a session +func (s *SessionCoordinatorServer) CloseSession( + ctx context.Context, + req *CloseSessionRequest, +) (*CloseSessionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + err = s.closeSessionUC.Execute(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + + return &CloseSessionResponse{ + Success: true, + }, nil +} + +// toGRPCError converts domain errors to gRPC errors +func toGRPCError(err error) error { + switch err { + case entities.ErrSessionExpired: + return status.Error(codes.DeadlineExceeded, err.Error()) + case entities.ErrSessionFull: + return status.Error(codes.ResourceExhausted, err.Error()) + case entities.ErrParticipantNotFound: + return status.Error(codes.NotFound, err.Error()) + case entities.ErrSessionNotInProgress: + return status.Error(codes.FailedPrecondition, err.Error()) + case entities.ErrInvalidSessionType: + return status.Error(codes.InvalidArgument, err.Error()) + default: + return status.Error(codes.Internal, err.Error()) + } +} + +// Request/Response types (normally generated from proto) +// These are simplified versions - actual implementation would use generated proto types + +type CreateSessionRequest struct { + SessionType string + ThresholdN int32 + ThresholdT int32 + Participants []*ParticipantInfoProto + MessageHash []byte + ExpiresInSeconds int64 +} + +type CreateSessionResponse struct { + SessionId string + JoinTokens map[string]string + ExpiresAt int64 +} + +type ParticipantInfoProto struct { + PartyId string + DeviceInfo *DeviceInfo +} + +type DeviceInfo struct { + DeviceType string + DeviceId string + Platform string + AppVersion string +} + +type JoinSessionRequest struct { + SessionId string + PartyId string + JoinToken string + DeviceInfo *DeviceInfo +} + +type JoinSessionResponse struct { + Success bool + SessionInfo *SessionInfo + OtherParties []*PartyInfo +} + +type SessionInfo struct { + SessionId string + SessionType string + ThresholdN int32 + ThresholdT int32 + MessageHash []byte + Status string +} + +type PartyInfo struct { + PartyId string + PartyIndex int32 + DeviceInfo *DeviceInfo +} + +type GetSessionStatusRequest struct { + SessionId string +} + +type GetSessionStatusResponse struct { + Status string + CompletedParties int32 + TotalParties int32 + PublicKey []byte + Signature []byte +} + +type ReportCompletionRequest struct { + SessionId string + PartyId string + PublicKey []byte + Signature []byte +} + +type ReportCompletionResponse struct { + Success bool + AllCompleted bool +} + +type CloseSessionRequest struct { + SessionId string +} + +type CloseSessionResponse struct { + Success bool +} diff --git a/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go b/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go new file mode 100644 index 00000000..5b359c6a --- /dev/null +++ b/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go @@ -0,0 +1,543 @@ +package http + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/rwadurian/mpc-system/pkg/logger" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" + "go.uber.org/zap" +) + +// SessionHTTPHandler handles HTTP requests for session management +type SessionHTTPHandler struct { + createSessionUC *use_cases.CreateSessionUseCase + joinSessionUC *use_cases.JoinSessionUseCase + getSessionStatusUC *use_cases.GetSessionStatusUseCase + reportCompletionUC *use_cases.ReportCompletionUseCase + closeSessionUC *use_cases.CloseSessionUseCase + sessionRepo repositories.SessionRepository +} + +// NewSessionHTTPHandler creates a new HTTP handler +func NewSessionHTTPHandler( + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, + sessionRepo repositories.SessionRepository, +) *SessionHTTPHandler { + return &SessionHTTPHandler{ + createSessionUC: createSessionUC, + joinSessionUC: joinSessionUC, + getSessionStatusUC: getSessionStatusUC, + reportCompletionUC: reportCompletionUC, + closeSessionUC: closeSessionUC, + sessionRepo: sessionRepo, + } +} + +// RegisterRoutes registers HTTP routes +func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) { + sessions := r.Group("/sessions") + { + sessions.POST("", h.CreateSession) + sessions.POST("/join", h.JoinSessionByToken) + sessions.POST("/:id/join", h.JoinSession) + sessions.GET("/:id", h.GetSessionStatus) + sessions.GET("/:id/status", h.GetSessionStatus) + sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady) + sessions.POST("/:id/start", h.StartSession) + sessions.POST("/:id/complete", h.ReportCompletion) + sessions.DELETE("/:id", h.CloseSession) + } +} + +// CreateSessionRequest is the HTTP request body for creating a session +type CreateSessionRequest struct { + SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"` + ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"` + ThresholdT int `json:"thresholdT" binding:"required,min=1"` + CreatedBy string `json:"createdBy" binding:"required"` + Participants []ParticipantInfoRequest `json:"participants,omitempty"` + MessageHash string `json:"messageHash,omitempty"` + ExpiresIn int64 `json:"expiresIn,omitempty"` +} + +// ParticipantInfoRequest represents a participant in the request +type ParticipantInfoRequest struct { + PartyID string `json:"party_id" binding:"required"` + DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` +} + +// DeviceInfoRequest represents device info in the request +type DeviceInfoRequest struct { + DeviceType string `json:"device_type" binding:"required"` + DeviceID string `json:"device_id,omitempty"` + Platform string `json:"platform,omitempty"` + AppVersion string `json:"app_version,omitempty"` +} + +// CreateSessionResponse is the HTTP response for creating a session +type CreateSessionResponse struct { + SessionID string `json:"sessionId"` + JoinToken string `json:"joinToken"` + Status string `json:"status"` + ExpiresAt int64 `json:"expiresAt,omitempty"` +} + +// CreateSession handles POST /sessions +func (h *SessionHTTPHandler) CreateSession(c *gin.Context) { + var req CreateSessionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Validate threshold + if req.ThresholdT > req.ThresholdN { + c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"}) + return + } + + // Convert request to input + participants := make([]input.ParticipantInfo, len(req.Participants)) + for i, p := range req.Participants { + participants[i] = input.ParticipantInfo{ + PartyID: p.PartyID, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + var messageHash []byte + if req.MessageHash != "" { + messageHash = []byte(req.MessageHash) + } + + expiresIn := time.Duration(req.ExpiresIn) * time.Second + if expiresIn == 0 { + expiresIn = 10 * time.Minute // Default + } + + inputData := input.CreateSessionInput{ + InitiatorID: req.CreatedBy, + SessionType: req.SessionType, + ThresholdN: req.ThresholdN, + ThresholdT: req.ThresholdT, + Participants: participants, + MessageHash: messageHash, + ExpiresIn: expiresIn, + } + + output, err := h.createSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to create session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Extract a single join token (for E2E compatibility) + // In a real scenario with pre-registered participants, we'd return all tokens + // For now, generate a universal join token + joinToken := output.SessionID.String() // Use session ID as join token for simplicity + if len(output.JoinTokens) > 0 { + // If there are participant-specific tokens, use the first one + for _, token := range output.JoinTokens { + joinToken = token + break + } + } + + c.JSON(http.StatusCreated, CreateSessionResponse{ + SessionID: output.SessionID.String(), + JoinToken: joinToken, + Status: "created", + ExpiresAt: output.ExpiresAt.UnixMilli(), + }) +} + +// JoinSessionRequest is the HTTP request body for joining a session +type JoinSessionRequest struct { + PartyID string `json:"party_id" binding:"required"` + JoinToken string `json:"join_token" binding:"required"` + DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` +} + +// JoinSessionResponse is the HTTP response for joining a session +type JoinSessionResponse struct { + Success bool `json:"success"` + SessionInfo SessionInfoDTO `json:"session_info"` + OtherParties []PartyInfoDTO `json:"other_parties"` +} + +// SessionInfoDTO represents session info in responses +type SessionInfoDTO struct { + SessionID string `json:"session_id"` + SessionType string `json:"session_type"` + ThresholdN int `json:"threshold_n"` + ThresholdT int `json:"threshold_t"` + MessageHash string `json:"message_hash,omitempty"` + Status string `json:"status"` +} + +// PartyInfoDTO represents party info in responses +type PartyInfoDTO struct { + PartyID string `json:"party_id"` + PartyIndex int `json:"party_index"` + DeviceInfo DeviceInfoRequest `json:"device_info"` +} + +// JoinSessionByTokenRequest is the HTTP request body for joining by token +type JoinSessionByTokenRequest struct { + JoinToken string `json:"joinToken" binding:"required"` + PartyID string `json:"partyId" binding:"required"` + DeviceType string `json:"deviceType" binding:"required"` + DeviceID string `json:"deviceId,omitempty"` +} + +// JoinSessionByTokenResponse is the HTTP response for joining by token +type JoinSessionByTokenResponse struct { + SessionID string `json:"sessionId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` + Participants []ParticipantStatusDTO `json:"participants"` +} + +// ParticipantStatusDTO represents participant status in responses +type ParticipantStatusDTO struct { + PartyID string `json:"partyId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` +} + +// JoinSessionByToken handles POST /sessions/join (join by token without session ID) +func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) { + var req JoinSessionByTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error("failed to bind join session request", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Pass empty UUID - the use case will extract session ID from the JWT + inputData := input.JoinSessionInput{ + SessionID: uuid.Nil, + PartyID: req.PartyID, + JoinToken: req.JoinToken, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceType), + DeviceID: req.DeviceID, + }, + } + + output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to join session", zap.Error(err)) + // Return 401 for authentication/token errors + if err.Error() == "invalid token" || err.Error() == "token expired" { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Build participant status list + myPartyIndex := output.PartyIndex + + participants := make([]ParticipantStatusDTO, 0) + participants = append(participants, ParticipantStatusDTO{ + PartyID: req.PartyID, + Status: "joined", + }) + for _, p := range output.OtherParties { + participants = append(participants, ParticipantStatusDTO{ + PartyID: p.PartyID, + Status: "joined", + }) + } + + c.JSON(http.StatusOK, JoinSessionByTokenResponse{ + SessionID: output.SessionInfo.SessionID.String(), + PartyIndex: myPartyIndex, + Status: output.SessionInfo.Status, + Participants: participants, + }) +} + +// JoinSession handles POST /sessions/:id/join +func (h *SessionHTTPHandler) JoinSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + var req JoinSessionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + inputData := input.JoinSessionInput{ + SessionID: sessionID, + PartyID: req.PartyID, + JoinToken: req.JoinToken, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), + DeviceID: req.DeviceInfo.DeviceID, + Platform: req.DeviceInfo.Platform, + AppVersion: req.DeviceInfo.AppVersion, + }, + } + + output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to join session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + otherParties := make([]PartyInfoDTO, len(output.OtherParties)) + for i, p := range output.OtherParties { + otherParties[i] = PartyInfoDTO{ + PartyID: p.PartyID, + PartyIndex: p.PartyIndex, + DeviceInfo: DeviceInfoRequest{ + DeviceType: string(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + c.JSON(http.StatusOK, JoinSessionResponse{ + Success: output.Success, + SessionInfo: SessionInfoDTO{ + SessionID: output.SessionInfo.SessionID.String(), + SessionType: output.SessionInfo.SessionType, + ThresholdN: output.SessionInfo.ThresholdN, + ThresholdT: output.SessionInfo.ThresholdT, + MessageHash: string(output.SessionInfo.MessageHash), + Status: output.SessionInfo.Status, + }, + OtherParties: otherParties, + }) +} + +// SessionStatusResponse is the HTTP response for session status +type SessionStatusResponse struct { + SessionID string `json:"sessionId"` + Status string `json:"status"` + ThresholdT int `json:"thresholdT"` + ThresholdN int `json:"thresholdN"` + Participants []ParticipantStatusDTO `json:"participants"` + PublicKey string `json:"publicKey,omitempty"` +} + +// GetSessionStatus handles GET /sessions/:id/status +func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID) + if err != nil { + logger.Error("failed to get session status", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + // Convert participants to DTO + participants := make([]ParticipantStatusDTO, len(output.Participants)) + for i, p := range output.Participants { + participants[i] = ParticipantStatusDTO{ + PartyID: p.PartyID, + PartyIndex: p.PartyIndex, + Status: p.Status, + } + } + + c.JSON(http.StatusOK, SessionStatusResponse{ + SessionID: output.SessionID.String(), + Status: output.Status, + ThresholdT: output.ThresholdT, + ThresholdN: output.ThresholdN, + Participants: participants, + PublicKey: string(output.PublicKey), + }) +} + +// ReportCompletionRequest is the HTTP request for reporting completion +type ReportCompletionRequest struct { + PartyID string `json:"party_id" binding:"required"` + PublicKey string `json:"public_key,omitempty"` + Signature string `json:"signature,omitempty"` +} + +// ReportCompletion handles POST /sessions/:id/complete +func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + var req ReportCompletionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + inputData := input.ReportCompletionInput{ + SessionID: sessionID, + PartyID: req.PartyID, + PublicKey: []byte(req.PublicKey), + Signature: []byte(req.Signature), + } + + output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to report completion", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": output.Success, + "all_completed": output.AllCompleted, + }) +} + +// CloseSession handles DELETE /sessions/:id +func (h *SessionHTTPHandler) CloseSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + err = h.closeSessionUC.Execute(c.Request.Context(), sessionID) + if err != nil { + logger.Error("failed to close session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready +func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + partyID := c.Param("partyId") + if partyID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"}) + return + } + + logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) + + // Load session + session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) + if err != nil { + if err == entities.ErrSessionNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + logger.Error("failed to load session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) + return + } + + // Create party ID value object + partyIDVO, err := value_objects.NewPartyID(partyID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"}) + return + } + + // Update participant status to ready + if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil { + logger.Error("failed to mark party as ready", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Save session + if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { + logger.Error("failed to save session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// StartSession handles POST /sessions/:id/start +func (h *SessionHTTPHandler) StartSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + logger.Info("starting session", zap.String("session_id", sessionID.String())) + + // Load session + session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) + if err != nil { + if err == entities.ErrSessionNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + logger.Error("failed to load session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) + return + } + + // Start the session + if err := session.Start(); err != nil { + logger.Error("failed to start session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Save session + if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { + logger.Error("failed to save session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// HealthCheck handles GET /health +func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + "service": "session-coordinator", + "time": time.Now().UTC().Format(time.RFC3339), + }) +} diff --git a/backend/mpc-system/services/session-coordinator/cmd/server/main.go b/backend/mpc-system/services/session-coordinator/cmd/server/main.go new file mode 100644 index 00000000..87cb0029 --- /dev/null +++ b/backend/mpc-system/services/session-coordinator/cmd/server/main.go @@ -0,0 +1,306 @@ +package main + +import ( + "context" + "database/sql" + "flag" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gin-gonic/gin" + _ "github.com/lib/pq" + amqp "github.com/rabbitmq/amqp091-go" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + + "github.com/rwadurian/mpc-system/pkg/config" + "github.com/rwadurian/mpc-system/pkg/jwt" + "github.com/rwadurian/mpc-system/pkg/logger" + grpcadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/grpc" + httphandler "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/http" + "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/postgres" + "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/rabbitmq" + redisadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/redis" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" + "go.uber.org/zap" +) + +func main() { + // Parse flags + configPath := flag.String("config", "", "Path to config file") + flag.Parse() + + // Load configuration + cfg, err := config.Load(*configPath) + if err != nil { + fmt.Printf("Failed to load config: %v\n", err) + os.Exit(1) + } + + // Initialize logger + if err := logger.Init(&logger.Config{ + Level: cfg.Logger.Level, + Encoding: cfg.Logger.Encoding, + }); err != nil { + fmt.Printf("Failed to initialize logger: %v\n", err) + os.Exit(1) + } + defer logger.Sync() + + logger.Info("Starting Session Coordinator Service", + zap.String("environment", cfg.Server.Environment), + zap.Int("grpc_port", cfg.Server.GRPCPort), + zap.Int("http_port", cfg.Server.HTTPPort)) + + // Initialize database connection + db, err := initDatabase(cfg.Database) + if err != nil { + logger.Fatal("Failed to connect to database", zap.Error(err)) + } + defer db.Close() + + // Initialize Redis connection + redisClient := initRedis(cfg.Redis) + defer redisClient.Close() + + // Initialize RabbitMQ connection + rabbitConn, err := initRabbitMQ(cfg.RabbitMQ) + if err != nil { + logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err)) + } + defer rabbitConn.Close() + + // Initialize repositories and adapters + sessionRepo := postgres.NewSessionPostgresRepo(db) + messageRepo := postgres.NewMessagePostgresRepo(db) + sessionCache := redisadapter.NewSessionCacheAdapter(redisClient) + eventPublisher, err := rabbitmq.NewEventPublisherAdapter(rabbitConn) + if err != nil { + logger.Fatal("Failed to create event publisher", zap.Error(err)) + } + defer eventPublisher.Close() + + // Initialize JWT service + jwtService := jwt.NewJWTService( + cfg.JWT.SecretKey, + cfg.JWT.Issuer, + cfg.JWT.TokenExpiry, + cfg.JWT.RefreshExpiry, + ) + + // Initialize use cases + createSessionUC := use_cases.NewCreateSessionUseCase(sessionRepo, jwtService, eventPublisher) + joinSessionUC := use_cases.NewJoinSessionUseCase(sessionRepo, jwtService, eventPublisher) + getSessionStatusUC := use_cases.NewGetSessionStatusUseCase(sessionRepo) + reportCompletionUC := use_cases.NewReportCompletionUseCase(sessionRepo, eventPublisher) + closeSessionUC := use_cases.NewCloseSessionUseCase(sessionRepo, messageRepo, eventPublisher) + expireSessionsUC := use_cases.NewExpireSessionsUseCase(sessionRepo, eventPublisher) + + // Start session expiration background job + go runSessionExpiration(expireSessionsUC) + + // Create shutdown context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start servers + errChan := make(chan error, 2) + + // Start gRPC server + go func() { + if err := startGRPCServer( + cfg, + createSessionUC, + joinSessionUC, + getSessionStatusUC, + reportCompletionUC, + closeSessionUC, + ); err != nil { + errChan <- fmt.Errorf("gRPC server error: %w", err) + } + }() + + // Start HTTP server + go func() { + if err := startHTTPServer( + cfg, + createSessionUC, + joinSessionUC, + getSessionStatusUC, + reportCompletionUC, + closeSessionUC, + sessionRepo, + ); err != nil { + errChan <- fmt.Errorf("HTTP server error: %w", err) + } + }() + + // Wait for shutdown signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case sig := <-sigChan: + logger.Info("Received shutdown signal", zap.String("signal", sig.String())) + case err := <-errChan: + logger.Error("Server error", zap.Error(err)) + } + + // Graceful shutdown + logger.Info("Shutting down...") + cancel() + + // Give services time to shutdown gracefully + time.Sleep(5 * time.Second) + logger.Info("Shutdown complete") + + _ = ctx + _ = sessionCache +} + +func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) { + db, err := sql.Open("postgres", cfg.DSN()) + if err != nil { + return nil, err + } + + db.SetMaxOpenConns(cfg.MaxOpenConns) + db.SetMaxIdleConns(cfg.MaxIdleConns) + db.SetConnMaxLifetime(cfg.ConnMaxLife) + + // Test connection + if err := db.Ping(); err != nil { + return nil, err + } + + logger.Info("Connected to PostgreSQL") + return db, nil +} + +func initRedis(cfg config.RedisConfig) *redis.Client { + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr(), + Password: cfg.Password, + DB: cfg.DB, + }) + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + logger.Warn("Redis connection failed, continuing without cache", zap.Error(err)) + } else { + logger.Info("Connected to Redis") + } + + return client +} + +func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) { + conn, err := amqp.Dial(cfg.URL()) + if err != nil { + return nil, err + } + + logger.Info("Connected to RabbitMQ") + return conn, nil +} + +func startGRPCServer( + cfg *config.Config, + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, +) error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort)) + if err != nil { + return err + } + + grpcServer := grpc.NewServer() + + // Register services (using our custom handler, not generated proto) + // In production, you would register the generated proto service + _ = grpcadapter.NewSessionCoordinatorServer( + createSessionUC, + joinSessionUC, + getSessionStatusUC, + reportCompletionUC, + closeSessionUC, + ) + + // Enable reflection for debugging + reflection.Register(grpcServer) + + logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort)) + return grpcServer.Serve(listener) +} + +func startHTTPServer( + cfg *config.Config, + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, + sessionRepo repositories.SessionRepository, +) error { + // Set Gin mode + if cfg.Server.Environment == "production" { + gin.SetMode(gin.ReleaseMode) + } + + router := gin.New() + router.Use(gin.Recovery()) + router.Use(gin.Logger()) + + // Create HTTP handler + httpHandler := httphandler.NewSessionHTTPHandler( + createSessionUC, + joinSessionUC, + getSessionStatusUC, + reportCompletionUC, + closeSessionUC, + sessionRepo, + ) + + // Health check + router.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + "service": "session-coordinator", + }) + }) + + // Register API routes + api := router.Group("/api/v1") + httpHandler.RegisterRoutes(api) + + logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort)) + return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort)) +} + +func runSessionExpiration(expireSessionsUC *use_cases.ExpireSessionsUseCase) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + count, err := expireSessionsUC.Execute(ctx) + cancel() + + if err != nil { + logger.Error("Failed to expire sessions", zap.Error(err)) + } else if count > 0 { + logger.Info("Expired stale sessions", zap.Int("count", count)) + } + } +} diff --git a/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go b/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go new file mode 100644 index 00000000..23471771 --- /dev/null +++ b/backend/mpc-system/services/session-coordinator/domain/entities/mpc_session.go @@ -0,0 +1,342 @@ +package entities + +import ( + "errors" + "time" + + "github.com/google/uuid" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" +) + +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionFull = errors.New("session is full") + ErrSessionExpired = errors.New("session expired") + ErrSessionNotInProgress = errors.New("session not in progress") + ErrParticipantNotFound = errors.New("participant not found") + ErrInvalidSessionType = errors.New("invalid session type") + ErrInvalidStatusTransition = errors.New("invalid status transition") +) + +// SessionType represents the type of MPC session +type SessionType string + +const ( + SessionTypeKeygen SessionType = "keygen" + SessionTypeSign SessionType = "sign" +) + +// IsValid checks if the session type is valid +func (t SessionType) IsValid() bool { + return t == SessionTypeKeygen || t == SessionTypeSign +} + +// MPCSession represents an MPC session +// Coordinator only manages session metadata, does not participate in MPC computation +type MPCSession struct { + ID value_objects.SessionID + SessionType SessionType + Threshold value_objects.Threshold + Participants []*Participant + Status value_objects.SessionStatus + MessageHash []byte // Used for Sign sessions + PublicKey []byte // Group public key after Keygen completion + CreatedBy string + CreatedAt time.Time + UpdatedAt time.Time + ExpiresAt time.Time + CompletedAt *time.Time +} + +// NewMPCSession creates a new MPC session +func NewMPCSession( + sessionType SessionType, + threshold value_objects.Threshold, + createdBy string, + expiresIn time.Duration, + messageHash []byte, // Only for Sign sessions +) (*MPCSession, error) { + if !sessionType.IsValid() { + return nil, ErrInvalidSessionType + } + + if sessionType == SessionTypeSign && len(messageHash) == 0 { + return nil, errors.New("message hash required for sign session") + } + + now := time.Now().UTC() + return &MPCSession{ + ID: value_objects.NewSessionID(), + SessionType: sessionType, + Threshold: threshold, + Participants: make([]*Participant, 0, threshold.N()), + Status: value_objects.SessionStatusCreated, + MessageHash: messageHash, + CreatedBy: createdBy, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: now.Add(expiresIn), + }, nil +} + +// AddParticipant adds a participant to the session +func (s *MPCSession) AddParticipant(p *Participant) error { + if len(s.Participants) >= s.Threshold.N() { + return ErrSessionFull + } + s.Participants = append(s.Participants, p) + s.UpdatedAt = time.Now().UTC() + return nil +} + +// GetParticipant gets a participant by party ID +func (s *MPCSession) GetParticipant(partyID value_objects.PartyID) (*Participant, error) { + for _, p := range s.Participants { + if p.PartyID.Equals(partyID) { + return p, nil + } + } + return nil, ErrParticipantNotFound +} + +// UpdateParticipantStatus updates a participant's status +func (s *MPCSession) UpdateParticipantStatus(partyID value_objects.PartyID, status value_objects.ParticipantStatus) error { + for _, p := range s.Participants { + if p.PartyID.Equals(partyID) { + switch status { + case value_objects.ParticipantStatusJoined: + return p.Join() + case value_objects.ParticipantStatusReady: + return p.MarkReady() + case value_objects.ParticipantStatusCompleted: + return p.MarkCompleted() + case value_objects.ParticipantStatusFailed: + p.MarkFailed() + return nil + default: + return errors.New("invalid status") + } + } + } + return ErrParticipantNotFound +} + +// CanStart checks if all participants have joined and the session can start +func (s *MPCSession) CanStart() bool { + if len(s.Participants) != s.Threshold.N() { + return false + } + + readyCount := 0 + for _, p := range s.Participants { + // Accept participants in either joined or ready status + if p.IsJoined() || p.IsReady() { + readyCount++ + } + } + return readyCount == s.Threshold.N() +} + +// Start transitions the session to in_progress +func (s *MPCSession) Start() error { + // If already in progress, just return success (idempotent) + if s.Status == value_objects.SessionStatusInProgress { + return nil + } + if !s.Status.CanTransitionTo(value_objects.SessionStatusInProgress) { + return ErrInvalidStatusTransition + } + if !s.CanStart() { + return errors.New("not all participants have joined") + } + s.Status = value_objects.SessionStatusInProgress + s.UpdatedAt = time.Now().UTC() + return nil +} + +// Complete marks the session as completed +func (s *MPCSession) Complete(publicKey []byte) error { + if !s.Status.CanTransitionTo(value_objects.SessionStatusCompleted) { + return ErrInvalidStatusTransition + } + s.Status = value_objects.SessionStatusCompleted + s.PublicKey = publicKey + now := time.Now().UTC() + s.CompletedAt = &now + s.UpdatedAt = now + return nil +} + +// Fail marks the session as failed +func (s *MPCSession) Fail() error { + if !s.Status.CanTransitionTo(value_objects.SessionStatusFailed) { + return ErrInvalidStatusTransition + } + s.Status = value_objects.SessionStatusFailed + s.UpdatedAt = time.Now().UTC() + return nil +} + +// Expire marks the session as expired +func (s *MPCSession) Expire() error { + if !s.Status.CanTransitionTo(value_objects.SessionStatusExpired) { + return ErrInvalidStatusTransition + } + s.Status = value_objects.SessionStatusExpired + s.UpdatedAt = time.Now().UTC() + return nil +} + +// IsExpired checks if the session has expired +func (s *MPCSession) IsExpired() bool { + return time.Now().UTC().After(s.ExpiresAt) +} + +// IsActive checks if the session is active +func (s *MPCSession) IsActive() bool { + return s.Status.IsActive() && !s.IsExpired() +} + +// IsParticipant checks if a party is a participant +func (s *MPCSession) IsParticipant(partyID value_objects.PartyID) bool { + for _, p := range s.Participants { + if p.PartyID.Equals(partyID) { + return true + } + } + return false +} + +// AllCompleted checks if all participants have completed +func (s *MPCSession) AllCompleted() bool { + for _, p := range s.Participants { + if !p.IsCompleted() { + return false + } + } + return true +} + +// CompletedCount returns the number of completed participants +func (s *MPCSession) CompletedCount() int { + count := 0 + for _, p := range s.Participants { + if p.IsCompleted() { + count++ + } + } + return count +} + +// JoinedCount returns the number of joined participants +func (s *MPCSession) JoinedCount() int { + count := 0 + for _, p := range s.Participants { + if p.IsJoined() { + count++ + } + } + return count +} + +// GetPartyIDs returns all party IDs +func (s *MPCSession) GetPartyIDs() []string { + ids := make([]string, len(s.Participants)) + for i, p := range s.Participants { + ids[i] = p.PartyID.String() + } + return ids +} + +// GetOtherParties returns participants except the specified party +func (s *MPCSession) GetOtherParties(excludePartyID value_objects.PartyID) []*Participant { + others := make([]*Participant, 0, len(s.Participants)-1) + for _, p := range s.Participants { + if !p.PartyID.Equals(excludePartyID) { + others = append(others, p) + } + } + return others +} + +// ToDTO converts to a DTO for API responses +func (s *MPCSession) ToDTO() SessionDTO { + participants := make([]ParticipantDTO, len(s.Participants)) + for i, p := range s.Participants { + participants[i] = ParticipantDTO{ + PartyID: p.PartyID.String(), + PartyIndex: p.PartyIndex, + Status: p.Status.String(), + DeviceType: string(p.DeviceInfo.DeviceType), + } + } + + return SessionDTO{ + ID: s.ID.String(), + SessionType: string(s.SessionType), + ThresholdN: s.Threshold.N(), + ThresholdT: s.Threshold.T(), + Participants: participants, + Status: s.Status.String(), + CreatedAt: s.CreatedAt, + ExpiresAt: s.ExpiresAt, + } +} + +// SessionDTO is a data transfer object for sessions +type SessionDTO struct { + ID string `json:"id"` + SessionType string `json:"session_type"` + ThresholdN int `json:"threshold_n"` + ThresholdT int `json:"threshold_t"` + Participants []ParticipantDTO `json:"participants"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// ParticipantDTO is a data transfer object for participants +type ParticipantDTO struct { + PartyID string `json:"party_id"` + PartyIndex int `json:"party_index"` + Status string `json:"status"` + DeviceType string `json:"device_type"` +} + +// Reconstruct reconstructs an MPCSession from database +func ReconstructSession( + id uuid.UUID, + sessionType string, + thresholdT, thresholdN int, + status string, + messageHash, publicKey []byte, + createdBy string, + createdAt, updatedAt, expiresAt time.Time, + completedAt *time.Time, + participants []*Participant, +) (*MPCSession, error) { + sessionStatus, err := value_objects.NewSessionStatus(status) + if err != nil { + return nil, err + } + + threshold, err := value_objects.NewThreshold(thresholdT, thresholdN) + if err != nil { + return nil, err + } + + return &MPCSession{ + ID: value_objects.SessionIDFromUUID(id), + SessionType: SessionType(sessionType), + Threshold: threshold, + Participants: participants, + Status: sessionStatus, + MessageHash: messageHash, + PublicKey: publicKey, + CreatedBy: createdBy, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + CompletedAt: completedAt, + }, nil +} diff --git a/backend/mpc-system/tests/e2e/keygen_flow_test.go b/backend/mpc-system/tests/e2e/keygen_flow_test.go new file mode 100644 index 00000000..9d17486b --- /dev/null +++ b/backend/mpc-system/tests/e2e/keygen_flow_test.go @@ -0,0 +1,354 @@ +//go:build e2e + +package e2e_test + +import ( + "bytes" + "encoding/json" + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type KeygenFlowTestSuite struct { + suite.Suite + baseURL string + client *http.Client +} + +func TestKeygenFlowSuite(t *testing.T) { + if testing.Short() { + t.Skip("Skipping e2e test in short mode") + } + suite.Run(t, new(KeygenFlowTestSuite)) +} + +func (s *KeygenFlowTestSuite) SetupSuite() { + s.baseURL = os.Getenv("SESSION_COORDINATOR_URL") + if s.baseURL == "" { + s.baseURL = "http://localhost:8080" + } + + s.client = &http.Client{ + Timeout: 30 * time.Second, + } + + // Wait for service to be ready + s.waitForService() +} + +func (s *KeygenFlowTestSuite) waitForService() { + maxRetries := 30 + for i := 0; i < maxRetries; i++ { + resp, err := s.client.Get(s.baseURL + "/health") + if err == nil && resp.StatusCode == http.StatusOK { + resp.Body.Close() + return + } + if resp != nil { + resp.Body.Close() + } + time.Sleep(time.Second) + } + s.T().Fatal("Service not ready after waiting") +} + +type CreateSessionRequest struct { + SessionType string `json:"sessionType"` + ThresholdT int `json:"thresholdT"` + ThresholdN int `json:"thresholdN"` + CreatedBy string `json:"createdBy"` +} + +type CreateSessionResponse struct { + SessionID string `json:"sessionId"` + JoinToken string `json:"joinToken"` + Status string `json:"status"` +} + +type JoinSessionRequest struct { + JoinToken string `json:"joinToken"` + PartyID string `json:"partyId"` + DeviceType string `json:"deviceType"` + DeviceID string `json:"deviceId"` +} + +type JoinSessionResponse struct { + SessionID string `json:"sessionId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` + Participants []struct { + PartyID string `json:"partyId"` + Status string `json:"status"` + } `json:"participants"` +} + +type SessionStatusResponse struct { + SessionID string `json:"sessionId"` + Status string `json:"status"` + ThresholdT int `json:"thresholdT"` + ThresholdN int `json:"thresholdN"` + Participants []struct { + PartyID string `json:"partyId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` + } `json:"participants"` +} + +func (s *KeygenFlowTestSuite) TestCompleteKeygenFlow() { + // Step 1: Create a new keygen session + createReq := CreateSessionRequest{ + SessionType: "keygen", + ThresholdT: 2, + ThresholdN: 3, + CreatedBy: "e2e_test_user", + } + + createResp := s.createSession(createReq) + require.NotEmpty(s.T(), createResp.SessionID) + require.NotEmpty(s.T(), createResp.JoinToken) + assert.Equal(s.T(), "created", createResp.Status) + + sessionID := createResp.SessionID + joinToken := createResp.JoinToken + + // Step 2: First party joins + joinReq1 := JoinSessionRequest{ + JoinToken: joinToken, + PartyID: "party_user_device", + DeviceType: "iOS", + DeviceID: "device_001", + } + joinResp1 := s.joinSession(joinReq1) + assert.Equal(s.T(), sessionID, joinResp1.SessionID) + assert.Equal(s.T(), 0, joinResp1.PartyIndex) + + // Step 3: Second party joins + joinReq2 := JoinSessionRequest{ + JoinToken: joinToken, + PartyID: "party_server", + DeviceType: "server", + DeviceID: "server_001", + } + joinResp2 := s.joinSession(joinReq2) + assert.Equal(s.T(), sessionID, joinResp2.SessionID) + assert.Equal(s.T(), 1, joinResp2.PartyIndex) + + // Step 4: Third party joins + joinReq3 := JoinSessionRequest{ + JoinToken: joinToken, + PartyID: "party_recovery", + DeviceType: "recovery", + DeviceID: "recovery_001", + } + joinResp3 := s.joinSession(joinReq3) + assert.Equal(s.T(), sessionID, joinResp3.SessionID) + assert.Equal(s.T(), 2, joinResp3.PartyIndex) + + // Step 5: Check session status - should have all participants + statusResp := s.getSessionStatus(sessionID) + assert.Equal(s.T(), 3, len(statusResp.Participants)) + + // Step 6: Mark parties as ready + s.markPartyReady(sessionID, "party_user_device") + s.markPartyReady(sessionID, "party_server") + s.markPartyReady(sessionID, "party_recovery") + + // Step 7: Start the session + s.startSession(sessionID) + + // Step 8: Verify session is in progress + statusResp = s.getSessionStatus(sessionID) + assert.Equal(s.T(), "in_progress", statusResp.Status) + + // Step 9: Report completion (simulate keygen completion) + publicKey := []byte("test-group-public-key-from-keygen") + s.reportCompletion(sessionID, "party_user_device", publicKey) + + // Step 10: Verify session is completed + statusResp = s.getSessionStatus(sessionID) + assert.Equal(s.T(), "completed", statusResp.Status) +} + +func (s *KeygenFlowTestSuite) TestJoinSessionWithInvalidToken() { + joinReq := JoinSessionRequest{ + JoinToken: "invalid-token", + PartyID: "party_test", + DeviceType: "iOS", + DeviceID: "device_test", + } + + body, _ := json.Marshal(joinReq) + resp, err := s.client.Post( + s.baseURL+"/api/v1/sessions/join", + "application/json", + bytes.NewReader(body), + ) + require.NoError(s.T(), err) + defer resp.Body.Close() + + assert.Equal(s.T(), http.StatusUnauthorized, resp.StatusCode) +} + +func (s *KeygenFlowTestSuite) TestGetNonExistentSession() { + resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/00000000-0000-0000-0000-000000000000") + require.NoError(s.T(), err) + defer resp.Body.Close() + + assert.Equal(s.T(), http.StatusNotFound, resp.StatusCode) +} + +func (s *KeygenFlowTestSuite) TestExceedParticipantLimit() { + // Create session with 2 participants max + createReq := CreateSessionRequest{ + SessionType: "keygen", + ThresholdT: 2, + ThresholdN: 2, + CreatedBy: "e2e_test_user_limit", + } + + createResp := s.createSession(createReq) + joinToken := createResp.JoinToken + + // Join with 2 parties (should succeed) + for i := 0; i < 2; i++ { + joinReq := JoinSessionRequest{ + JoinToken: joinToken, + PartyID: "party_" + string(rune('a'+i)), + DeviceType: "test", + DeviceID: "device_" + string(rune('a'+i)), + } + s.joinSession(joinReq) + } + + // Try to join with 3rd party (should fail) + joinReq := JoinSessionRequest{ + JoinToken: joinToken, + PartyID: "party_extra", + DeviceType: "test", + DeviceID: "device_extra", + } + + body, _ := json.Marshal(joinReq) + resp, err := s.client.Post( + s.baseURL+"/api/v1/sessions/join", + "application/json", + bytes.NewReader(body), + ) + require.NoError(s.T(), err) + defer resp.Body.Close() + + assert.Equal(s.T(), http.StatusBadRequest, resp.StatusCode) +} + +// Helper methods + +func (s *KeygenFlowTestSuite) createSession(req CreateSessionRequest) CreateSessionResponse { + body, _ := json.Marshal(req) + resp, err := s.client.Post( + s.baseURL+"/api/v1/sessions", + "application/json", + bytes.NewReader(body), + ) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusCreated, resp.StatusCode) + + var result CreateSessionResponse + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(s.T(), err) + + return result +} + +func (s *KeygenFlowTestSuite) joinSession(req JoinSessionRequest) JoinSessionResponse { + body, _ := json.Marshal(req) + resp, err := s.client.Post( + s.baseURL+"/api/v1/sessions/join", + "application/json", + bytes.NewReader(body), + ) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusOK, resp.StatusCode) + + var result JoinSessionResponse + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(s.T(), err) + + return result +} + +func (s *KeygenFlowTestSuite) getSessionStatus(sessionID string) SessionStatusResponse { + resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/" + sessionID) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusOK, resp.StatusCode) + + var result SessionStatusResponse + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(s.T(), err) + + return result +} + +func (s *KeygenFlowTestSuite) markPartyReady(sessionID, partyID string) { + req := map[string]string{"partyId": partyID} + body, _ := json.Marshal(req) + + httpReq, _ := http.NewRequest( + http.MethodPut, + s.baseURL+"/api/v1/sessions/"+sessionID+"/parties/"+partyID+"/ready", + bytes.NewReader(body), + ) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(httpReq) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusOK, resp.StatusCode) +} + +func (s *KeygenFlowTestSuite) startSession(sessionID string) { + httpReq, _ := http.NewRequest( + http.MethodPost, + s.baseURL+"/api/v1/sessions/"+sessionID+"/start", + nil, + ) + + resp, err := s.client.Do(httpReq) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusOK, resp.StatusCode) +} + +func (s *KeygenFlowTestSuite) reportCompletion(sessionID string, partyID string, publicKey []byte) { + req := map[string]interface{}{ + "party_id": partyID, + "public_key": string(publicKey), + } + body, _ := json.Marshal(req) + + httpReq, _ := http.NewRequest( + http.MethodPost, + s.baseURL+"/api/v1/sessions/"+sessionID+"/complete", + bytes.NewReader(body), + ) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(httpReq) + require.NoError(s.T(), err) + defer resp.Body.Close() + + require.Equal(s.T(), http.StatusOK, resp.StatusCode) +}