fix: Implement MarkPartyReady and StartSession handlers, update domain logic

- Add sessionRepo to HTTP handler for database operations
- Implement MarkPartyReady handler to update participant status
- Implement StartSession handler to start MPC sessions
- Update CanStart() to accept participants in 'ready' status
- Make Start() method idempotent to handle automatic + explicit starts
- Fix repository injection through dependency chain in main.go
- Add party_id parameter to test completion request
This commit is contained in:
hailin 2025-11-29 00:31:24 -08:00
parent 6fa4d7ac1d
commit 7531cbd07a
5 changed files with 1871 additions and 0 deletions

View File

@ -0,0 +1,326 @@
package grpc
import (
"context"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// SessionCoordinatorServer implements the gRPC SessionCoordinator service
type SessionCoordinatorServer struct {
createSessionUC *use_cases.CreateSessionUseCase
joinSessionUC *use_cases.JoinSessionUseCase
getSessionStatusUC *use_cases.GetSessionStatusUseCase
reportCompletionUC *use_cases.ReportCompletionUseCase
closeSessionUC *use_cases.CloseSessionUseCase
}
// NewSessionCoordinatorServer creates a new gRPC server
func NewSessionCoordinatorServer(
createSessionUC *use_cases.CreateSessionUseCase,
joinSessionUC *use_cases.JoinSessionUseCase,
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
reportCompletionUC *use_cases.ReportCompletionUseCase,
closeSessionUC *use_cases.CloseSessionUseCase,
) *SessionCoordinatorServer {
return &SessionCoordinatorServer{
createSessionUC: createSessionUC,
joinSessionUC: joinSessionUC,
getSessionStatusUC: getSessionStatusUC,
reportCompletionUC: reportCompletionUC,
closeSessionUC: closeSessionUC,
}
}
// CreateSession creates a new MPC session
func (s *SessionCoordinatorServer) CreateSession(
ctx context.Context,
req *CreateSessionRequest,
) (*CreateSessionResponse, error) {
// Convert request to input
participants := make([]input.ParticipantInfo, len(req.Participants))
for i, p := range req.Participants {
participants[i] = input.ParticipantInfo{
PartyID: p.PartyId,
DeviceInfo: entities.DeviceInfo{
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
DeviceID: p.DeviceInfo.DeviceId,
Platform: p.DeviceInfo.Platform,
AppVersion: p.DeviceInfo.AppVersion,
},
}
}
inputData := input.CreateSessionInput{
InitiatorID: "", // Could be extracted from auth context
SessionType: req.SessionType,
ThresholdN: int(req.ThresholdN),
ThresholdT: int(req.ThresholdT),
Participants: participants,
MessageHash: req.MessageHash,
ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second,
}
// Execute use case
output, err := s.createSessionUC.Execute(ctx, inputData)
if err != nil {
return nil, toGRPCError(err)
}
// Convert output to response
return &CreateSessionResponse{
SessionId: output.SessionID.String(),
JoinTokens: output.JoinTokens,
ExpiresAt: output.ExpiresAt.UnixMilli(),
}, nil
}
// JoinSession allows a participant to join a session
func (s *SessionCoordinatorServer) JoinSession(
ctx context.Context,
req *JoinSessionRequest,
) (*JoinSessionResponse, error) {
sessionID, err := uuid.Parse(req.SessionId)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
}
inputData := input.JoinSessionInput{
SessionID: sessionID,
PartyID: req.PartyId,
JoinToken: req.JoinToken,
DeviceInfo: entities.DeviceInfo{
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
DeviceID: req.DeviceInfo.DeviceId,
Platform: req.DeviceInfo.Platform,
AppVersion: req.DeviceInfo.AppVersion,
},
}
output, err := s.joinSessionUC.Execute(ctx, inputData)
if err != nil {
return nil, toGRPCError(err)
}
// Convert other parties to response format
otherParties := make([]*PartyInfo, len(output.OtherParties))
for i, p := range output.OtherParties {
otherParties[i] = &PartyInfo{
PartyId: p.PartyID,
PartyIndex: int32(p.PartyIndex),
DeviceInfo: &DeviceInfo{
DeviceType: string(p.DeviceInfo.DeviceType),
DeviceId: p.DeviceInfo.DeviceID,
Platform: p.DeviceInfo.Platform,
AppVersion: p.DeviceInfo.AppVersion,
},
}
}
return &JoinSessionResponse{
Success: output.Success,
SessionInfo: &SessionInfo{
SessionId: output.SessionInfo.SessionID.String(),
SessionType: output.SessionInfo.SessionType,
ThresholdN: int32(output.SessionInfo.ThresholdN),
ThresholdT: int32(output.SessionInfo.ThresholdT),
MessageHash: output.SessionInfo.MessageHash,
Status: output.SessionInfo.Status,
},
OtherParties: otherParties,
}, nil
}
// GetSessionStatus retrieves the status of a session
func (s *SessionCoordinatorServer) GetSessionStatus(
ctx context.Context,
req *GetSessionStatusRequest,
) (*GetSessionStatusResponse, error) {
sessionID, err := uuid.Parse(req.SessionId)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
}
output, err := s.getSessionStatusUC.Execute(ctx, sessionID)
if err != nil {
return nil, toGRPCError(err)
}
// Calculate completed parties from participants
completedParties := 0
for _, p := range output.Participants {
if p.Status == "completed" {
completedParties++
}
}
return &GetSessionStatusResponse{
Status: output.Status,
CompletedParties: int32(completedParties),
TotalParties: int32(len(output.Participants)),
PublicKey: output.PublicKey,
Signature: output.Signature,
}, nil
}
// ReportCompletion reports that a participant has completed
func (s *SessionCoordinatorServer) ReportCompletion(
ctx context.Context,
req *ReportCompletionRequest,
) (*ReportCompletionResponse, error) {
sessionID, err := uuid.Parse(req.SessionId)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
}
inputData := input.ReportCompletionInput{
SessionID: sessionID,
PartyID: req.PartyId,
PublicKey: req.PublicKey,
Signature: req.Signature,
}
output, err := s.reportCompletionUC.Execute(ctx, inputData)
if err != nil {
return nil, toGRPCError(err)
}
return &ReportCompletionResponse{
Success: output.Success,
AllCompleted: output.AllCompleted,
}, nil
}
// CloseSession closes a session
func (s *SessionCoordinatorServer) CloseSession(
ctx context.Context,
req *CloseSessionRequest,
) (*CloseSessionResponse, error) {
sessionID, err := uuid.Parse(req.SessionId)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
}
err = s.closeSessionUC.Execute(ctx, sessionID)
if err != nil {
return nil, toGRPCError(err)
}
return &CloseSessionResponse{
Success: true,
}, nil
}
// toGRPCError converts domain errors to gRPC errors
func toGRPCError(err error) error {
switch err {
case entities.ErrSessionExpired:
return status.Error(codes.DeadlineExceeded, err.Error())
case entities.ErrSessionFull:
return status.Error(codes.ResourceExhausted, err.Error())
case entities.ErrParticipantNotFound:
return status.Error(codes.NotFound, err.Error())
case entities.ErrSessionNotInProgress:
return status.Error(codes.FailedPrecondition, err.Error())
case entities.ErrInvalidSessionType:
return status.Error(codes.InvalidArgument, err.Error())
default:
return status.Error(codes.Internal, err.Error())
}
}
// Request/Response types (normally generated from proto)
// These are simplified versions - actual implementation would use generated proto types
type CreateSessionRequest struct {
SessionType string
ThresholdN int32
ThresholdT int32
Participants []*ParticipantInfoProto
MessageHash []byte
ExpiresInSeconds int64
}
type CreateSessionResponse struct {
SessionId string
JoinTokens map[string]string
ExpiresAt int64
}
type ParticipantInfoProto struct {
PartyId string
DeviceInfo *DeviceInfo
}
type DeviceInfo struct {
DeviceType string
DeviceId string
Platform string
AppVersion string
}
type JoinSessionRequest struct {
SessionId string
PartyId string
JoinToken string
DeviceInfo *DeviceInfo
}
type JoinSessionResponse struct {
Success bool
SessionInfo *SessionInfo
OtherParties []*PartyInfo
}
type SessionInfo struct {
SessionId string
SessionType string
ThresholdN int32
ThresholdT int32
MessageHash []byte
Status string
}
type PartyInfo struct {
PartyId string
PartyIndex int32
DeviceInfo *DeviceInfo
}
type GetSessionStatusRequest struct {
SessionId string
}
type GetSessionStatusResponse struct {
Status string
CompletedParties int32
TotalParties int32
PublicKey []byte
Signature []byte
}
type ReportCompletionRequest struct {
SessionId string
PartyId string
PublicKey []byte
Signature []byte
}
type ReportCompletionResponse struct {
Success bool
AllCompleted bool
}
type CloseSessionRequest struct {
SessionId string
}
type CloseSessionResponse struct {
Success bool
}

View File

@ -0,0 +1,543 @@
package http
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
"go.uber.org/zap"
)
// SessionHTTPHandler handles HTTP requests for session management
type SessionHTTPHandler struct {
createSessionUC *use_cases.CreateSessionUseCase
joinSessionUC *use_cases.JoinSessionUseCase
getSessionStatusUC *use_cases.GetSessionStatusUseCase
reportCompletionUC *use_cases.ReportCompletionUseCase
closeSessionUC *use_cases.CloseSessionUseCase
sessionRepo repositories.SessionRepository
}
// NewSessionHTTPHandler creates a new HTTP handler
func NewSessionHTTPHandler(
createSessionUC *use_cases.CreateSessionUseCase,
joinSessionUC *use_cases.JoinSessionUseCase,
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
reportCompletionUC *use_cases.ReportCompletionUseCase,
closeSessionUC *use_cases.CloseSessionUseCase,
sessionRepo repositories.SessionRepository,
) *SessionHTTPHandler {
return &SessionHTTPHandler{
createSessionUC: createSessionUC,
joinSessionUC: joinSessionUC,
getSessionStatusUC: getSessionStatusUC,
reportCompletionUC: reportCompletionUC,
closeSessionUC: closeSessionUC,
sessionRepo: sessionRepo,
}
}
// RegisterRoutes registers HTTP routes
func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) {
sessions := r.Group("/sessions")
{
sessions.POST("", h.CreateSession)
sessions.POST("/join", h.JoinSessionByToken)
sessions.POST("/:id/join", h.JoinSession)
sessions.GET("/:id", h.GetSessionStatus)
sessions.GET("/:id/status", h.GetSessionStatus)
sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady)
sessions.POST("/:id/start", h.StartSession)
sessions.POST("/:id/complete", h.ReportCompletion)
sessions.DELETE("/:id", h.CloseSession)
}
}
// CreateSessionRequest is the HTTP request body for creating a session
type CreateSessionRequest struct {
SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"`
ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"`
ThresholdT int `json:"thresholdT" binding:"required,min=1"`
CreatedBy string `json:"createdBy" binding:"required"`
Participants []ParticipantInfoRequest `json:"participants,omitempty"`
MessageHash string `json:"messageHash,omitempty"`
ExpiresIn int64 `json:"expiresIn,omitempty"`
}
// ParticipantInfoRequest represents a participant in the request
type ParticipantInfoRequest struct {
PartyID string `json:"party_id" binding:"required"`
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
}
// DeviceInfoRequest represents device info in the request
type DeviceInfoRequest struct {
DeviceType string `json:"device_type" binding:"required"`
DeviceID string `json:"device_id,omitempty"`
Platform string `json:"platform,omitempty"`
AppVersion string `json:"app_version,omitempty"`
}
// CreateSessionResponse is the HTTP response for creating a session
type CreateSessionResponse struct {
SessionID string `json:"sessionId"`
JoinToken string `json:"joinToken"`
Status string `json:"status"`
ExpiresAt int64 `json:"expiresAt,omitempty"`
}
// CreateSession handles POST /sessions
func (h *SessionHTTPHandler) CreateSession(c *gin.Context) {
var req CreateSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Validate threshold
if req.ThresholdT > req.ThresholdN {
c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"})
return
}
// Convert request to input
participants := make([]input.ParticipantInfo, len(req.Participants))
for i, p := range req.Participants {
participants[i] = input.ParticipantInfo{
PartyID: p.PartyID,
DeviceInfo: entities.DeviceInfo{
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
DeviceID: p.DeviceInfo.DeviceID,
Platform: p.DeviceInfo.Platform,
AppVersion: p.DeviceInfo.AppVersion,
},
}
}
var messageHash []byte
if req.MessageHash != "" {
messageHash = []byte(req.MessageHash)
}
expiresIn := time.Duration(req.ExpiresIn) * time.Second
if expiresIn == 0 {
expiresIn = 10 * time.Minute // Default
}
inputData := input.CreateSessionInput{
InitiatorID: req.CreatedBy,
SessionType: req.SessionType,
ThresholdN: req.ThresholdN,
ThresholdT: req.ThresholdT,
Participants: participants,
MessageHash: messageHash,
ExpiresIn: expiresIn,
}
output, err := h.createSessionUC.Execute(c.Request.Context(), inputData)
if err != nil {
logger.Error("failed to create session", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Extract a single join token (for E2E compatibility)
// In a real scenario with pre-registered participants, we'd return all tokens
// For now, generate a universal join token
joinToken := output.SessionID.String() // Use session ID as join token for simplicity
if len(output.JoinTokens) > 0 {
// If there are participant-specific tokens, use the first one
for _, token := range output.JoinTokens {
joinToken = token
break
}
}
c.JSON(http.StatusCreated, CreateSessionResponse{
SessionID: output.SessionID.String(),
JoinToken: joinToken,
Status: "created",
ExpiresAt: output.ExpiresAt.UnixMilli(),
})
}
// JoinSessionRequest is the HTTP request body for joining a session
type JoinSessionRequest struct {
PartyID string `json:"party_id" binding:"required"`
JoinToken string `json:"join_token" binding:"required"`
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
}
// JoinSessionResponse is the HTTP response for joining a session
type JoinSessionResponse struct {
Success bool `json:"success"`
SessionInfo SessionInfoDTO `json:"session_info"`
OtherParties []PartyInfoDTO `json:"other_parties"`
}
// SessionInfoDTO represents session info in responses
type SessionInfoDTO struct {
SessionID string `json:"session_id"`
SessionType string `json:"session_type"`
ThresholdN int `json:"threshold_n"`
ThresholdT int `json:"threshold_t"`
MessageHash string `json:"message_hash,omitempty"`
Status string `json:"status"`
}
// PartyInfoDTO represents party info in responses
type PartyInfoDTO struct {
PartyID string `json:"party_id"`
PartyIndex int `json:"party_index"`
DeviceInfo DeviceInfoRequest `json:"device_info"`
}
// JoinSessionByTokenRequest is the HTTP request body for joining by token
type JoinSessionByTokenRequest struct {
JoinToken string `json:"joinToken" binding:"required"`
PartyID string `json:"partyId" binding:"required"`
DeviceType string `json:"deviceType" binding:"required"`
DeviceID string `json:"deviceId,omitempty"`
}
// JoinSessionByTokenResponse is the HTTP response for joining by token
type JoinSessionByTokenResponse struct {
SessionID string `json:"sessionId"`
PartyIndex int `json:"partyIndex"`
Status string `json:"status"`
Participants []ParticipantStatusDTO `json:"participants"`
}
// ParticipantStatusDTO represents participant status in responses
type ParticipantStatusDTO struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
Status string `json:"status"`
}
// JoinSessionByToken handles POST /sessions/join (join by token without session ID)
func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) {
var req JoinSessionByTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error("failed to bind join session request", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Pass empty UUID - the use case will extract session ID from the JWT
inputData := input.JoinSessionInput{
SessionID: uuid.Nil,
PartyID: req.PartyID,
JoinToken: req.JoinToken,
DeviceInfo: entities.DeviceInfo{
DeviceType: entities.DeviceType(req.DeviceType),
DeviceID: req.DeviceID,
},
}
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
if err != nil {
logger.Error("failed to join session", zap.Error(err))
// Return 401 for authentication/token errors
if err.Error() == "invalid token" || err.Error() == "token expired" {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Build participant status list
myPartyIndex := output.PartyIndex
participants := make([]ParticipantStatusDTO, 0)
participants = append(participants, ParticipantStatusDTO{
PartyID: req.PartyID,
Status: "joined",
})
for _, p := range output.OtherParties {
participants = append(participants, ParticipantStatusDTO{
PartyID: p.PartyID,
Status: "joined",
})
}
c.JSON(http.StatusOK, JoinSessionByTokenResponse{
SessionID: output.SessionInfo.SessionID.String(),
PartyIndex: myPartyIndex,
Status: output.SessionInfo.Status,
Participants: participants,
})
}
// JoinSession handles POST /sessions/:id/join
func (h *SessionHTTPHandler) JoinSession(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
var req JoinSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
inputData := input.JoinSessionInput{
SessionID: sessionID,
PartyID: req.PartyID,
JoinToken: req.JoinToken,
DeviceInfo: entities.DeviceInfo{
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
DeviceID: req.DeviceInfo.DeviceID,
Platform: req.DeviceInfo.Platform,
AppVersion: req.DeviceInfo.AppVersion,
},
}
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
if err != nil {
logger.Error("failed to join session", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
otherParties := make([]PartyInfoDTO, len(output.OtherParties))
for i, p := range output.OtherParties {
otherParties[i] = PartyInfoDTO{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
DeviceInfo: DeviceInfoRequest{
DeviceType: string(p.DeviceInfo.DeviceType),
DeviceID: p.DeviceInfo.DeviceID,
Platform: p.DeviceInfo.Platform,
AppVersion: p.DeviceInfo.AppVersion,
},
}
}
c.JSON(http.StatusOK, JoinSessionResponse{
Success: output.Success,
SessionInfo: SessionInfoDTO{
SessionID: output.SessionInfo.SessionID.String(),
SessionType: output.SessionInfo.SessionType,
ThresholdN: output.SessionInfo.ThresholdN,
ThresholdT: output.SessionInfo.ThresholdT,
MessageHash: string(output.SessionInfo.MessageHash),
Status: output.SessionInfo.Status,
},
OtherParties: otherParties,
})
}
// SessionStatusResponse is the HTTP response for session status
type SessionStatusResponse struct {
SessionID string `json:"sessionId"`
Status string `json:"status"`
ThresholdT int `json:"thresholdT"`
ThresholdN int `json:"thresholdN"`
Participants []ParticipantStatusDTO `json:"participants"`
PublicKey string `json:"publicKey,omitempty"`
}
// GetSessionStatus handles GET /sessions/:id/status
func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID)
if err != nil {
logger.Error("failed to get session status", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
// Convert participants to DTO
participants := make([]ParticipantStatusDTO, len(output.Participants))
for i, p := range output.Participants {
participants[i] = ParticipantStatusDTO{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
Status: p.Status,
}
}
c.JSON(http.StatusOK, SessionStatusResponse{
SessionID: output.SessionID.String(),
Status: output.Status,
ThresholdT: output.ThresholdT,
ThresholdN: output.ThresholdN,
Participants: participants,
PublicKey: string(output.PublicKey),
})
}
// ReportCompletionRequest is the HTTP request for reporting completion
type ReportCompletionRequest struct {
PartyID string `json:"party_id" binding:"required"`
PublicKey string `json:"public_key,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ReportCompletion handles POST /sessions/:id/complete
func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
var req ReportCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
inputData := input.ReportCompletionInput{
SessionID: sessionID,
PartyID: req.PartyID,
PublicKey: []byte(req.PublicKey),
Signature: []byte(req.Signature),
}
output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData)
if err != nil {
logger.Error("failed to report completion", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"success": output.Success,
"all_completed": output.AllCompleted,
})
}
// CloseSession handles DELETE /sessions/:id
func (h *SessionHTTPHandler) CloseSession(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
err = h.closeSessionUC.Execute(c.Request.Context(), sessionID)
if err != nil {
logger.Error("failed to close session", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready
func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
partyID := c.Param("partyId")
if partyID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"})
return
}
logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID))
// Load session
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
if err != nil {
if err == entities.ErrSessionNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
logger.Error("failed to load session", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
return
}
// Create party ID value object
partyIDVO, err := value_objects.NewPartyID(partyID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"})
return
}
// Update participant status to ready
if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil {
logger.Error("failed to mark party as ready", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Save session
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
logger.Error("failed to save session", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// StartSession handles POST /sessions/:id/start
func (h *SessionHTTPHandler) StartSession(c *gin.Context) {
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
return
}
logger.Info("starting session", zap.String("session_id", sessionID.String()))
// Load session
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
if err != nil {
if err == entities.ErrSessionNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
logger.Error("failed to load session", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
return
}
// Start the session
if err := session.Start(); err != nil {
logger.Error("failed to start session", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Save session
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
logger.Error("failed to save session", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// HealthCheck handles GET /health
func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "session-coordinator",
"time": time.Now().UTC().Format(time.RFC3339),
})
}

View File

@ -0,0 +1,306 @@
package main
import (
"context"
"database/sql"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
_ "github.com/lib/pq"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/redis/go-redis/v9"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
"github.com/rwadurian/mpc-system/pkg/config"
"github.com/rwadurian/mpc-system/pkg/jwt"
"github.com/rwadurian/mpc-system/pkg/logger"
grpcadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/grpc"
httphandler "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/http"
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/postgres"
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/rabbitmq"
redisadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/redis"
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
"go.uber.org/zap"
)
func main() {
// Parse flags
configPath := flag.String("config", "", "Path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Failed to load config: %v\n", err)
os.Exit(1)
}
// Initialize logger
if err := logger.Init(&logger.Config{
Level: cfg.Logger.Level,
Encoding: cfg.Logger.Encoding,
}); err != nil {
fmt.Printf("Failed to initialize logger: %v\n", err)
os.Exit(1)
}
defer logger.Sync()
logger.Info("Starting Session Coordinator Service",
zap.String("environment", cfg.Server.Environment),
zap.Int("grpc_port", cfg.Server.GRPCPort),
zap.Int("http_port", cfg.Server.HTTPPort))
// Initialize database connection
db, err := initDatabase(cfg.Database)
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer db.Close()
// Initialize Redis connection
redisClient := initRedis(cfg.Redis)
defer redisClient.Close()
// Initialize RabbitMQ connection
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
if err != nil {
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
}
defer rabbitConn.Close()
// Initialize repositories and adapters
sessionRepo := postgres.NewSessionPostgresRepo(db)
messageRepo := postgres.NewMessagePostgresRepo(db)
sessionCache := redisadapter.NewSessionCacheAdapter(redisClient)
eventPublisher, err := rabbitmq.NewEventPublisherAdapter(rabbitConn)
if err != nil {
logger.Fatal("Failed to create event publisher", zap.Error(err))
}
defer eventPublisher.Close()
// Initialize JWT service
jwtService := jwt.NewJWTService(
cfg.JWT.SecretKey,
cfg.JWT.Issuer,
cfg.JWT.TokenExpiry,
cfg.JWT.RefreshExpiry,
)
// Initialize use cases
createSessionUC := use_cases.NewCreateSessionUseCase(sessionRepo, jwtService, eventPublisher)
joinSessionUC := use_cases.NewJoinSessionUseCase(sessionRepo, jwtService, eventPublisher)
getSessionStatusUC := use_cases.NewGetSessionStatusUseCase(sessionRepo)
reportCompletionUC := use_cases.NewReportCompletionUseCase(sessionRepo, eventPublisher)
closeSessionUC := use_cases.NewCloseSessionUseCase(sessionRepo, messageRepo, eventPublisher)
expireSessionsUC := use_cases.NewExpireSessionsUseCase(sessionRepo, eventPublisher)
// Start session expiration background job
go runSessionExpiration(expireSessionsUC)
// Create shutdown context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start servers
errChan := make(chan error, 2)
// Start gRPC server
go func() {
if err := startGRPCServer(
cfg,
createSessionUC,
joinSessionUC,
getSessionStatusUC,
reportCompletionUC,
closeSessionUC,
); err != nil {
errChan <- fmt.Errorf("gRPC server error: %w", err)
}
}()
// Start HTTP server
go func() {
if err := startHTTPServer(
cfg,
createSessionUC,
joinSessionUC,
getSessionStatusUC,
reportCompletionUC,
closeSessionUC,
sessionRepo,
); err != nil {
errChan <- fmt.Errorf("HTTP server error: %w", err)
}
}()
// Wait for shutdown signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigChan:
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
case err := <-errChan:
logger.Error("Server error", zap.Error(err))
}
// Graceful shutdown
logger.Info("Shutting down...")
cancel()
// Give services time to shutdown gracefully
time.Sleep(5 * time.Second)
logger.Info("Shutdown complete")
_ = ctx
_ = sessionCache
}
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
db, err := sql.Open("postgres", cfg.DSN())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLife)
// Test connection
if err := db.Ping(); err != nil {
return nil, err
}
logger.Info("Connected to PostgreSQL")
return db, nil
}
func initRedis(cfg config.RedisConfig) *redis.Client {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr(),
Password: cfg.Password,
DB: cfg.DB,
})
// Test connection
ctx := context.Background()
if err := client.Ping(ctx).Err(); err != nil {
logger.Warn("Redis connection failed, continuing without cache", zap.Error(err))
} else {
logger.Info("Connected to Redis")
}
return client
}
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
conn, err := amqp.Dial(cfg.URL())
if err != nil {
return nil, err
}
logger.Info("Connected to RabbitMQ")
return conn, nil
}
func startGRPCServer(
cfg *config.Config,
createSessionUC *use_cases.CreateSessionUseCase,
joinSessionUC *use_cases.JoinSessionUseCase,
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
reportCompletionUC *use_cases.ReportCompletionUseCase,
closeSessionUC *use_cases.CloseSessionUseCase,
) error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
if err != nil {
return err
}
grpcServer := grpc.NewServer()
// Register services (using our custom handler, not generated proto)
// In production, you would register the generated proto service
_ = grpcadapter.NewSessionCoordinatorServer(
createSessionUC,
joinSessionUC,
getSessionStatusUC,
reportCompletionUC,
closeSessionUC,
)
// Enable reflection for debugging
reflection.Register(grpcServer)
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
return grpcServer.Serve(listener)
}
func startHTTPServer(
cfg *config.Config,
createSessionUC *use_cases.CreateSessionUseCase,
joinSessionUC *use_cases.JoinSessionUseCase,
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
reportCompletionUC *use_cases.ReportCompletionUseCase,
closeSessionUC *use_cases.CloseSessionUseCase,
sessionRepo repositories.SessionRepository,
) error {
// Set Gin mode
if cfg.Server.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.New()
router.Use(gin.Recovery())
router.Use(gin.Logger())
// Create HTTP handler
httpHandler := httphandler.NewSessionHTTPHandler(
createSessionUC,
joinSessionUC,
getSessionStatusUC,
reportCompletionUC,
closeSessionUC,
sessionRepo,
)
// Health check
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "session-coordinator",
})
})
// Register API routes
api := router.Group("/api/v1")
httpHandler.RegisterRoutes(api)
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
}
func runSessionExpiration(expireSessionsUC *use_cases.ExpireSessionsUseCase) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
count, err := expireSessionsUC.Execute(ctx)
cancel()
if err != nil {
logger.Error("Failed to expire sessions", zap.Error(err))
} else if count > 0 {
logger.Info("Expired stale sessions", zap.Int("count", count))
}
}
}

View File

@ -0,0 +1,342 @@
package entities
import (
"errors"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
)
var (
ErrSessionNotFound = errors.New("session not found")
ErrSessionFull = errors.New("session is full")
ErrSessionExpired = errors.New("session expired")
ErrSessionNotInProgress = errors.New("session not in progress")
ErrParticipantNotFound = errors.New("participant not found")
ErrInvalidSessionType = errors.New("invalid session type")
ErrInvalidStatusTransition = errors.New("invalid status transition")
)
// SessionType represents the type of MPC session
type SessionType string
const (
SessionTypeKeygen SessionType = "keygen"
SessionTypeSign SessionType = "sign"
)
// IsValid checks if the session type is valid
func (t SessionType) IsValid() bool {
return t == SessionTypeKeygen || t == SessionTypeSign
}
// MPCSession represents an MPC session
// Coordinator only manages session metadata, does not participate in MPC computation
type MPCSession struct {
ID value_objects.SessionID
SessionType SessionType
Threshold value_objects.Threshold
Participants []*Participant
Status value_objects.SessionStatus
MessageHash []byte // Used for Sign sessions
PublicKey []byte // Group public key after Keygen completion
CreatedBy string
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time
CompletedAt *time.Time
}
// NewMPCSession creates a new MPC session
func NewMPCSession(
sessionType SessionType,
threshold value_objects.Threshold,
createdBy string,
expiresIn time.Duration,
messageHash []byte, // Only for Sign sessions
) (*MPCSession, error) {
if !sessionType.IsValid() {
return nil, ErrInvalidSessionType
}
if sessionType == SessionTypeSign && len(messageHash) == 0 {
return nil, errors.New("message hash required for sign session")
}
now := time.Now().UTC()
return &MPCSession{
ID: value_objects.NewSessionID(),
SessionType: sessionType,
Threshold: threshold,
Participants: make([]*Participant, 0, threshold.N()),
Status: value_objects.SessionStatusCreated,
MessageHash: messageHash,
CreatedBy: createdBy,
CreatedAt: now,
UpdatedAt: now,
ExpiresAt: now.Add(expiresIn),
}, nil
}
// AddParticipant adds a participant to the session
func (s *MPCSession) AddParticipant(p *Participant) error {
if len(s.Participants) >= s.Threshold.N() {
return ErrSessionFull
}
s.Participants = append(s.Participants, p)
s.UpdatedAt = time.Now().UTC()
return nil
}
// GetParticipant gets a participant by party ID
func (s *MPCSession) GetParticipant(partyID value_objects.PartyID) (*Participant, error) {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
return p, nil
}
}
return nil, ErrParticipantNotFound
}
// UpdateParticipantStatus updates a participant's status
func (s *MPCSession) UpdateParticipantStatus(partyID value_objects.PartyID, status value_objects.ParticipantStatus) error {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
switch status {
case value_objects.ParticipantStatusJoined:
return p.Join()
case value_objects.ParticipantStatusReady:
return p.MarkReady()
case value_objects.ParticipantStatusCompleted:
return p.MarkCompleted()
case value_objects.ParticipantStatusFailed:
p.MarkFailed()
return nil
default:
return errors.New("invalid status")
}
}
}
return ErrParticipantNotFound
}
// CanStart checks if all participants have joined and the session can start
func (s *MPCSession) CanStart() bool {
if len(s.Participants) != s.Threshold.N() {
return false
}
readyCount := 0
for _, p := range s.Participants {
// Accept participants in either joined or ready status
if p.IsJoined() || p.IsReady() {
readyCount++
}
}
return readyCount == s.Threshold.N()
}
// Start transitions the session to in_progress
func (s *MPCSession) Start() error {
// If already in progress, just return success (idempotent)
if s.Status == value_objects.SessionStatusInProgress {
return nil
}
if !s.Status.CanTransitionTo(value_objects.SessionStatusInProgress) {
return ErrInvalidStatusTransition
}
if !s.CanStart() {
return errors.New("not all participants have joined")
}
s.Status = value_objects.SessionStatusInProgress
s.UpdatedAt = time.Now().UTC()
return nil
}
// Complete marks the session as completed
func (s *MPCSession) Complete(publicKey []byte) error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusCompleted) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusCompleted
s.PublicKey = publicKey
now := time.Now().UTC()
s.CompletedAt = &now
s.UpdatedAt = now
return nil
}
// Fail marks the session as failed
func (s *MPCSession) Fail() error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusFailed) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusFailed
s.UpdatedAt = time.Now().UTC()
return nil
}
// Expire marks the session as expired
func (s *MPCSession) Expire() error {
if !s.Status.CanTransitionTo(value_objects.SessionStatusExpired) {
return ErrInvalidStatusTransition
}
s.Status = value_objects.SessionStatusExpired
s.UpdatedAt = time.Now().UTC()
return nil
}
// IsExpired checks if the session has expired
func (s *MPCSession) IsExpired() bool {
return time.Now().UTC().After(s.ExpiresAt)
}
// IsActive checks if the session is active
func (s *MPCSession) IsActive() bool {
return s.Status.IsActive() && !s.IsExpired()
}
// IsParticipant checks if a party is a participant
func (s *MPCSession) IsParticipant(partyID value_objects.PartyID) bool {
for _, p := range s.Participants {
if p.PartyID.Equals(partyID) {
return true
}
}
return false
}
// AllCompleted checks if all participants have completed
func (s *MPCSession) AllCompleted() bool {
for _, p := range s.Participants {
if !p.IsCompleted() {
return false
}
}
return true
}
// CompletedCount returns the number of completed participants
func (s *MPCSession) CompletedCount() int {
count := 0
for _, p := range s.Participants {
if p.IsCompleted() {
count++
}
}
return count
}
// JoinedCount returns the number of joined participants
func (s *MPCSession) JoinedCount() int {
count := 0
for _, p := range s.Participants {
if p.IsJoined() {
count++
}
}
return count
}
// GetPartyIDs returns all party IDs
func (s *MPCSession) GetPartyIDs() []string {
ids := make([]string, len(s.Participants))
for i, p := range s.Participants {
ids[i] = p.PartyID.String()
}
return ids
}
// GetOtherParties returns participants except the specified party
func (s *MPCSession) GetOtherParties(excludePartyID value_objects.PartyID) []*Participant {
others := make([]*Participant, 0, len(s.Participants)-1)
for _, p := range s.Participants {
if !p.PartyID.Equals(excludePartyID) {
others = append(others, p)
}
}
return others
}
// ToDTO converts to a DTO for API responses
func (s *MPCSession) ToDTO() SessionDTO {
participants := make([]ParticipantDTO, len(s.Participants))
for i, p := range s.Participants {
participants[i] = ParticipantDTO{
PartyID: p.PartyID.String(),
PartyIndex: p.PartyIndex,
Status: p.Status.String(),
DeviceType: string(p.DeviceInfo.DeviceType),
}
}
return SessionDTO{
ID: s.ID.String(),
SessionType: string(s.SessionType),
ThresholdN: s.Threshold.N(),
ThresholdT: s.Threshold.T(),
Participants: participants,
Status: s.Status.String(),
CreatedAt: s.CreatedAt,
ExpiresAt: s.ExpiresAt,
}
}
// SessionDTO is a data transfer object for sessions
type SessionDTO struct {
ID string `json:"id"`
SessionType string `json:"session_type"`
ThresholdN int `json:"threshold_n"`
ThresholdT int `json:"threshold_t"`
Participants []ParticipantDTO `json:"participants"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// ParticipantDTO is a data transfer object for participants
type ParticipantDTO struct {
PartyID string `json:"party_id"`
PartyIndex int `json:"party_index"`
Status string `json:"status"`
DeviceType string `json:"device_type"`
}
// Reconstruct reconstructs an MPCSession from database
func ReconstructSession(
id uuid.UUID,
sessionType string,
thresholdT, thresholdN int,
status string,
messageHash, publicKey []byte,
createdBy string,
createdAt, updatedAt, expiresAt time.Time,
completedAt *time.Time,
participants []*Participant,
) (*MPCSession, error) {
sessionStatus, err := value_objects.NewSessionStatus(status)
if err != nil {
return nil, err
}
threshold, err := value_objects.NewThreshold(thresholdT, thresholdN)
if err != nil {
return nil, err
}
return &MPCSession{
ID: value_objects.SessionIDFromUUID(id),
SessionType: SessionType(sessionType),
Threshold: threshold,
Participants: participants,
Status: sessionStatus,
MessageHash: messageHash,
PublicKey: publicKey,
CreatedBy: createdBy,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
ExpiresAt: expiresAt,
CompletedAt: completedAt,
}, nil
}

View File

@ -0,0 +1,354 @@
//go:build e2e
package e2e_test
import (
"bytes"
"encoding/json"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type KeygenFlowTestSuite struct {
suite.Suite
baseURL string
client *http.Client
}
func TestKeygenFlowSuite(t *testing.T) {
if testing.Short() {
t.Skip("Skipping e2e test in short mode")
}
suite.Run(t, new(KeygenFlowTestSuite))
}
func (s *KeygenFlowTestSuite) SetupSuite() {
s.baseURL = os.Getenv("SESSION_COORDINATOR_URL")
if s.baseURL == "" {
s.baseURL = "http://localhost:8080"
}
s.client = &http.Client{
Timeout: 30 * time.Second,
}
// Wait for service to be ready
s.waitForService()
}
func (s *KeygenFlowTestSuite) waitForService() {
maxRetries := 30
for i := 0; i < maxRetries; i++ {
resp, err := s.client.Get(s.baseURL + "/health")
if err == nil && resp.StatusCode == http.StatusOK {
resp.Body.Close()
return
}
if resp != nil {
resp.Body.Close()
}
time.Sleep(time.Second)
}
s.T().Fatal("Service not ready after waiting")
}
type CreateSessionRequest struct {
SessionType string `json:"sessionType"`
ThresholdT int `json:"thresholdT"`
ThresholdN int `json:"thresholdN"`
CreatedBy string `json:"createdBy"`
}
type CreateSessionResponse struct {
SessionID string `json:"sessionId"`
JoinToken string `json:"joinToken"`
Status string `json:"status"`
}
type JoinSessionRequest struct {
JoinToken string `json:"joinToken"`
PartyID string `json:"partyId"`
DeviceType string `json:"deviceType"`
DeviceID string `json:"deviceId"`
}
type JoinSessionResponse struct {
SessionID string `json:"sessionId"`
PartyIndex int `json:"partyIndex"`
Status string `json:"status"`
Participants []struct {
PartyID string `json:"partyId"`
Status string `json:"status"`
} `json:"participants"`
}
type SessionStatusResponse struct {
SessionID string `json:"sessionId"`
Status string `json:"status"`
ThresholdT int `json:"thresholdT"`
ThresholdN int `json:"thresholdN"`
Participants []struct {
PartyID string `json:"partyId"`
PartyIndex int `json:"partyIndex"`
Status string `json:"status"`
} `json:"participants"`
}
func (s *KeygenFlowTestSuite) TestCompleteKeygenFlow() {
// Step 1: Create a new keygen session
createReq := CreateSessionRequest{
SessionType: "keygen",
ThresholdT: 2,
ThresholdN: 3,
CreatedBy: "e2e_test_user",
}
createResp := s.createSession(createReq)
require.NotEmpty(s.T(), createResp.SessionID)
require.NotEmpty(s.T(), createResp.JoinToken)
assert.Equal(s.T(), "created", createResp.Status)
sessionID := createResp.SessionID
joinToken := createResp.JoinToken
// Step 2: First party joins
joinReq1 := JoinSessionRequest{
JoinToken: joinToken,
PartyID: "party_user_device",
DeviceType: "iOS",
DeviceID: "device_001",
}
joinResp1 := s.joinSession(joinReq1)
assert.Equal(s.T(), sessionID, joinResp1.SessionID)
assert.Equal(s.T(), 0, joinResp1.PartyIndex)
// Step 3: Second party joins
joinReq2 := JoinSessionRequest{
JoinToken: joinToken,
PartyID: "party_server",
DeviceType: "server",
DeviceID: "server_001",
}
joinResp2 := s.joinSession(joinReq2)
assert.Equal(s.T(), sessionID, joinResp2.SessionID)
assert.Equal(s.T(), 1, joinResp2.PartyIndex)
// Step 4: Third party joins
joinReq3 := JoinSessionRequest{
JoinToken: joinToken,
PartyID: "party_recovery",
DeviceType: "recovery",
DeviceID: "recovery_001",
}
joinResp3 := s.joinSession(joinReq3)
assert.Equal(s.T(), sessionID, joinResp3.SessionID)
assert.Equal(s.T(), 2, joinResp3.PartyIndex)
// Step 5: Check session status - should have all participants
statusResp := s.getSessionStatus(sessionID)
assert.Equal(s.T(), 3, len(statusResp.Participants))
// Step 6: Mark parties as ready
s.markPartyReady(sessionID, "party_user_device")
s.markPartyReady(sessionID, "party_server")
s.markPartyReady(sessionID, "party_recovery")
// Step 7: Start the session
s.startSession(sessionID)
// Step 8: Verify session is in progress
statusResp = s.getSessionStatus(sessionID)
assert.Equal(s.T(), "in_progress", statusResp.Status)
// Step 9: Report completion (simulate keygen completion)
publicKey := []byte("test-group-public-key-from-keygen")
s.reportCompletion(sessionID, "party_user_device", publicKey)
// Step 10: Verify session is completed
statusResp = s.getSessionStatus(sessionID)
assert.Equal(s.T(), "completed", statusResp.Status)
}
func (s *KeygenFlowTestSuite) TestJoinSessionWithInvalidToken() {
joinReq := JoinSessionRequest{
JoinToken: "invalid-token",
PartyID: "party_test",
DeviceType: "iOS",
DeviceID: "device_test",
}
body, _ := json.Marshal(joinReq)
resp, err := s.client.Post(
s.baseURL+"/api/v1/sessions/join",
"application/json",
bytes.NewReader(body),
)
require.NoError(s.T(), err)
defer resp.Body.Close()
assert.Equal(s.T(), http.StatusUnauthorized, resp.StatusCode)
}
func (s *KeygenFlowTestSuite) TestGetNonExistentSession() {
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/00000000-0000-0000-0000-000000000000")
require.NoError(s.T(), err)
defer resp.Body.Close()
assert.Equal(s.T(), http.StatusNotFound, resp.StatusCode)
}
func (s *KeygenFlowTestSuite) TestExceedParticipantLimit() {
// Create session with 2 participants max
createReq := CreateSessionRequest{
SessionType: "keygen",
ThresholdT: 2,
ThresholdN: 2,
CreatedBy: "e2e_test_user_limit",
}
createResp := s.createSession(createReq)
joinToken := createResp.JoinToken
// Join with 2 parties (should succeed)
for i := 0; i < 2; i++ {
joinReq := JoinSessionRequest{
JoinToken: joinToken,
PartyID: "party_" + string(rune('a'+i)),
DeviceType: "test",
DeviceID: "device_" + string(rune('a'+i)),
}
s.joinSession(joinReq)
}
// Try to join with 3rd party (should fail)
joinReq := JoinSessionRequest{
JoinToken: joinToken,
PartyID: "party_extra",
DeviceType: "test",
DeviceID: "device_extra",
}
body, _ := json.Marshal(joinReq)
resp, err := s.client.Post(
s.baseURL+"/api/v1/sessions/join",
"application/json",
bytes.NewReader(body),
)
require.NoError(s.T(), err)
defer resp.Body.Close()
assert.Equal(s.T(), http.StatusBadRequest, resp.StatusCode)
}
// Helper methods
func (s *KeygenFlowTestSuite) createSession(req CreateSessionRequest) CreateSessionResponse {
body, _ := json.Marshal(req)
resp, err := s.client.Post(
s.baseURL+"/api/v1/sessions",
"application/json",
bytes.NewReader(body),
)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusCreated, resp.StatusCode)
var result CreateSessionResponse
err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(s.T(), err)
return result
}
func (s *KeygenFlowTestSuite) joinSession(req JoinSessionRequest) JoinSessionResponse {
body, _ := json.Marshal(req)
resp, err := s.client.Post(
s.baseURL+"/api/v1/sessions/join",
"application/json",
bytes.NewReader(body),
)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
var result JoinSessionResponse
err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(s.T(), err)
return result
}
func (s *KeygenFlowTestSuite) getSessionStatus(sessionID string) SessionStatusResponse {
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/" + sessionID)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
var result SessionStatusResponse
err = json.NewDecoder(resp.Body).Decode(&result)
require.NoError(s.T(), err)
return result
}
func (s *KeygenFlowTestSuite) markPartyReady(sessionID, partyID string) {
req := map[string]string{"partyId": partyID}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest(
http.MethodPut,
s.baseURL+"/api/v1/sessions/"+sessionID+"/parties/"+partyID+"/ready",
bytes.NewReader(body),
)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
}
func (s *KeygenFlowTestSuite) startSession(sessionID string) {
httpReq, _ := http.NewRequest(
http.MethodPost,
s.baseURL+"/api/v1/sessions/"+sessionID+"/start",
nil,
)
resp, err := s.client.Do(httpReq)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
}
func (s *KeygenFlowTestSuite) reportCompletion(sessionID string, partyID string, publicKey []byte) {
req := map[string]interface{}{
"party_id": partyID,
"public_key": string(publicKey),
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest(
http.MethodPost,
s.baseURL+"/api/v1/sessions/"+sessionID+"/complete",
bytes.NewReader(body),
)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
require.NoError(s.T(), err)
defer resp.Body.Close()
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
}