fix: Implement MarkPartyReady and StartSession handlers, update domain logic
- Add sessionRepo to HTTP handler for database operations - Implement MarkPartyReady handler to update participant status - Implement StartSession handler to start MPC sessions - Update CanStart() to accept participants in 'ready' status - Make Start() method idempotent to handle automatic + explicit starts - Fix repository injection through dependency chain in main.go - Add party_id parameter to test completion request
This commit is contained in:
parent
6fa4d7ac1d
commit
7531cbd07a
|
|
@ -0,0 +1,326 @@
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionCoordinatorServer implements the gRPC SessionCoordinator service
|
||||||
|
type SessionCoordinatorServer struct {
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionCoordinatorServer creates a new gRPC server
|
||||||
|
func NewSessionCoordinatorServer(
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase,
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||||
|
) *SessionCoordinatorServer {
|
||||||
|
return &SessionCoordinatorServer{
|
||||||
|
createSessionUC: createSessionUC,
|
||||||
|
joinSessionUC: joinSessionUC,
|
||||||
|
getSessionStatusUC: getSessionStatusUC,
|
||||||
|
reportCompletionUC: reportCompletionUC,
|
||||||
|
closeSessionUC: closeSessionUC,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession creates a new MPC session
|
||||||
|
func (s *SessionCoordinatorServer) CreateSession(
|
||||||
|
ctx context.Context,
|
||||||
|
req *CreateSessionRequest,
|
||||||
|
) (*CreateSessionResponse, error) {
|
||||||
|
// Convert request to input
|
||||||
|
participants := make([]input.ParticipantInfo, len(req.Participants))
|
||||||
|
for i, p := range req.Participants {
|
||||||
|
participants[i] = input.ParticipantInfo{
|
||||||
|
PartyID: p.PartyId,
|
||||||
|
DeviceInfo: entities.DeviceInfo{
|
||||||
|
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
|
||||||
|
DeviceID: p.DeviceInfo.DeviceId,
|
||||||
|
Platform: p.DeviceInfo.Platform,
|
||||||
|
AppVersion: p.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.CreateSessionInput{
|
||||||
|
InitiatorID: "", // Could be extracted from auth context
|
||||||
|
SessionType: req.SessionType,
|
||||||
|
ThresholdN: int(req.ThresholdN),
|
||||||
|
ThresholdT: int(req.ThresholdT),
|
||||||
|
Participants: participants,
|
||||||
|
MessageHash: req.MessageHash,
|
||||||
|
ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute use case
|
||||||
|
output, err := s.createSessionUC.Execute(ctx, inputData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, toGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert output to response
|
||||||
|
return &CreateSessionResponse{
|
||||||
|
SessionId: output.SessionID.String(),
|
||||||
|
JoinTokens: output.JoinTokens,
|
||||||
|
ExpiresAt: output.ExpiresAt.UnixMilli(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSession allows a participant to join a session
|
||||||
|
func (s *SessionCoordinatorServer) JoinSession(
|
||||||
|
ctx context.Context,
|
||||||
|
req *JoinSessionRequest,
|
||||||
|
) (*JoinSessionResponse, error) {
|
||||||
|
sessionID, err := uuid.Parse(req.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.JoinSessionInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
PartyID: req.PartyId,
|
||||||
|
JoinToken: req.JoinToken,
|
||||||
|
DeviceInfo: entities.DeviceInfo{
|
||||||
|
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
|
||||||
|
DeviceID: req.DeviceInfo.DeviceId,
|
||||||
|
Platform: req.DeviceInfo.Platform,
|
||||||
|
AppVersion: req.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := s.joinSessionUC.Execute(ctx, inputData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, toGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert other parties to response format
|
||||||
|
otherParties := make([]*PartyInfo, len(output.OtherParties))
|
||||||
|
for i, p := range output.OtherParties {
|
||||||
|
otherParties[i] = &PartyInfo{
|
||||||
|
PartyId: p.PartyID,
|
||||||
|
PartyIndex: int32(p.PartyIndex),
|
||||||
|
DeviceInfo: &DeviceInfo{
|
||||||
|
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||||
|
DeviceId: p.DeviceInfo.DeviceID,
|
||||||
|
Platform: p.DeviceInfo.Platform,
|
||||||
|
AppVersion: p.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &JoinSessionResponse{
|
||||||
|
Success: output.Success,
|
||||||
|
SessionInfo: &SessionInfo{
|
||||||
|
SessionId: output.SessionInfo.SessionID.String(),
|
||||||
|
SessionType: output.SessionInfo.SessionType,
|
||||||
|
ThresholdN: int32(output.SessionInfo.ThresholdN),
|
||||||
|
ThresholdT: int32(output.SessionInfo.ThresholdT),
|
||||||
|
MessageHash: output.SessionInfo.MessageHash,
|
||||||
|
Status: output.SessionInfo.Status,
|
||||||
|
},
|
||||||
|
OtherParties: otherParties,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionStatus retrieves the status of a session
|
||||||
|
func (s *SessionCoordinatorServer) GetSessionStatus(
|
||||||
|
ctx context.Context,
|
||||||
|
req *GetSessionStatusRequest,
|
||||||
|
) (*GetSessionStatusResponse, error) {
|
||||||
|
sessionID, err := uuid.Parse(req.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := s.getSessionStatusUC.Execute(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, toGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate completed parties from participants
|
||||||
|
completedParties := 0
|
||||||
|
for _, p := range output.Participants {
|
||||||
|
if p.Status == "completed" {
|
||||||
|
completedParties++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GetSessionStatusResponse{
|
||||||
|
Status: output.Status,
|
||||||
|
CompletedParties: int32(completedParties),
|
||||||
|
TotalParties: int32(len(output.Participants)),
|
||||||
|
PublicKey: output.PublicKey,
|
||||||
|
Signature: output.Signature,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportCompletion reports that a participant has completed
|
||||||
|
func (s *SessionCoordinatorServer) ReportCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
req *ReportCompletionRequest,
|
||||||
|
) (*ReportCompletionResponse, error) {
|
||||||
|
sessionID, err := uuid.Parse(req.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.ReportCompletionInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
PartyID: req.PartyId,
|
||||||
|
PublicKey: req.PublicKey,
|
||||||
|
Signature: req.Signature,
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := s.reportCompletionUC.Execute(ctx, inputData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, toGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ReportCompletionResponse{
|
||||||
|
Success: output.Success,
|
||||||
|
AllCompleted: output.AllCompleted,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseSession closes a session
|
||||||
|
func (s *SessionCoordinatorServer) CloseSession(
|
||||||
|
ctx context.Context,
|
||||||
|
req *CloseSessionRequest,
|
||||||
|
) (*CloseSessionResponse, error) {
|
||||||
|
sessionID, err := uuid.Parse(req.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.closeSessionUC.Execute(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, toGRPCError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CloseSessionResponse{
|
||||||
|
Success: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toGRPCError converts domain errors to gRPC errors
|
||||||
|
func toGRPCError(err error) error {
|
||||||
|
switch err {
|
||||||
|
case entities.ErrSessionExpired:
|
||||||
|
return status.Error(codes.DeadlineExceeded, err.Error())
|
||||||
|
case entities.ErrSessionFull:
|
||||||
|
return status.Error(codes.ResourceExhausted, err.Error())
|
||||||
|
case entities.ErrParticipantNotFound:
|
||||||
|
return status.Error(codes.NotFound, err.Error())
|
||||||
|
case entities.ErrSessionNotInProgress:
|
||||||
|
return status.Error(codes.FailedPrecondition, err.Error())
|
||||||
|
case entities.ErrInvalidSessionType:
|
||||||
|
return status.Error(codes.InvalidArgument, err.Error())
|
||||||
|
default:
|
||||||
|
return status.Error(codes.Internal, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request/Response types (normally generated from proto)
|
||||||
|
// These are simplified versions - actual implementation would use generated proto types
|
||||||
|
|
||||||
|
type CreateSessionRequest struct {
|
||||||
|
SessionType string
|
||||||
|
ThresholdN int32
|
||||||
|
ThresholdT int32
|
||||||
|
Participants []*ParticipantInfoProto
|
||||||
|
MessageHash []byte
|
||||||
|
ExpiresInSeconds int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateSessionResponse struct {
|
||||||
|
SessionId string
|
||||||
|
JoinTokens map[string]string
|
||||||
|
ExpiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParticipantInfoProto struct {
|
||||||
|
PartyId string
|
||||||
|
DeviceInfo *DeviceInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceInfo struct {
|
||||||
|
DeviceType string
|
||||||
|
DeviceId string
|
||||||
|
Platform string
|
||||||
|
AppVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinSessionRequest struct {
|
||||||
|
SessionId string
|
||||||
|
PartyId string
|
||||||
|
JoinToken string
|
||||||
|
DeviceInfo *DeviceInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinSessionResponse struct {
|
||||||
|
Success bool
|
||||||
|
SessionInfo *SessionInfo
|
||||||
|
OtherParties []*PartyInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionInfo struct {
|
||||||
|
SessionId string
|
||||||
|
SessionType string
|
||||||
|
ThresholdN int32
|
||||||
|
ThresholdT int32
|
||||||
|
MessageHash []byte
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PartyInfo struct {
|
||||||
|
PartyId string
|
||||||
|
PartyIndex int32
|
||||||
|
DeviceInfo *DeviceInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetSessionStatusRequest struct {
|
||||||
|
SessionId string
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetSessionStatusResponse struct {
|
||||||
|
Status string
|
||||||
|
CompletedParties int32
|
||||||
|
TotalParties int32
|
||||||
|
PublicKey []byte
|
||||||
|
Signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReportCompletionRequest struct {
|
||||||
|
SessionId string
|
||||||
|
PartyId string
|
||||||
|
PublicKey []byte
|
||||||
|
Signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReportCompletionResponse struct {
|
||||||
|
Success bool
|
||||||
|
AllCompleted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type CloseSessionRequest struct {
|
||||||
|
SessionId string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CloseSessionResponse struct {
|
||||||
|
Success bool
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,543 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionHTTPHandler handles HTTP requests for session management
|
||||||
|
type SessionHTTPHandler struct {
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase
|
||||||
|
sessionRepo repositories.SessionRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionHTTPHandler creates a new HTTP handler
|
||||||
|
func NewSessionHTTPHandler(
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase,
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||||
|
sessionRepo repositories.SessionRepository,
|
||||||
|
) *SessionHTTPHandler {
|
||||||
|
return &SessionHTTPHandler{
|
||||||
|
createSessionUC: createSessionUC,
|
||||||
|
joinSessionUC: joinSessionUC,
|
||||||
|
getSessionStatusUC: getSessionStatusUC,
|
||||||
|
reportCompletionUC: reportCompletionUC,
|
||||||
|
closeSessionUC: closeSessionUC,
|
||||||
|
sessionRepo: sessionRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers HTTP routes
|
||||||
|
func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) {
|
||||||
|
sessions := r.Group("/sessions")
|
||||||
|
{
|
||||||
|
sessions.POST("", h.CreateSession)
|
||||||
|
sessions.POST("/join", h.JoinSessionByToken)
|
||||||
|
sessions.POST("/:id/join", h.JoinSession)
|
||||||
|
sessions.GET("/:id", h.GetSessionStatus)
|
||||||
|
sessions.GET("/:id/status", h.GetSessionStatus)
|
||||||
|
sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady)
|
||||||
|
sessions.POST("/:id/start", h.StartSession)
|
||||||
|
sessions.POST("/:id/complete", h.ReportCompletion)
|
||||||
|
sessions.DELETE("/:id", h.CloseSession)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSessionRequest is the HTTP request body for creating a session
|
||||||
|
type CreateSessionRequest struct {
|
||||||
|
SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"`
|
||||||
|
ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"`
|
||||||
|
ThresholdT int `json:"thresholdT" binding:"required,min=1"`
|
||||||
|
CreatedBy string `json:"createdBy" binding:"required"`
|
||||||
|
Participants []ParticipantInfoRequest `json:"participants,omitempty"`
|
||||||
|
MessageHash string `json:"messageHash,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantInfoRequest represents a participant in the request
|
||||||
|
type ParticipantInfoRequest struct {
|
||||||
|
PartyID string `json:"party_id" binding:"required"`
|
||||||
|
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceInfoRequest represents device info in the request
|
||||||
|
type DeviceInfoRequest struct {
|
||||||
|
DeviceType string `json:"device_type" binding:"required"`
|
||||||
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
|
Platform string `json:"platform,omitempty"`
|
||||||
|
AppVersion string `json:"app_version,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSessionResponse is the HTTP response for creating a session
|
||||||
|
type CreateSessionResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
JoinToken string `json:"joinToken"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ExpiresAt int64 `json:"expiresAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession handles POST /sessions
|
||||||
|
func (h *SessionHTTPHandler) CreateSession(c *gin.Context) {
|
||||||
|
var req CreateSessionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate threshold
|
||||||
|
if req.ThresholdT > req.ThresholdN {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert request to input
|
||||||
|
participants := make([]input.ParticipantInfo, len(req.Participants))
|
||||||
|
for i, p := range req.Participants {
|
||||||
|
participants[i] = input.ParticipantInfo{
|
||||||
|
PartyID: p.PartyID,
|
||||||
|
DeviceInfo: entities.DeviceInfo{
|
||||||
|
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
|
||||||
|
DeviceID: p.DeviceInfo.DeviceID,
|
||||||
|
Platform: p.DeviceInfo.Platform,
|
||||||
|
AppVersion: p.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var messageHash []byte
|
||||||
|
if req.MessageHash != "" {
|
||||||
|
messageHash = []byte(req.MessageHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||||
|
if expiresIn == 0 {
|
||||||
|
expiresIn = 10 * time.Minute // Default
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.CreateSessionInput{
|
||||||
|
InitiatorID: req.CreatedBy,
|
||||||
|
SessionType: req.SessionType,
|
||||||
|
ThresholdN: req.ThresholdN,
|
||||||
|
ThresholdT: req.ThresholdT,
|
||||||
|
Participants: participants,
|
||||||
|
MessageHash: messageHash,
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := h.createSessionUC.Execute(c.Request.Context(), inputData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract a single join token (for E2E compatibility)
|
||||||
|
// In a real scenario with pre-registered participants, we'd return all tokens
|
||||||
|
// For now, generate a universal join token
|
||||||
|
joinToken := output.SessionID.String() // Use session ID as join token for simplicity
|
||||||
|
if len(output.JoinTokens) > 0 {
|
||||||
|
// If there are participant-specific tokens, use the first one
|
||||||
|
for _, token := range output.JoinTokens {
|
||||||
|
joinToken = token
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, CreateSessionResponse{
|
||||||
|
SessionID: output.SessionID.String(),
|
||||||
|
JoinToken: joinToken,
|
||||||
|
Status: "created",
|
||||||
|
ExpiresAt: output.ExpiresAt.UnixMilli(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSessionRequest is the HTTP request body for joining a session
|
||||||
|
type JoinSessionRequest struct {
|
||||||
|
PartyID string `json:"party_id" binding:"required"`
|
||||||
|
JoinToken string `json:"join_token" binding:"required"`
|
||||||
|
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSessionResponse is the HTTP response for joining a session
|
||||||
|
type JoinSessionResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
SessionInfo SessionInfoDTO `json:"session_info"`
|
||||||
|
OtherParties []PartyInfoDTO `json:"other_parties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionInfoDTO represents session info in responses
|
||||||
|
type SessionInfoDTO struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
SessionType string `json:"session_type"`
|
||||||
|
ThresholdN int `json:"threshold_n"`
|
||||||
|
ThresholdT int `json:"threshold_t"`
|
||||||
|
MessageHash string `json:"message_hash,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PartyInfoDTO represents party info in responses
|
||||||
|
type PartyInfoDTO struct {
|
||||||
|
PartyID string `json:"party_id"`
|
||||||
|
PartyIndex int `json:"party_index"`
|
||||||
|
DeviceInfo DeviceInfoRequest `json:"device_info"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSessionByTokenRequest is the HTTP request body for joining by token
|
||||||
|
type JoinSessionByTokenRequest struct {
|
||||||
|
JoinToken string `json:"joinToken" binding:"required"`
|
||||||
|
PartyID string `json:"partyId" binding:"required"`
|
||||||
|
DeviceType string `json:"deviceType" binding:"required"`
|
||||||
|
DeviceID string `json:"deviceId,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSessionByTokenResponse is the HTTP response for joining by token
|
||||||
|
type JoinSessionByTokenResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
PartyIndex int `json:"partyIndex"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Participants []ParticipantStatusDTO `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantStatusDTO represents participant status in responses
|
||||||
|
type ParticipantStatusDTO struct {
|
||||||
|
PartyID string `json:"partyId"`
|
||||||
|
PartyIndex int `json:"partyIndex"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSessionByToken handles POST /sessions/join (join by token without session ID)
|
||||||
|
func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) {
|
||||||
|
var req JoinSessionByTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
logger.Error("failed to bind join session request", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass empty UUID - the use case will extract session ID from the JWT
|
||||||
|
inputData := input.JoinSessionInput{
|
||||||
|
SessionID: uuid.Nil,
|
||||||
|
PartyID: req.PartyID,
|
||||||
|
JoinToken: req.JoinToken,
|
||||||
|
DeviceInfo: entities.DeviceInfo{
|
||||||
|
DeviceType: entities.DeviceType(req.DeviceType),
|
||||||
|
DeviceID: req.DeviceID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to join session", zap.Error(err))
|
||||||
|
// Return 401 for authentication/token errors
|
||||||
|
if err.Error() == "invalid token" || err.Error() == "token expired" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build participant status list
|
||||||
|
myPartyIndex := output.PartyIndex
|
||||||
|
|
||||||
|
participants := make([]ParticipantStatusDTO, 0)
|
||||||
|
participants = append(participants, ParticipantStatusDTO{
|
||||||
|
PartyID: req.PartyID,
|
||||||
|
Status: "joined",
|
||||||
|
})
|
||||||
|
for _, p := range output.OtherParties {
|
||||||
|
participants = append(participants, ParticipantStatusDTO{
|
||||||
|
PartyID: p.PartyID,
|
||||||
|
Status: "joined",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, JoinSessionByTokenResponse{
|
||||||
|
SessionID: output.SessionInfo.SessionID.String(),
|
||||||
|
PartyIndex: myPartyIndex,
|
||||||
|
Status: output.SessionInfo.Status,
|
||||||
|
Participants: participants,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinSession handles POST /sessions/:id/join
|
||||||
|
func (h *SessionHTTPHandler) JoinSession(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req JoinSessionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.JoinSessionInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
PartyID: req.PartyID,
|
||||||
|
JoinToken: req.JoinToken,
|
||||||
|
DeviceInfo: entities.DeviceInfo{
|
||||||
|
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
|
||||||
|
DeviceID: req.DeviceInfo.DeviceID,
|
||||||
|
Platform: req.DeviceInfo.Platform,
|
||||||
|
AppVersion: req.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to join session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
otherParties := make([]PartyInfoDTO, len(output.OtherParties))
|
||||||
|
for i, p := range output.OtherParties {
|
||||||
|
otherParties[i] = PartyInfoDTO{
|
||||||
|
PartyID: p.PartyID,
|
||||||
|
PartyIndex: p.PartyIndex,
|
||||||
|
DeviceInfo: DeviceInfoRequest{
|
||||||
|
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||||
|
DeviceID: p.DeviceInfo.DeviceID,
|
||||||
|
Platform: p.DeviceInfo.Platform,
|
||||||
|
AppVersion: p.DeviceInfo.AppVersion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, JoinSessionResponse{
|
||||||
|
Success: output.Success,
|
||||||
|
SessionInfo: SessionInfoDTO{
|
||||||
|
SessionID: output.SessionInfo.SessionID.String(),
|
||||||
|
SessionType: output.SessionInfo.SessionType,
|
||||||
|
ThresholdN: output.SessionInfo.ThresholdN,
|
||||||
|
ThresholdT: output.SessionInfo.ThresholdT,
|
||||||
|
MessageHash: string(output.SessionInfo.MessageHash),
|
||||||
|
Status: output.SessionInfo.Status,
|
||||||
|
},
|
||||||
|
OtherParties: otherParties,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStatusResponse is the HTTP response for session status
|
||||||
|
type SessionStatusResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ThresholdT int `json:"thresholdT"`
|
||||||
|
ThresholdN int `json:"thresholdN"`
|
||||||
|
Participants []ParticipantStatusDTO `json:"participants"`
|
||||||
|
PublicKey string `json:"publicKey,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionStatus handles GET /sessions/:id/status
|
||||||
|
func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to get session status", zap.Error(err))
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert participants to DTO
|
||||||
|
participants := make([]ParticipantStatusDTO, len(output.Participants))
|
||||||
|
for i, p := range output.Participants {
|
||||||
|
participants[i] = ParticipantStatusDTO{
|
||||||
|
PartyID: p.PartyID,
|
||||||
|
PartyIndex: p.PartyIndex,
|
||||||
|
Status: p.Status,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, SessionStatusResponse{
|
||||||
|
SessionID: output.SessionID.String(),
|
||||||
|
Status: output.Status,
|
||||||
|
ThresholdT: output.ThresholdT,
|
||||||
|
ThresholdN: output.ThresholdN,
|
||||||
|
Participants: participants,
|
||||||
|
PublicKey: string(output.PublicKey),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportCompletionRequest is the HTTP request for reporting completion
|
||||||
|
type ReportCompletionRequest struct {
|
||||||
|
PartyID string `json:"party_id" binding:"required"`
|
||||||
|
PublicKey string `json:"public_key,omitempty"`
|
||||||
|
Signature string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportCompletion handles POST /sessions/:id/complete
|
||||||
|
func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ReportCompletionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputData := input.ReportCompletionInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
PartyID: req.PartyID,
|
||||||
|
PublicKey: []byte(req.PublicKey),
|
||||||
|
Signature: []byte(req.Signature),
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to report completion", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": output.Success,
|
||||||
|
"all_completed": output.AllCompleted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseSession handles DELETE /sessions/:id
|
||||||
|
func (h *SessionHTTPHandler) CloseSession(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.closeSessionUC.Execute(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to close session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready
|
||||||
|
func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
partyID := c.Param("partyId")
|
||||||
|
if partyID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID))
|
||||||
|
|
||||||
|
// Load session
|
||||||
|
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
if err == entities.ErrSessionNotFound {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error("failed to load session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create party ID value object
|
||||||
|
partyIDVO, err := value_objects.NewPartyID(partyID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update participant status to ready
|
||||||
|
if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil {
|
||||||
|
logger.Error("failed to mark party as ready", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save session
|
||||||
|
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
|
||||||
|
logger.Error("failed to save session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSession handles POST /sessions/:id/start
|
||||||
|
func (h *SessionHTTPHandler) StartSession(c *gin.Context) {
|
||||||
|
sessionID, err := uuid.Parse(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("starting session", zap.String("session_id", sessionID.String()))
|
||||||
|
|
||||||
|
// Load session
|
||||||
|
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
|
||||||
|
if err != nil {
|
||||||
|
if err == entities.ErrSessionNotFound {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error("failed to load session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the session
|
||||||
|
if err := session.Start(); err != nil {
|
||||||
|
logger.Error("failed to start session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save session
|
||||||
|
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
|
||||||
|
logger.Error("failed to save session", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck handles GET /health
|
||||||
|
func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "session-coordinator",
|
||||||
|
"time": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,306 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/reflection"
|
||||||
|
|
||||||
|
"github.com/rwadurian/mpc-system/pkg/config"
|
||||||
|
"github.com/rwadurian/mpc-system/pkg/jwt"
|
||||||
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||||
|
grpcadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/grpc"
|
||||||
|
httphandler "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/http"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/postgres"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/rabbitmq"
|
||||||
|
redisadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/redis"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Parse flags
|
||||||
|
configPath := flag.String("config", "", "Path to config file")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
cfg, err := config.Load(*configPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load config: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize logger
|
||||||
|
if err := logger.Init(&logger.Config{
|
||||||
|
Level: cfg.Logger.Level,
|
||||||
|
Encoding: cfg.Logger.Encoding,
|
||||||
|
}); err != nil {
|
||||||
|
fmt.Printf("Failed to initialize logger: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer logger.Sync()
|
||||||
|
|
||||||
|
logger.Info("Starting Session Coordinator Service",
|
||||||
|
zap.String("environment", cfg.Server.Environment),
|
||||||
|
zap.Int("grpc_port", cfg.Server.GRPCPort),
|
||||||
|
zap.Int("http_port", cfg.Server.HTTPPort))
|
||||||
|
|
||||||
|
// Initialize database connection
|
||||||
|
db, err := initDatabase(cfg.Database)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Initialize Redis connection
|
||||||
|
redisClient := initRedis(cfg.Redis)
|
||||||
|
defer redisClient.Close()
|
||||||
|
|
||||||
|
// Initialize RabbitMQ connection
|
||||||
|
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
|
||||||
|
}
|
||||||
|
defer rabbitConn.Close()
|
||||||
|
|
||||||
|
// Initialize repositories and adapters
|
||||||
|
sessionRepo := postgres.NewSessionPostgresRepo(db)
|
||||||
|
messageRepo := postgres.NewMessagePostgresRepo(db)
|
||||||
|
sessionCache := redisadapter.NewSessionCacheAdapter(redisClient)
|
||||||
|
eventPublisher, err := rabbitmq.NewEventPublisherAdapter(rabbitConn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to create event publisher", zap.Error(err))
|
||||||
|
}
|
||||||
|
defer eventPublisher.Close()
|
||||||
|
|
||||||
|
// Initialize JWT service
|
||||||
|
jwtService := jwt.NewJWTService(
|
||||||
|
cfg.JWT.SecretKey,
|
||||||
|
cfg.JWT.Issuer,
|
||||||
|
cfg.JWT.TokenExpiry,
|
||||||
|
cfg.JWT.RefreshExpiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize use cases
|
||||||
|
createSessionUC := use_cases.NewCreateSessionUseCase(sessionRepo, jwtService, eventPublisher)
|
||||||
|
joinSessionUC := use_cases.NewJoinSessionUseCase(sessionRepo, jwtService, eventPublisher)
|
||||||
|
getSessionStatusUC := use_cases.NewGetSessionStatusUseCase(sessionRepo)
|
||||||
|
reportCompletionUC := use_cases.NewReportCompletionUseCase(sessionRepo, eventPublisher)
|
||||||
|
closeSessionUC := use_cases.NewCloseSessionUseCase(sessionRepo, messageRepo, eventPublisher)
|
||||||
|
expireSessionsUC := use_cases.NewExpireSessionsUseCase(sessionRepo, eventPublisher)
|
||||||
|
|
||||||
|
// Start session expiration background job
|
||||||
|
go runSessionExpiration(expireSessionsUC)
|
||||||
|
|
||||||
|
// Create shutdown context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start servers
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
// Start gRPC server
|
||||||
|
go func() {
|
||||||
|
if err := startGRPCServer(
|
||||||
|
cfg,
|
||||||
|
createSessionUC,
|
||||||
|
joinSessionUC,
|
||||||
|
getSessionStatusUC,
|
||||||
|
reportCompletionUC,
|
||||||
|
closeSessionUC,
|
||||||
|
); err != nil {
|
||||||
|
errChan <- fmt.Errorf("gRPC server error: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start HTTP server
|
||||||
|
go func() {
|
||||||
|
if err := startHTTPServer(
|
||||||
|
cfg,
|
||||||
|
createSessionUC,
|
||||||
|
joinSessionUC,
|
||||||
|
getSessionStatusUC,
|
||||||
|
reportCompletionUC,
|
||||||
|
closeSessionUC,
|
||||||
|
sessionRepo,
|
||||||
|
); err != nil {
|
||||||
|
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for shutdown signal
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sig := <-sigChan:
|
||||||
|
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
||||||
|
case err := <-errChan:
|
||||||
|
logger.Error("Server error", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown
|
||||||
|
logger.Info("Shutting down...")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Give services time to shutdown gracefully
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
logger.Info("Shutdown complete")
|
||||||
|
|
||||||
|
_ = ctx
|
||||||
|
_ = sessionCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
||||||
|
db, err := sql.Open("postgres", cfg.DSN())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Connected to PostgreSQL")
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initRedis(cfg config.RedisConfig) *redis.Client {
|
||||||
|
client := redis.NewClient(&redis.Options{
|
||||||
|
Addr: cfg.Addr(),
|
||||||
|
Password: cfg.Password,
|
||||||
|
DB: cfg.DB,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
|
logger.Warn("Redis connection failed, continuing without cache", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
logger.Info("Connected to Redis")
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
||||||
|
conn, err := amqp.Dial(cfg.URL())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Connected to RabbitMQ")
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startGRPCServer(
|
||||||
|
cfg *config.Config,
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase,
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||||
|
) error {
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcServer := grpc.NewServer()
|
||||||
|
|
||||||
|
// Register services (using our custom handler, not generated proto)
|
||||||
|
// In production, you would register the generated proto service
|
||||||
|
_ = grpcadapter.NewSessionCoordinatorServer(
|
||||||
|
createSessionUC,
|
||||||
|
joinSessionUC,
|
||||||
|
getSessionStatusUC,
|
||||||
|
reportCompletionUC,
|
||||||
|
closeSessionUC,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Enable reflection for debugging
|
||||||
|
reflection.Register(grpcServer)
|
||||||
|
|
||||||
|
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
|
||||||
|
return grpcServer.Serve(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startHTTPServer(
|
||||||
|
cfg *config.Config,
|
||||||
|
createSessionUC *use_cases.CreateSessionUseCase,
|
||||||
|
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||||
|
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||||
|
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||||
|
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||||
|
sessionRepo repositories.SessionRepository,
|
||||||
|
) error {
|
||||||
|
// Set Gin mode
|
||||||
|
if cfg.Server.Environment == "production" {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(gin.Recovery())
|
||||||
|
router.Use(gin.Logger())
|
||||||
|
|
||||||
|
// Create HTTP handler
|
||||||
|
httpHandler := httphandler.NewSessionHTTPHandler(
|
||||||
|
createSessionUC,
|
||||||
|
joinSessionUC,
|
||||||
|
getSessionStatusUC,
|
||||||
|
reportCompletionUC,
|
||||||
|
closeSessionUC,
|
||||||
|
sessionRepo,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Health check
|
||||||
|
router.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "session-coordinator",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register API routes
|
||||||
|
api := router.Group("/api/v1")
|
||||||
|
httpHandler.RegisterRoutes(api)
|
||||||
|
|
||||||
|
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
||||||
|
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSessionExpiration(expireSessionsUC *use_cases.ExpireSessionsUseCase) {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
count, err := expireSessionsUC.Execute(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to expire sessions", zap.Error(err))
|
||||||
|
} else if count > 0 {
|
||||||
|
logger.Info("Expired stale sessions", zap.Int("count", count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,342 @@
|
||||||
|
package entities
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionNotFound = errors.New("session not found")
|
||||||
|
ErrSessionFull = errors.New("session is full")
|
||||||
|
ErrSessionExpired = errors.New("session expired")
|
||||||
|
ErrSessionNotInProgress = errors.New("session not in progress")
|
||||||
|
ErrParticipantNotFound = errors.New("participant not found")
|
||||||
|
ErrInvalidSessionType = errors.New("invalid session type")
|
||||||
|
ErrInvalidStatusTransition = errors.New("invalid status transition")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionType represents the type of MPC session
|
||||||
|
type SessionType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionTypeKeygen SessionType = "keygen"
|
||||||
|
SessionTypeSign SessionType = "sign"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsValid checks if the session type is valid
|
||||||
|
func (t SessionType) IsValid() bool {
|
||||||
|
return t == SessionTypeKeygen || t == SessionTypeSign
|
||||||
|
}
|
||||||
|
|
||||||
|
// MPCSession represents an MPC session
|
||||||
|
// Coordinator only manages session metadata, does not participate in MPC computation
|
||||||
|
type MPCSession struct {
|
||||||
|
ID value_objects.SessionID
|
||||||
|
SessionType SessionType
|
||||||
|
Threshold value_objects.Threshold
|
||||||
|
Participants []*Participant
|
||||||
|
Status value_objects.SessionStatus
|
||||||
|
MessageHash []byte // Used for Sign sessions
|
||||||
|
PublicKey []byte // Group public key after Keygen completion
|
||||||
|
CreatedBy string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
CompletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMPCSession creates a new MPC session
|
||||||
|
func NewMPCSession(
|
||||||
|
sessionType SessionType,
|
||||||
|
threshold value_objects.Threshold,
|
||||||
|
createdBy string,
|
||||||
|
expiresIn time.Duration,
|
||||||
|
messageHash []byte, // Only for Sign sessions
|
||||||
|
) (*MPCSession, error) {
|
||||||
|
if !sessionType.IsValid() {
|
||||||
|
return nil, ErrInvalidSessionType
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionType == SessionTypeSign && len(messageHash) == 0 {
|
||||||
|
return nil, errors.New("message hash required for sign session")
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return &MPCSession{
|
||||||
|
ID: value_objects.NewSessionID(),
|
||||||
|
SessionType: sessionType,
|
||||||
|
Threshold: threshold,
|
||||||
|
Participants: make([]*Participant, 0, threshold.N()),
|
||||||
|
Status: value_objects.SessionStatusCreated,
|
||||||
|
MessageHash: messageHash,
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
ExpiresAt: now.Add(expiresIn),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddParticipant adds a participant to the session
|
||||||
|
func (s *MPCSession) AddParticipant(p *Participant) error {
|
||||||
|
if len(s.Participants) >= s.Threshold.N() {
|
||||||
|
return ErrSessionFull
|
||||||
|
}
|
||||||
|
s.Participants = append(s.Participants, p)
|
||||||
|
s.UpdatedAt = time.Now().UTC()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipant gets a participant by party ID
|
||||||
|
func (s *MPCSession) GetParticipant(partyID value_objects.PartyID) (*Participant, error) {
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if p.PartyID.Equals(partyID) {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrParticipantNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateParticipantStatus updates a participant's status
|
||||||
|
func (s *MPCSession) UpdateParticipantStatus(partyID value_objects.PartyID, status value_objects.ParticipantStatus) error {
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if p.PartyID.Equals(partyID) {
|
||||||
|
switch status {
|
||||||
|
case value_objects.ParticipantStatusJoined:
|
||||||
|
return p.Join()
|
||||||
|
case value_objects.ParticipantStatusReady:
|
||||||
|
return p.MarkReady()
|
||||||
|
case value_objects.ParticipantStatusCompleted:
|
||||||
|
return p.MarkCompleted()
|
||||||
|
case value_objects.ParticipantStatusFailed:
|
||||||
|
p.MarkFailed()
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("invalid status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ErrParticipantNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanStart checks if all participants have joined and the session can start
|
||||||
|
func (s *MPCSession) CanStart() bool {
|
||||||
|
if len(s.Participants) != s.Threshold.N() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
readyCount := 0
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
// Accept participants in either joined or ready status
|
||||||
|
if p.IsJoined() || p.IsReady() {
|
||||||
|
readyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return readyCount == s.Threshold.N()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start transitions the session to in_progress
|
||||||
|
func (s *MPCSession) Start() error {
|
||||||
|
// If already in progress, just return success (idempotent)
|
||||||
|
if s.Status == value_objects.SessionStatusInProgress {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !s.Status.CanTransitionTo(value_objects.SessionStatusInProgress) {
|
||||||
|
return ErrInvalidStatusTransition
|
||||||
|
}
|
||||||
|
if !s.CanStart() {
|
||||||
|
return errors.New("not all participants have joined")
|
||||||
|
}
|
||||||
|
s.Status = value_objects.SessionStatusInProgress
|
||||||
|
s.UpdatedAt = time.Now().UTC()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete marks the session as completed
|
||||||
|
func (s *MPCSession) Complete(publicKey []byte) error {
|
||||||
|
if !s.Status.CanTransitionTo(value_objects.SessionStatusCompleted) {
|
||||||
|
return ErrInvalidStatusTransition
|
||||||
|
}
|
||||||
|
s.Status = value_objects.SessionStatusCompleted
|
||||||
|
s.PublicKey = publicKey
|
||||||
|
now := time.Now().UTC()
|
||||||
|
s.CompletedAt = &now
|
||||||
|
s.UpdatedAt = now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fail marks the session as failed
|
||||||
|
func (s *MPCSession) Fail() error {
|
||||||
|
if !s.Status.CanTransitionTo(value_objects.SessionStatusFailed) {
|
||||||
|
return ErrInvalidStatusTransition
|
||||||
|
}
|
||||||
|
s.Status = value_objects.SessionStatusFailed
|
||||||
|
s.UpdatedAt = time.Now().UTC()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire marks the session as expired
|
||||||
|
func (s *MPCSession) Expire() error {
|
||||||
|
if !s.Status.CanTransitionTo(value_objects.SessionStatusExpired) {
|
||||||
|
return ErrInvalidStatusTransition
|
||||||
|
}
|
||||||
|
s.Status = value_objects.SessionStatusExpired
|
||||||
|
s.UpdatedAt = time.Now().UTC()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired checks if the session has expired
|
||||||
|
func (s *MPCSession) IsExpired() bool {
|
||||||
|
return time.Now().UTC().After(s.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive checks if the session is active
|
||||||
|
func (s *MPCSession) IsActive() bool {
|
||||||
|
return s.Status.IsActive() && !s.IsExpired()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsParticipant checks if a party is a participant
|
||||||
|
func (s *MPCSession) IsParticipant(partyID value_objects.PartyID) bool {
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if p.PartyID.Equals(partyID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllCompleted checks if all participants have completed
|
||||||
|
func (s *MPCSession) AllCompleted() bool {
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if !p.IsCompleted() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletedCount returns the number of completed participants
|
||||||
|
func (s *MPCSession) CompletedCount() int {
|
||||||
|
count := 0
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if p.IsCompleted() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinedCount returns the number of joined participants
|
||||||
|
func (s *MPCSession) JoinedCount() int {
|
||||||
|
count := 0
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if p.IsJoined() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPartyIDs returns all party IDs
|
||||||
|
func (s *MPCSession) GetPartyIDs() []string {
|
||||||
|
ids := make([]string, len(s.Participants))
|
||||||
|
for i, p := range s.Participants {
|
||||||
|
ids[i] = p.PartyID.String()
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOtherParties returns participants except the specified party
|
||||||
|
func (s *MPCSession) GetOtherParties(excludePartyID value_objects.PartyID) []*Participant {
|
||||||
|
others := make([]*Participant, 0, len(s.Participants)-1)
|
||||||
|
for _, p := range s.Participants {
|
||||||
|
if !p.PartyID.Equals(excludePartyID) {
|
||||||
|
others = append(others, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return others
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDTO converts to a DTO for API responses
|
||||||
|
func (s *MPCSession) ToDTO() SessionDTO {
|
||||||
|
participants := make([]ParticipantDTO, len(s.Participants))
|
||||||
|
for i, p := range s.Participants {
|
||||||
|
participants[i] = ParticipantDTO{
|
||||||
|
PartyID: p.PartyID.String(),
|
||||||
|
PartyIndex: p.PartyIndex,
|
||||||
|
Status: p.Status.String(),
|
||||||
|
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SessionDTO{
|
||||||
|
ID: s.ID.String(),
|
||||||
|
SessionType: string(s.SessionType),
|
||||||
|
ThresholdN: s.Threshold.N(),
|
||||||
|
ThresholdT: s.Threshold.T(),
|
||||||
|
Participants: participants,
|
||||||
|
Status: s.Status.String(),
|
||||||
|
CreatedAt: s.CreatedAt,
|
||||||
|
ExpiresAt: s.ExpiresAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionDTO is a data transfer object for sessions
|
||||||
|
type SessionDTO struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
SessionType string `json:"session_type"`
|
||||||
|
ThresholdN int `json:"threshold_n"`
|
||||||
|
ThresholdT int `json:"threshold_t"`
|
||||||
|
Participants []ParticipantDTO `json:"participants"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantDTO is a data transfer object for participants
|
||||||
|
type ParticipantDTO struct {
|
||||||
|
PartyID string `json:"party_id"`
|
||||||
|
PartyIndex int `json:"party_index"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
DeviceType string `json:"device_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct reconstructs an MPCSession from database
|
||||||
|
func ReconstructSession(
|
||||||
|
id uuid.UUID,
|
||||||
|
sessionType string,
|
||||||
|
thresholdT, thresholdN int,
|
||||||
|
status string,
|
||||||
|
messageHash, publicKey []byte,
|
||||||
|
createdBy string,
|
||||||
|
createdAt, updatedAt, expiresAt time.Time,
|
||||||
|
completedAt *time.Time,
|
||||||
|
participants []*Participant,
|
||||||
|
) (*MPCSession, error) {
|
||||||
|
sessionStatus, err := value_objects.NewSessionStatus(status)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold, err := value_objects.NewThreshold(thresholdT, thresholdN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MPCSession{
|
||||||
|
ID: value_objects.SessionIDFromUUID(id),
|
||||||
|
SessionType: SessionType(sessionType),
|
||||||
|
Threshold: threshold,
|
||||||
|
Participants: participants,
|
||||||
|
Status: sessionStatus,
|
||||||
|
MessageHash: messageHash,
|
||||||
|
PublicKey: publicKey,
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
UpdatedAt: updatedAt,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
CompletedAt: completedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
//go:build e2e
|
||||||
|
|
||||||
|
package e2e_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeygenFlowTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
baseURL string
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeygenFlowSuite(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping e2e test in short mode")
|
||||||
|
}
|
||||||
|
suite.Run(t, new(KeygenFlowTestSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) SetupSuite() {
|
||||||
|
s.baseURL = os.Getenv("SESSION_COORDINATOR_URL")
|
||||||
|
if s.baseURL == "" {
|
||||||
|
s.baseURL = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
|
||||||
|
s.client = &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for service to be ready
|
||||||
|
s.waitForService()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) waitForService() {
|
||||||
|
maxRetries := 30
|
||||||
|
for i := 0; i < maxRetries; i++ {
|
||||||
|
resp, err := s.client.Get(s.baseURL + "/health")
|
||||||
|
if err == nil && resp.StatusCode == http.StatusOK {
|
||||||
|
resp.Body.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
s.T().Fatal("Service not ready after waiting")
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateSessionRequest struct {
|
||||||
|
SessionType string `json:"sessionType"`
|
||||||
|
ThresholdT int `json:"thresholdT"`
|
||||||
|
ThresholdN int `json:"thresholdN"`
|
||||||
|
CreatedBy string `json:"createdBy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateSessionResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
JoinToken string `json:"joinToken"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinSessionRequest struct {
|
||||||
|
JoinToken string `json:"joinToken"`
|
||||||
|
PartyID string `json:"partyId"`
|
||||||
|
DeviceType string `json:"deviceType"`
|
||||||
|
DeviceID string `json:"deviceId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinSessionResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
PartyIndex int `json:"partyIndex"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Participants []struct {
|
||||||
|
PartyID string `json:"partyId"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
} `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionStatusResponse struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ThresholdT int `json:"thresholdT"`
|
||||||
|
ThresholdN int `json:"thresholdN"`
|
||||||
|
Participants []struct {
|
||||||
|
PartyID string `json:"partyId"`
|
||||||
|
PartyIndex int `json:"partyIndex"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
} `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) TestCompleteKeygenFlow() {
|
||||||
|
// Step 1: Create a new keygen session
|
||||||
|
createReq := CreateSessionRequest{
|
||||||
|
SessionType: "keygen",
|
||||||
|
ThresholdT: 2,
|
||||||
|
ThresholdN: 3,
|
||||||
|
CreatedBy: "e2e_test_user",
|
||||||
|
}
|
||||||
|
|
||||||
|
createResp := s.createSession(createReq)
|
||||||
|
require.NotEmpty(s.T(), createResp.SessionID)
|
||||||
|
require.NotEmpty(s.T(), createResp.JoinToken)
|
||||||
|
assert.Equal(s.T(), "created", createResp.Status)
|
||||||
|
|
||||||
|
sessionID := createResp.SessionID
|
||||||
|
joinToken := createResp.JoinToken
|
||||||
|
|
||||||
|
// Step 2: First party joins
|
||||||
|
joinReq1 := JoinSessionRequest{
|
||||||
|
JoinToken: joinToken,
|
||||||
|
PartyID: "party_user_device",
|
||||||
|
DeviceType: "iOS",
|
||||||
|
DeviceID: "device_001",
|
||||||
|
}
|
||||||
|
joinResp1 := s.joinSession(joinReq1)
|
||||||
|
assert.Equal(s.T(), sessionID, joinResp1.SessionID)
|
||||||
|
assert.Equal(s.T(), 0, joinResp1.PartyIndex)
|
||||||
|
|
||||||
|
// Step 3: Second party joins
|
||||||
|
joinReq2 := JoinSessionRequest{
|
||||||
|
JoinToken: joinToken,
|
||||||
|
PartyID: "party_server",
|
||||||
|
DeviceType: "server",
|
||||||
|
DeviceID: "server_001",
|
||||||
|
}
|
||||||
|
joinResp2 := s.joinSession(joinReq2)
|
||||||
|
assert.Equal(s.T(), sessionID, joinResp2.SessionID)
|
||||||
|
assert.Equal(s.T(), 1, joinResp2.PartyIndex)
|
||||||
|
|
||||||
|
// Step 4: Third party joins
|
||||||
|
joinReq3 := JoinSessionRequest{
|
||||||
|
JoinToken: joinToken,
|
||||||
|
PartyID: "party_recovery",
|
||||||
|
DeviceType: "recovery",
|
||||||
|
DeviceID: "recovery_001",
|
||||||
|
}
|
||||||
|
joinResp3 := s.joinSession(joinReq3)
|
||||||
|
assert.Equal(s.T(), sessionID, joinResp3.SessionID)
|
||||||
|
assert.Equal(s.T(), 2, joinResp3.PartyIndex)
|
||||||
|
|
||||||
|
// Step 5: Check session status - should have all participants
|
||||||
|
statusResp := s.getSessionStatus(sessionID)
|
||||||
|
assert.Equal(s.T(), 3, len(statusResp.Participants))
|
||||||
|
|
||||||
|
// Step 6: Mark parties as ready
|
||||||
|
s.markPartyReady(sessionID, "party_user_device")
|
||||||
|
s.markPartyReady(sessionID, "party_server")
|
||||||
|
s.markPartyReady(sessionID, "party_recovery")
|
||||||
|
|
||||||
|
// Step 7: Start the session
|
||||||
|
s.startSession(sessionID)
|
||||||
|
|
||||||
|
// Step 8: Verify session is in progress
|
||||||
|
statusResp = s.getSessionStatus(sessionID)
|
||||||
|
assert.Equal(s.T(), "in_progress", statusResp.Status)
|
||||||
|
|
||||||
|
// Step 9: Report completion (simulate keygen completion)
|
||||||
|
publicKey := []byte("test-group-public-key-from-keygen")
|
||||||
|
s.reportCompletion(sessionID, "party_user_device", publicKey)
|
||||||
|
|
||||||
|
// Step 10: Verify session is completed
|
||||||
|
statusResp = s.getSessionStatus(sessionID)
|
||||||
|
assert.Equal(s.T(), "completed", statusResp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) TestJoinSessionWithInvalidToken() {
|
||||||
|
joinReq := JoinSessionRequest{
|
||||||
|
JoinToken: "invalid-token",
|
||||||
|
PartyID: "party_test",
|
||||||
|
DeviceType: "iOS",
|
||||||
|
DeviceID: "device_test",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(joinReq)
|
||||||
|
resp, err := s.client.Post(
|
||||||
|
s.baseURL+"/api/v1/sessions/join",
|
||||||
|
"application/json",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) TestGetNonExistentSession() {
|
||||||
|
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/00000000-0000-0000-0000-000000000000")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), http.StatusNotFound, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) TestExceedParticipantLimit() {
|
||||||
|
// Create session with 2 participants max
|
||||||
|
createReq := CreateSessionRequest{
|
||||||
|
SessionType: "keygen",
|
||||||
|
ThresholdT: 2,
|
||||||
|
ThresholdN: 2,
|
||||||
|
CreatedBy: "e2e_test_user_limit",
|
||||||
|
}
|
||||||
|
|
||||||
|
createResp := s.createSession(createReq)
|
||||||
|
joinToken := createResp.JoinToken
|
||||||
|
|
||||||
|
// Join with 2 parties (should succeed)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
joinReq := JoinSessionRequest{
|
||||||
|
JoinToken: joinToken,
|
||||||
|
PartyID: "party_" + string(rune('a'+i)),
|
||||||
|
DeviceType: "test",
|
||||||
|
DeviceID: "device_" + string(rune('a'+i)),
|
||||||
|
}
|
||||||
|
s.joinSession(joinReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to join with 3rd party (should fail)
|
||||||
|
joinReq := JoinSessionRequest{
|
||||||
|
JoinToken: joinToken,
|
||||||
|
PartyID: "party_extra",
|
||||||
|
DeviceType: "test",
|
||||||
|
DeviceID: "device_extra",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(joinReq)
|
||||||
|
resp, err := s.client.Post(
|
||||||
|
s.baseURL+"/api/v1/sessions/join",
|
||||||
|
"application/json",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) createSession(req CreateSessionRequest) CreateSessionResponse {
|
||||||
|
body, _ := json.Marshal(req)
|
||||||
|
resp, err := s.client.Post(
|
||||||
|
s.baseURL+"/api/v1/sessions",
|
||||||
|
"application/json",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusCreated, resp.StatusCode)
|
||||||
|
|
||||||
|
var result CreateSessionResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) joinSession(req JoinSessionRequest) JoinSessionResponse {
|
||||||
|
body, _ := json.Marshal(req)
|
||||||
|
resp, err := s.client.Post(
|
||||||
|
s.baseURL+"/api/v1/sessions/join",
|
||||||
|
"application/json",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var result JoinSessionResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) getSessionStatus(sessionID string) SessionStatusResponse {
|
||||||
|
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/" + sessionID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var result SessionStatusResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) markPartyReady(sessionID, partyID string) {
|
||||||
|
req := map[string]string{"partyId": partyID}
|
||||||
|
body, _ := json.Marshal(req)
|
||||||
|
|
||||||
|
httpReq, _ := http.NewRequest(
|
||||||
|
http.MethodPut,
|
||||||
|
s.baseURL+"/api/v1/sessions/"+sessionID+"/parties/"+partyID+"/ready",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := s.client.Do(httpReq)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) startSession(sessionID string) {
|
||||||
|
httpReq, _ := http.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
s.baseURL+"/api/v1/sessions/"+sessionID+"/start",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := s.client.Do(httpReq)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *KeygenFlowTestSuite) reportCompletion(sessionID string, partyID string, publicKey []byte) {
|
||||||
|
req := map[string]interface{}{
|
||||||
|
"party_id": partyID,
|
||||||
|
"public_key": string(publicKey),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(req)
|
||||||
|
|
||||||
|
httpReq, _ := http.NewRequest(
|
||||||
|
http.MethodPost,
|
||||||
|
s.baseURL+"/api/v1/sessions/"+sessionID+"/complete",
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := s.client.Do(httpReq)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue