fix: Implement MarkPartyReady and StartSession handlers, update domain logic
- Add sessionRepo to HTTP handler for database operations - Implement MarkPartyReady handler to update participant status - Implement StartSession handler to start MPC sessions - Update CanStart() to accept participants in 'ready' status - Make Start() method idempotent to handle automatic + explicit starts - Fix repository injection through dependency chain in main.go - Add party_id parameter to test completion request
This commit is contained in:
parent
6fa4d7ac1d
commit
7531cbd07a
|
|
@ -0,0 +1,326 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// SessionCoordinatorServer implements the gRPC SessionCoordinator service
|
||||
type SessionCoordinatorServer struct {
|
||||
createSessionUC *use_cases.CreateSessionUseCase
|
||||
joinSessionUC *use_cases.JoinSessionUseCase
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase
|
||||
closeSessionUC *use_cases.CloseSessionUseCase
|
||||
}
|
||||
|
||||
// NewSessionCoordinatorServer creates a new gRPC server
|
||||
func NewSessionCoordinatorServer(
|
||||
createSessionUC *use_cases.CreateSessionUseCase,
|
||||
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||
) *SessionCoordinatorServer {
|
||||
return &SessionCoordinatorServer{
|
||||
createSessionUC: createSessionUC,
|
||||
joinSessionUC: joinSessionUC,
|
||||
getSessionStatusUC: getSessionStatusUC,
|
||||
reportCompletionUC: reportCompletionUC,
|
||||
closeSessionUC: closeSessionUC,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new MPC session
|
||||
func (s *SessionCoordinatorServer) CreateSession(
|
||||
ctx context.Context,
|
||||
req *CreateSessionRequest,
|
||||
) (*CreateSessionResponse, error) {
|
||||
// Convert request to input
|
||||
participants := make([]input.ParticipantInfo, len(req.Participants))
|
||||
for i, p := range req.Participants {
|
||||
participants[i] = input.ParticipantInfo{
|
||||
PartyID: p.PartyId,
|
||||
DeviceInfo: entities.DeviceInfo{
|
||||
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
|
||||
DeviceID: p.DeviceInfo.DeviceId,
|
||||
Platform: p.DeviceInfo.Platform,
|
||||
AppVersion: p.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
inputData := input.CreateSessionInput{
|
||||
InitiatorID: "", // Could be extracted from auth context
|
||||
SessionType: req.SessionType,
|
||||
ThresholdN: int(req.ThresholdN),
|
||||
ThresholdT: int(req.ThresholdT),
|
||||
Participants: participants,
|
||||
MessageHash: req.MessageHash,
|
||||
ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second,
|
||||
}
|
||||
|
||||
// Execute use case
|
||||
output, err := s.createSessionUC.Execute(ctx, inputData)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err)
|
||||
}
|
||||
|
||||
// Convert output to response
|
||||
return &CreateSessionResponse{
|
||||
SessionId: output.SessionID.String(),
|
||||
JoinTokens: output.JoinTokens,
|
||||
ExpiresAt: output.ExpiresAt.UnixMilli(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JoinSession allows a participant to join a session
|
||||
func (s *SessionCoordinatorServer) JoinSession(
|
||||
ctx context.Context,
|
||||
req *JoinSessionRequest,
|
||||
) (*JoinSessionResponse, error) {
|
||||
sessionID, err := uuid.Parse(req.SessionId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||
}
|
||||
|
||||
inputData := input.JoinSessionInput{
|
||||
SessionID: sessionID,
|
||||
PartyID: req.PartyId,
|
||||
JoinToken: req.JoinToken,
|
||||
DeviceInfo: entities.DeviceInfo{
|
||||
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
|
||||
DeviceID: req.DeviceInfo.DeviceId,
|
||||
Platform: req.DeviceInfo.Platform,
|
||||
AppVersion: req.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
|
||||
output, err := s.joinSessionUC.Execute(ctx, inputData)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err)
|
||||
}
|
||||
|
||||
// Convert other parties to response format
|
||||
otherParties := make([]*PartyInfo, len(output.OtherParties))
|
||||
for i, p := range output.OtherParties {
|
||||
otherParties[i] = &PartyInfo{
|
||||
PartyId: p.PartyID,
|
||||
PartyIndex: int32(p.PartyIndex),
|
||||
DeviceInfo: &DeviceInfo{
|
||||
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||
DeviceId: p.DeviceInfo.DeviceID,
|
||||
Platform: p.DeviceInfo.Platform,
|
||||
AppVersion: p.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &JoinSessionResponse{
|
||||
Success: output.Success,
|
||||
SessionInfo: &SessionInfo{
|
||||
SessionId: output.SessionInfo.SessionID.String(),
|
||||
SessionType: output.SessionInfo.SessionType,
|
||||
ThresholdN: int32(output.SessionInfo.ThresholdN),
|
||||
ThresholdT: int32(output.SessionInfo.ThresholdT),
|
||||
MessageHash: output.SessionInfo.MessageHash,
|
||||
Status: output.SessionInfo.Status,
|
||||
},
|
||||
OtherParties: otherParties,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSessionStatus retrieves the status of a session
|
||||
func (s *SessionCoordinatorServer) GetSessionStatus(
|
||||
ctx context.Context,
|
||||
req *GetSessionStatusRequest,
|
||||
) (*GetSessionStatusResponse, error) {
|
||||
sessionID, err := uuid.Parse(req.SessionId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||
}
|
||||
|
||||
output, err := s.getSessionStatusUC.Execute(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err)
|
||||
}
|
||||
|
||||
// Calculate completed parties from participants
|
||||
completedParties := 0
|
||||
for _, p := range output.Participants {
|
||||
if p.Status == "completed" {
|
||||
completedParties++
|
||||
}
|
||||
}
|
||||
|
||||
return &GetSessionStatusResponse{
|
||||
Status: output.Status,
|
||||
CompletedParties: int32(completedParties),
|
||||
TotalParties: int32(len(output.Participants)),
|
||||
PublicKey: output.PublicKey,
|
||||
Signature: output.Signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReportCompletion reports that a participant has completed
|
||||
func (s *SessionCoordinatorServer) ReportCompletion(
|
||||
ctx context.Context,
|
||||
req *ReportCompletionRequest,
|
||||
) (*ReportCompletionResponse, error) {
|
||||
sessionID, err := uuid.Parse(req.SessionId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||
}
|
||||
|
||||
inputData := input.ReportCompletionInput{
|
||||
SessionID: sessionID,
|
||||
PartyID: req.PartyId,
|
||||
PublicKey: req.PublicKey,
|
||||
Signature: req.Signature,
|
||||
}
|
||||
|
||||
output, err := s.reportCompletionUC.Execute(ctx, inputData)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err)
|
||||
}
|
||||
|
||||
return &ReportCompletionResponse{
|
||||
Success: output.Success,
|
||||
AllCompleted: output.AllCompleted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CloseSession closes a session
|
||||
func (s *SessionCoordinatorServer) CloseSession(
|
||||
ctx context.Context,
|
||||
req *CloseSessionRequest,
|
||||
) (*CloseSessionResponse, error) {
|
||||
sessionID, err := uuid.Parse(req.SessionId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid session ID")
|
||||
}
|
||||
|
||||
err = s.closeSessionUC.Execute(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err)
|
||||
}
|
||||
|
||||
return &CloseSessionResponse{
|
||||
Success: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toGRPCError converts domain errors to gRPC errors
|
||||
func toGRPCError(err error) error {
|
||||
switch err {
|
||||
case entities.ErrSessionExpired:
|
||||
return status.Error(codes.DeadlineExceeded, err.Error())
|
||||
case entities.ErrSessionFull:
|
||||
return status.Error(codes.ResourceExhausted, err.Error())
|
||||
case entities.ErrParticipantNotFound:
|
||||
return status.Error(codes.NotFound, err.Error())
|
||||
case entities.ErrSessionNotInProgress:
|
||||
return status.Error(codes.FailedPrecondition, err.Error())
|
||||
case entities.ErrInvalidSessionType:
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
default:
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Request/Response types (normally generated from proto)
|
||||
// These are simplified versions - actual implementation would use generated proto types
|
||||
|
||||
type CreateSessionRequest struct {
|
||||
SessionType string
|
||||
ThresholdN int32
|
||||
ThresholdT int32
|
||||
Participants []*ParticipantInfoProto
|
||||
MessageHash []byte
|
||||
ExpiresInSeconds int64
|
||||
}
|
||||
|
||||
type CreateSessionResponse struct {
|
||||
SessionId string
|
||||
JoinTokens map[string]string
|
||||
ExpiresAt int64
|
||||
}
|
||||
|
||||
type ParticipantInfoProto struct {
|
||||
PartyId string
|
||||
DeviceInfo *DeviceInfo
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
DeviceType string
|
||||
DeviceId string
|
||||
Platform string
|
||||
AppVersion string
|
||||
}
|
||||
|
||||
type JoinSessionRequest struct {
|
||||
SessionId string
|
||||
PartyId string
|
||||
JoinToken string
|
||||
DeviceInfo *DeviceInfo
|
||||
}
|
||||
|
||||
type JoinSessionResponse struct {
|
||||
Success bool
|
||||
SessionInfo *SessionInfo
|
||||
OtherParties []*PartyInfo
|
||||
}
|
||||
|
||||
type SessionInfo struct {
|
||||
SessionId string
|
||||
SessionType string
|
||||
ThresholdN int32
|
||||
ThresholdT int32
|
||||
MessageHash []byte
|
||||
Status string
|
||||
}
|
||||
|
||||
type PartyInfo struct {
|
||||
PartyId string
|
||||
PartyIndex int32
|
||||
DeviceInfo *DeviceInfo
|
||||
}
|
||||
|
||||
type GetSessionStatusRequest struct {
|
||||
SessionId string
|
||||
}
|
||||
|
||||
type GetSessionStatusResponse struct {
|
||||
Status string
|
||||
CompletedParties int32
|
||||
TotalParties int32
|
||||
PublicKey []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type ReportCompletionRequest struct {
|
||||
SessionId string
|
||||
PartyId string
|
||||
PublicKey []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type ReportCompletionResponse struct {
|
||||
Success bool
|
||||
AllCompleted bool
|
||||
}
|
||||
|
||||
type CloseSessionRequest struct {
|
||||
SessionId string
|
||||
}
|
||||
|
||||
type CloseSessionResponse struct {
|
||||
Success bool
|
||||
}
|
||||
|
|
@ -0,0 +1,543 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionHTTPHandler handles HTTP requests for session management
|
||||
type SessionHTTPHandler struct {
|
||||
createSessionUC *use_cases.CreateSessionUseCase
|
||||
joinSessionUC *use_cases.JoinSessionUseCase
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase
|
||||
closeSessionUC *use_cases.CloseSessionUseCase
|
||||
sessionRepo repositories.SessionRepository
|
||||
}
|
||||
|
||||
// NewSessionHTTPHandler creates a new HTTP handler
|
||||
func NewSessionHTTPHandler(
|
||||
createSessionUC *use_cases.CreateSessionUseCase,
|
||||
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||
sessionRepo repositories.SessionRepository,
|
||||
) *SessionHTTPHandler {
|
||||
return &SessionHTTPHandler{
|
||||
createSessionUC: createSessionUC,
|
||||
joinSessionUC: joinSessionUC,
|
||||
getSessionStatusUC: getSessionStatusUC,
|
||||
reportCompletionUC: reportCompletionUC,
|
||||
closeSessionUC: closeSessionUC,
|
||||
sessionRepo: sessionRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers HTTP routes
|
||||
func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) {
|
||||
sessions := r.Group("/sessions")
|
||||
{
|
||||
sessions.POST("", h.CreateSession)
|
||||
sessions.POST("/join", h.JoinSessionByToken)
|
||||
sessions.POST("/:id/join", h.JoinSession)
|
||||
sessions.GET("/:id", h.GetSessionStatus)
|
||||
sessions.GET("/:id/status", h.GetSessionStatus)
|
||||
sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady)
|
||||
sessions.POST("/:id/start", h.StartSession)
|
||||
sessions.POST("/:id/complete", h.ReportCompletion)
|
||||
sessions.DELETE("/:id", h.CloseSession)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSessionRequest is the HTTP request body for creating a session
|
||||
type CreateSessionRequest struct {
|
||||
SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"`
|
||||
ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"`
|
||||
ThresholdT int `json:"thresholdT" binding:"required,min=1"`
|
||||
CreatedBy string `json:"createdBy" binding:"required"`
|
||||
Participants []ParticipantInfoRequest `json:"participants,omitempty"`
|
||||
MessageHash string `json:"messageHash,omitempty"`
|
||||
ExpiresIn int64 `json:"expiresIn,omitempty"`
|
||||
}
|
||||
|
||||
// ParticipantInfoRequest represents a participant in the request
|
||||
type ParticipantInfoRequest struct {
|
||||
PartyID string `json:"party_id" binding:"required"`
|
||||
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
|
||||
}
|
||||
|
||||
// DeviceInfoRequest represents device info in the request
|
||||
type DeviceInfoRequest struct {
|
||||
DeviceType string `json:"device_type" binding:"required"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AppVersion string `json:"app_version,omitempty"`
|
||||
}
|
||||
|
||||
// CreateSessionResponse is the HTTP response for creating a session
|
||||
type CreateSessionResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
JoinToken string `json:"joinToken"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt int64 `json:"expiresAt,omitempty"`
|
||||
}
|
||||
|
||||
// CreateSession handles POST /sessions
|
||||
func (h *SessionHTTPHandler) CreateSession(c *gin.Context) {
|
||||
var req CreateSessionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate threshold
|
||||
if req.ThresholdT > req.ThresholdN {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert request to input
|
||||
participants := make([]input.ParticipantInfo, len(req.Participants))
|
||||
for i, p := range req.Participants {
|
||||
participants[i] = input.ParticipantInfo{
|
||||
PartyID: p.PartyID,
|
||||
DeviceInfo: entities.DeviceInfo{
|
||||
DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType),
|
||||
DeviceID: p.DeviceInfo.DeviceID,
|
||||
Platform: p.DeviceInfo.Platform,
|
||||
AppVersion: p.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var messageHash []byte
|
||||
if req.MessageHash != "" {
|
||||
messageHash = []byte(req.MessageHash)
|
||||
}
|
||||
|
||||
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||
if expiresIn == 0 {
|
||||
expiresIn = 10 * time.Minute // Default
|
||||
}
|
||||
|
||||
inputData := input.CreateSessionInput{
|
||||
InitiatorID: req.CreatedBy,
|
||||
SessionType: req.SessionType,
|
||||
ThresholdN: req.ThresholdN,
|
||||
ThresholdT: req.ThresholdT,
|
||||
Participants: participants,
|
||||
MessageHash: messageHash,
|
||||
ExpiresIn: expiresIn,
|
||||
}
|
||||
|
||||
output, err := h.createSessionUC.Execute(c.Request.Context(), inputData)
|
||||
if err != nil {
|
||||
logger.Error("failed to create session", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract a single join token (for E2E compatibility)
|
||||
// In a real scenario with pre-registered participants, we'd return all tokens
|
||||
// For now, generate a universal join token
|
||||
joinToken := output.SessionID.String() // Use session ID as join token for simplicity
|
||||
if len(output.JoinTokens) > 0 {
|
||||
// If there are participant-specific tokens, use the first one
|
||||
for _, token := range output.JoinTokens {
|
||||
joinToken = token
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, CreateSessionResponse{
|
||||
SessionID: output.SessionID.String(),
|
||||
JoinToken: joinToken,
|
||||
Status: "created",
|
||||
ExpiresAt: output.ExpiresAt.UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// JoinSessionRequest is the HTTP request body for joining a session
|
||||
type JoinSessionRequest struct {
|
||||
PartyID string `json:"party_id" binding:"required"`
|
||||
JoinToken string `json:"join_token" binding:"required"`
|
||||
DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"`
|
||||
}
|
||||
|
||||
// JoinSessionResponse is the HTTP response for joining a session
|
||||
type JoinSessionResponse struct {
|
||||
Success bool `json:"success"`
|
||||
SessionInfo SessionInfoDTO `json:"session_info"`
|
||||
OtherParties []PartyInfoDTO `json:"other_parties"`
|
||||
}
|
||||
|
||||
// SessionInfoDTO represents session info in responses
|
||||
type SessionInfoDTO struct {
|
||||
SessionID string `json:"session_id"`
|
||||
SessionType string `json:"session_type"`
|
||||
ThresholdN int `json:"threshold_n"`
|
||||
ThresholdT int `json:"threshold_t"`
|
||||
MessageHash string `json:"message_hash,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// PartyInfoDTO represents party info in responses
|
||||
type PartyInfoDTO struct {
|
||||
PartyID string `json:"party_id"`
|
||||
PartyIndex int `json:"party_index"`
|
||||
DeviceInfo DeviceInfoRequest `json:"device_info"`
|
||||
}
|
||||
|
||||
// JoinSessionByTokenRequest is the HTTP request body for joining by token
|
||||
type JoinSessionByTokenRequest struct {
|
||||
JoinToken string `json:"joinToken" binding:"required"`
|
||||
PartyID string `json:"partyId" binding:"required"`
|
||||
DeviceType string `json:"deviceType" binding:"required"`
|
||||
DeviceID string `json:"deviceId,omitempty"`
|
||||
}
|
||||
|
||||
// JoinSessionByTokenResponse is the HTTP response for joining by token
|
||||
type JoinSessionByTokenResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
PartyIndex int `json:"partyIndex"`
|
||||
Status string `json:"status"`
|
||||
Participants []ParticipantStatusDTO `json:"participants"`
|
||||
}
|
||||
|
||||
// ParticipantStatusDTO represents participant status in responses
|
||||
type ParticipantStatusDTO struct {
|
||||
PartyID string `json:"partyId"`
|
||||
PartyIndex int `json:"partyIndex"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// JoinSessionByToken handles POST /sessions/join (join by token without session ID)
|
||||
func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) {
|
||||
var req JoinSessionByTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error("failed to bind join session request", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Pass empty UUID - the use case will extract session ID from the JWT
|
||||
inputData := input.JoinSessionInput{
|
||||
SessionID: uuid.Nil,
|
||||
PartyID: req.PartyID,
|
||||
JoinToken: req.JoinToken,
|
||||
DeviceInfo: entities.DeviceInfo{
|
||||
DeviceType: entities.DeviceType(req.DeviceType),
|
||||
DeviceID: req.DeviceID,
|
||||
},
|
||||
}
|
||||
|
||||
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
|
||||
if err != nil {
|
||||
logger.Error("failed to join session", zap.Error(err))
|
||||
// Return 401 for authentication/token errors
|
||||
if err.Error() == "invalid token" || err.Error() == "token expired" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build participant status list
|
||||
myPartyIndex := output.PartyIndex
|
||||
|
||||
participants := make([]ParticipantStatusDTO, 0)
|
||||
participants = append(participants, ParticipantStatusDTO{
|
||||
PartyID: req.PartyID,
|
||||
Status: "joined",
|
||||
})
|
||||
for _, p := range output.OtherParties {
|
||||
participants = append(participants, ParticipantStatusDTO{
|
||||
PartyID: p.PartyID,
|
||||
Status: "joined",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, JoinSessionByTokenResponse{
|
||||
SessionID: output.SessionInfo.SessionID.String(),
|
||||
PartyIndex: myPartyIndex,
|
||||
Status: output.SessionInfo.Status,
|
||||
Participants: participants,
|
||||
})
|
||||
}
|
||||
|
||||
// JoinSession handles POST /sessions/:id/join
|
||||
func (h *SessionHTTPHandler) JoinSession(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req JoinSessionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
inputData := input.JoinSessionInput{
|
||||
SessionID: sessionID,
|
||||
PartyID: req.PartyID,
|
||||
JoinToken: req.JoinToken,
|
||||
DeviceInfo: entities.DeviceInfo{
|
||||
DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType),
|
||||
DeviceID: req.DeviceInfo.DeviceID,
|
||||
Platform: req.DeviceInfo.Platform,
|
||||
AppVersion: req.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
|
||||
output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData)
|
||||
if err != nil {
|
||||
logger.Error("failed to join session", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
otherParties := make([]PartyInfoDTO, len(output.OtherParties))
|
||||
for i, p := range output.OtherParties {
|
||||
otherParties[i] = PartyInfoDTO{
|
||||
PartyID: p.PartyID,
|
||||
PartyIndex: p.PartyIndex,
|
||||
DeviceInfo: DeviceInfoRequest{
|
||||
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||
DeviceID: p.DeviceInfo.DeviceID,
|
||||
Platform: p.DeviceInfo.Platform,
|
||||
AppVersion: p.DeviceInfo.AppVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, JoinSessionResponse{
|
||||
Success: output.Success,
|
||||
SessionInfo: SessionInfoDTO{
|
||||
SessionID: output.SessionInfo.SessionID.String(),
|
||||
SessionType: output.SessionInfo.SessionType,
|
||||
ThresholdN: output.SessionInfo.ThresholdN,
|
||||
ThresholdT: output.SessionInfo.ThresholdT,
|
||||
MessageHash: string(output.SessionInfo.MessageHash),
|
||||
Status: output.SessionInfo.Status,
|
||||
},
|
||||
OtherParties: otherParties,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionStatusResponse is the HTTP response for session status
|
||||
type SessionStatusResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Status string `json:"status"`
|
||||
ThresholdT int `json:"thresholdT"`
|
||||
ThresholdN int `json:"thresholdN"`
|
||||
Participants []ParticipantStatusDTO `json:"participants"`
|
||||
PublicKey string `json:"publicKey,omitempty"`
|
||||
}
|
||||
|
||||
// GetSessionStatus handles GET /sessions/:id/status
|
||||
func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID)
|
||||
if err != nil {
|
||||
logger.Error("failed to get session status", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert participants to DTO
|
||||
participants := make([]ParticipantStatusDTO, len(output.Participants))
|
||||
for i, p := range output.Participants {
|
||||
participants[i] = ParticipantStatusDTO{
|
||||
PartyID: p.PartyID,
|
||||
PartyIndex: p.PartyIndex,
|
||||
Status: p.Status,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SessionStatusResponse{
|
||||
SessionID: output.SessionID.String(),
|
||||
Status: output.Status,
|
||||
ThresholdT: output.ThresholdT,
|
||||
ThresholdN: output.ThresholdN,
|
||||
Participants: participants,
|
||||
PublicKey: string(output.PublicKey),
|
||||
})
|
||||
}
|
||||
|
||||
// ReportCompletionRequest is the HTTP request for reporting completion
|
||||
type ReportCompletionRequest struct {
|
||||
PartyID string `json:"party_id" binding:"required"`
|
||||
PublicKey string `json:"public_key,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ReportCompletion handles POST /sessions/:id/complete
|
||||
func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req ReportCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
inputData := input.ReportCompletionInput{
|
||||
SessionID: sessionID,
|
||||
PartyID: req.PartyID,
|
||||
PublicKey: []byte(req.PublicKey),
|
||||
Signature: []byte(req.Signature),
|
||||
}
|
||||
|
||||
output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData)
|
||||
if err != nil {
|
||||
logger.Error("failed to report completion", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": output.Success,
|
||||
"all_completed": output.AllCompleted,
|
||||
})
|
||||
}
|
||||
|
||||
// CloseSession handles DELETE /sessions/:id
|
||||
func (h *SessionHTTPHandler) CloseSession(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.closeSessionUC.Execute(c.Request.Context(), sessionID)
|
||||
if err != nil {
|
||||
logger.Error("failed to close session", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready
|
||||
func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
partyID := c.Param("partyId")
|
||||
if partyID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID))
|
||||
|
||||
// Load session
|
||||
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
|
||||
if err != nil {
|
||||
if err == entities.ErrSessionNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
logger.Error("failed to load session", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create party ID value object
|
||||
partyIDVO, err := value_objects.NewPartyID(partyID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update participant status to ready
|
||||
if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil {
|
||||
logger.Error("failed to mark party as ready", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Save session
|
||||
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
|
||||
logger.Error("failed to save session", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// StartSession handles POST /sessions/:id/start
|
||||
func (h *SessionHTTPHandler) StartSession(c *gin.Context) {
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("starting session", zap.String("session_id", sessionID.String()))
|
||||
|
||||
// Load session
|
||||
session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID)
|
||||
if err != nil {
|
||||
if err == entities.ErrSessionNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
logger.Error("failed to load session", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"})
|
||||
return
|
||||
}
|
||||
|
||||
// Start the session
|
||||
if err := session.Start(); err != nil {
|
||||
logger.Error("failed to start session", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Save session
|
||||
if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil {
|
||||
logger.Error("failed to save session", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// HealthCheck handles GET /health
|
||||
func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"service": "session-coordinator",
|
||||
"time": time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/lib/pq"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
||||
"github.com/rwadurian/mpc-system/pkg/config"
|
||||
"github.com/rwadurian/mpc-system/pkg/jwt"
|
||||
"github.com/rwadurian/mpc-system/pkg/logger"
|
||||
grpcadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/grpc"
|
||||
httphandler "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/input/http"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/postgres"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/rabbitmq"
|
||||
redisadapter "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/redis"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Parse flags
|
||||
configPath := flag.String("config", "", "Path to config file")
|
||||
flag.Parse()
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := logger.Init(&logger.Config{
|
||||
Level: cfg.Logger.Level,
|
||||
Encoding: cfg.Logger.Encoding,
|
||||
}); err != nil {
|
||||
fmt.Printf("Failed to initialize logger: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
logger.Info("Starting Session Coordinator Service",
|
||||
zap.String("environment", cfg.Server.Environment),
|
||||
zap.Int("grpc_port", cfg.Server.GRPCPort),
|
||||
zap.Int("http_port", cfg.Server.HTTPPort))
|
||||
|
||||
// Initialize database connection
|
||||
db, err := initDatabase(cfg.Database)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to connect to database", zap.Error(err))
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Initialize Redis connection
|
||||
redisClient := initRedis(cfg.Redis)
|
||||
defer redisClient.Close()
|
||||
|
||||
// Initialize RabbitMQ connection
|
||||
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
|
||||
}
|
||||
defer rabbitConn.Close()
|
||||
|
||||
// Initialize repositories and adapters
|
||||
sessionRepo := postgres.NewSessionPostgresRepo(db)
|
||||
messageRepo := postgres.NewMessagePostgresRepo(db)
|
||||
sessionCache := redisadapter.NewSessionCacheAdapter(redisClient)
|
||||
eventPublisher, err := rabbitmq.NewEventPublisherAdapter(rabbitConn)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create event publisher", zap.Error(err))
|
||||
}
|
||||
defer eventPublisher.Close()
|
||||
|
||||
// Initialize JWT service
|
||||
jwtService := jwt.NewJWTService(
|
||||
cfg.JWT.SecretKey,
|
||||
cfg.JWT.Issuer,
|
||||
cfg.JWT.TokenExpiry,
|
||||
cfg.JWT.RefreshExpiry,
|
||||
)
|
||||
|
||||
// Initialize use cases
|
||||
createSessionUC := use_cases.NewCreateSessionUseCase(sessionRepo, jwtService, eventPublisher)
|
||||
joinSessionUC := use_cases.NewJoinSessionUseCase(sessionRepo, jwtService, eventPublisher)
|
||||
getSessionStatusUC := use_cases.NewGetSessionStatusUseCase(sessionRepo)
|
||||
reportCompletionUC := use_cases.NewReportCompletionUseCase(sessionRepo, eventPublisher)
|
||||
closeSessionUC := use_cases.NewCloseSessionUseCase(sessionRepo, messageRepo, eventPublisher)
|
||||
expireSessionsUC := use_cases.NewExpireSessionsUseCase(sessionRepo, eventPublisher)
|
||||
|
||||
// Start session expiration background job
|
||||
go runSessionExpiration(expireSessionsUC)
|
||||
|
||||
// Create shutdown context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start servers
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
// Start gRPC server
|
||||
go func() {
|
||||
if err := startGRPCServer(
|
||||
cfg,
|
||||
createSessionUC,
|
||||
joinSessionUC,
|
||||
getSessionStatusUC,
|
||||
reportCompletionUC,
|
||||
closeSessionUC,
|
||||
); err != nil {
|
||||
errChan <- fmt.Errorf("gRPC server error: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start HTTP server
|
||||
go func() {
|
||||
if err := startHTTPServer(
|
||||
cfg,
|
||||
createSessionUC,
|
||||
joinSessionUC,
|
||||
getSessionStatusUC,
|
||||
reportCompletionUC,
|
||||
closeSessionUC,
|
||||
sessionRepo,
|
||||
); err != nil {
|
||||
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for shutdown signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case sig := <-sigChan:
|
||||
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
||||
case err := <-errChan:
|
||||
logger.Error("Server error", zap.Error(err))
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
logger.Info("Shutting down...")
|
||||
cancel()
|
||||
|
||||
// Give services time to shutdown gracefully
|
||||
time.Sleep(5 * time.Second)
|
||||
logger.Info("Shutdown complete")
|
||||
|
||||
_ = ctx
|
||||
_ = sessionCache
|
||||
}
|
||||
|
||||
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", cfg.DSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Connected to PostgreSQL")
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initRedis(cfg config.RedisConfig) *redis.Client {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr(),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx := context.Background()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
logger.Warn("Redis connection failed, continuing without cache", zap.Error(err))
|
||||
} else {
|
||||
logger.Info("Connected to Redis")
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
|
||||
conn, err := amqp.Dial(cfg.URL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Connected to RabbitMQ")
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func startGRPCServer(
|
||||
cfg *config.Config,
|
||||
createSessionUC *use_cases.CreateSessionUseCase,
|
||||
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||
) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
// Register services (using our custom handler, not generated proto)
|
||||
// In production, you would register the generated proto service
|
||||
_ = grpcadapter.NewSessionCoordinatorServer(
|
||||
createSessionUC,
|
||||
joinSessionUC,
|
||||
getSessionStatusUC,
|
||||
reportCompletionUC,
|
||||
closeSessionUC,
|
||||
)
|
||||
|
||||
// Enable reflection for debugging
|
||||
reflection.Register(grpcServer)
|
||||
|
||||
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
|
||||
return grpcServer.Serve(listener)
|
||||
}
|
||||
|
||||
func startHTTPServer(
|
||||
cfg *config.Config,
|
||||
createSessionUC *use_cases.CreateSessionUseCase,
|
||||
joinSessionUC *use_cases.JoinSessionUseCase,
|
||||
getSessionStatusUC *use_cases.GetSessionStatusUseCase,
|
||||
reportCompletionUC *use_cases.ReportCompletionUseCase,
|
||||
closeSessionUC *use_cases.CloseSessionUseCase,
|
||||
sessionRepo repositories.SessionRepository,
|
||||
) error {
|
||||
// Set Gin mode
|
||||
if cfg.Server.Environment == "production" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(gin.Recovery())
|
||||
router.Use(gin.Logger())
|
||||
|
||||
// Create HTTP handler
|
||||
httpHandler := httphandler.NewSessionHTTPHandler(
|
||||
createSessionUC,
|
||||
joinSessionUC,
|
||||
getSessionStatusUC,
|
||||
reportCompletionUC,
|
||||
closeSessionUC,
|
||||
sessionRepo,
|
||||
)
|
||||
|
||||
// Health check
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"service": "session-coordinator",
|
||||
})
|
||||
})
|
||||
|
||||
// Register API routes
|
||||
api := router.Group("/api/v1")
|
||||
httpHandler.RegisterRoutes(api)
|
||||
|
||||
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
||||
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
||||
}
|
||||
|
||||
func runSessionExpiration(expireSessionsUC *use_cases.ExpireSessionsUseCase) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
count, err := expireSessionsUC.Execute(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to expire sessions", zap.Error(err))
|
||||
} else if count > 0 {
|
||||
logger.Info("Expired stale sessions", zap.Int("count", count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
package entities
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionFull = errors.New("session is full")
|
||||
ErrSessionExpired = errors.New("session expired")
|
||||
ErrSessionNotInProgress = errors.New("session not in progress")
|
||||
ErrParticipantNotFound = errors.New("participant not found")
|
||||
ErrInvalidSessionType = errors.New("invalid session type")
|
||||
ErrInvalidStatusTransition = errors.New("invalid status transition")
|
||||
)
|
||||
|
||||
// SessionType represents the type of MPC session
|
||||
type SessionType string
|
||||
|
||||
const (
|
||||
SessionTypeKeygen SessionType = "keygen"
|
||||
SessionTypeSign SessionType = "sign"
|
||||
)
|
||||
|
||||
// IsValid checks if the session type is valid
|
||||
func (t SessionType) IsValid() bool {
|
||||
return t == SessionTypeKeygen || t == SessionTypeSign
|
||||
}
|
||||
|
||||
// MPCSession represents an MPC session
|
||||
// Coordinator only manages session metadata, does not participate in MPC computation
|
||||
type MPCSession struct {
|
||||
ID value_objects.SessionID
|
||||
SessionType SessionType
|
||||
Threshold value_objects.Threshold
|
||||
Participants []*Participant
|
||||
Status value_objects.SessionStatus
|
||||
MessageHash []byte // Used for Sign sessions
|
||||
PublicKey []byte // Group public key after Keygen completion
|
||||
CreatedBy string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
CompletedAt *time.Time
|
||||
}
|
||||
|
||||
// NewMPCSession creates a new MPC session
|
||||
func NewMPCSession(
|
||||
sessionType SessionType,
|
||||
threshold value_objects.Threshold,
|
||||
createdBy string,
|
||||
expiresIn time.Duration,
|
||||
messageHash []byte, // Only for Sign sessions
|
||||
) (*MPCSession, error) {
|
||||
if !sessionType.IsValid() {
|
||||
return nil, ErrInvalidSessionType
|
||||
}
|
||||
|
||||
if sessionType == SessionTypeSign && len(messageHash) == 0 {
|
||||
return nil, errors.New("message hash required for sign session")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
return &MPCSession{
|
||||
ID: value_objects.NewSessionID(),
|
||||
SessionType: sessionType,
|
||||
Threshold: threshold,
|
||||
Participants: make([]*Participant, 0, threshold.N()),
|
||||
Status: value_objects.SessionStatusCreated,
|
||||
MessageHash: messageHash,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ExpiresAt: now.Add(expiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddParticipant adds a participant to the session
|
||||
func (s *MPCSession) AddParticipant(p *Participant) error {
|
||||
if len(s.Participants) >= s.Threshold.N() {
|
||||
return ErrSessionFull
|
||||
}
|
||||
s.Participants = append(s.Participants, p)
|
||||
s.UpdatedAt = time.Now().UTC()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetParticipant gets a participant by party ID
|
||||
func (s *MPCSession) GetParticipant(partyID value_objects.PartyID) (*Participant, error) {
|
||||
for _, p := range s.Participants {
|
||||
if p.PartyID.Equals(partyID) {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrParticipantNotFound
|
||||
}
|
||||
|
||||
// UpdateParticipantStatus updates a participant's status
|
||||
func (s *MPCSession) UpdateParticipantStatus(partyID value_objects.PartyID, status value_objects.ParticipantStatus) error {
|
||||
for _, p := range s.Participants {
|
||||
if p.PartyID.Equals(partyID) {
|
||||
switch status {
|
||||
case value_objects.ParticipantStatusJoined:
|
||||
return p.Join()
|
||||
case value_objects.ParticipantStatusReady:
|
||||
return p.MarkReady()
|
||||
case value_objects.ParticipantStatusCompleted:
|
||||
return p.MarkCompleted()
|
||||
case value_objects.ParticipantStatusFailed:
|
||||
p.MarkFailed()
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid status")
|
||||
}
|
||||
}
|
||||
}
|
||||
return ErrParticipantNotFound
|
||||
}
|
||||
|
||||
// CanStart checks if all participants have joined and the session can start
|
||||
func (s *MPCSession) CanStart() bool {
|
||||
if len(s.Participants) != s.Threshold.N() {
|
||||
return false
|
||||
}
|
||||
|
||||
readyCount := 0
|
||||
for _, p := range s.Participants {
|
||||
// Accept participants in either joined or ready status
|
||||
if p.IsJoined() || p.IsReady() {
|
||||
readyCount++
|
||||
}
|
||||
}
|
||||
return readyCount == s.Threshold.N()
|
||||
}
|
||||
|
||||
// Start transitions the session to in_progress
|
||||
func (s *MPCSession) Start() error {
|
||||
// If already in progress, just return success (idempotent)
|
||||
if s.Status == value_objects.SessionStatusInProgress {
|
||||
return nil
|
||||
}
|
||||
if !s.Status.CanTransitionTo(value_objects.SessionStatusInProgress) {
|
||||
return ErrInvalidStatusTransition
|
||||
}
|
||||
if !s.CanStart() {
|
||||
return errors.New("not all participants have joined")
|
||||
}
|
||||
s.Status = value_objects.SessionStatusInProgress
|
||||
s.UpdatedAt = time.Now().UTC()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Complete marks the session as completed
|
||||
func (s *MPCSession) Complete(publicKey []byte) error {
|
||||
if !s.Status.CanTransitionTo(value_objects.SessionStatusCompleted) {
|
||||
return ErrInvalidStatusTransition
|
||||
}
|
||||
s.Status = value_objects.SessionStatusCompleted
|
||||
s.PublicKey = publicKey
|
||||
now := time.Now().UTC()
|
||||
s.CompletedAt = &now
|
||||
s.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fail marks the session as failed
|
||||
func (s *MPCSession) Fail() error {
|
||||
if !s.Status.CanTransitionTo(value_objects.SessionStatusFailed) {
|
||||
return ErrInvalidStatusTransition
|
||||
}
|
||||
s.Status = value_objects.SessionStatusFailed
|
||||
s.UpdatedAt = time.Now().UTC()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Expire marks the session as expired
|
||||
func (s *MPCSession) Expire() error {
|
||||
if !s.Status.CanTransitionTo(value_objects.SessionStatusExpired) {
|
||||
return ErrInvalidStatusTransition
|
||||
}
|
||||
s.Status = value_objects.SessionStatusExpired
|
||||
s.UpdatedAt = time.Now().UTC()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *MPCSession) IsExpired() bool {
|
||||
return time.Now().UTC().After(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsActive checks if the session is active
|
||||
func (s *MPCSession) IsActive() bool {
|
||||
return s.Status.IsActive() && !s.IsExpired()
|
||||
}
|
||||
|
||||
// IsParticipant checks if a party is a participant
|
||||
func (s *MPCSession) IsParticipant(partyID value_objects.PartyID) bool {
|
||||
for _, p := range s.Participants {
|
||||
if p.PartyID.Equals(partyID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AllCompleted checks if all participants have completed
|
||||
func (s *MPCSession) AllCompleted() bool {
|
||||
for _, p := range s.Participants {
|
||||
if !p.IsCompleted() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CompletedCount returns the number of completed participants
|
||||
func (s *MPCSession) CompletedCount() int {
|
||||
count := 0
|
||||
for _, p := range s.Participants {
|
||||
if p.IsCompleted() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// JoinedCount returns the number of joined participants
|
||||
func (s *MPCSession) JoinedCount() int {
|
||||
count := 0
|
||||
for _, p := range s.Participants {
|
||||
if p.IsJoined() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GetPartyIDs returns all party IDs
|
||||
func (s *MPCSession) GetPartyIDs() []string {
|
||||
ids := make([]string, len(s.Participants))
|
||||
for i, p := range s.Participants {
|
||||
ids[i] = p.PartyID.String()
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetOtherParties returns participants except the specified party
|
||||
func (s *MPCSession) GetOtherParties(excludePartyID value_objects.PartyID) []*Participant {
|
||||
others := make([]*Participant, 0, len(s.Participants)-1)
|
||||
for _, p := range s.Participants {
|
||||
if !p.PartyID.Equals(excludePartyID) {
|
||||
others = append(others, p)
|
||||
}
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
// ToDTO converts to a DTO for API responses
|
||||
func (s *MPCSession) ToDTO() SessionDTO {
|
||||
participants := make([]ParticipantDTO, len(s.Participants))
|
||||
for i, p := range s.Participants {
|
||||
participants[i] = ParticipantDTO{
|
||||
PartyID: p.PartyID.String(),
|
||||
PartyIndex: p.PartyIndex,
|
||||
Status: p.Status.String(),
|
||||
DeviceType: string(p.DeviceInfo.DeviceType),
|
||||
}
|
||||
}
|
||||
|
||||
return SessionDTO{
|
||||
ID: s.ID.String(),
|
||||
SessionType: string(s.SessionType),
|
||||
ThresholdN: s.Threshold.N(),
|
||||
ThresholdT: s.Threshold.T(),
|
||||
Participants: participants,
|
||||
Status: s.Status.String(),
|
||||
CreatedAt: s.CreatedAt,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
// SessionDTO is a data transfer object for sessions
|
||||
type SessionDTO struct {
|
||||
ID string `json:"id"`
|
||||
SessionType string `json:"session_type"`
|
||||
ThresholdN int `json:"threshold_n"`
|
||||
ThresholdT int `json:"threshold_t"`
|
||||
Participants []ParticipantDTO `json:"participants"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// ParticipantDTO is a data transfer object for participants
|
||||
type ParticipantDTO struct {
|
||||
PartyID string `json:"party_id"`
|
||||
PartyIndex int `json:"party_index"`
|
||||
Status string `json:"status"`
|
||||
DeviceType string `json:"device_type"`
|
||||
}
|
||||
|
||||
// Reconstruct reconstructs an MPCSession from database
|
||||
func ReconstructSession(
|
||||
id uuid.UUID,
|
||||
sessionType string,
|
||||
thresholdT, thresholdN int,
|
||||
status string,
|
||||
messageHash, publicKey []byte,
|
||||
createdBy string,
|
||||
createdAt, updatedAt, expiresAt time.Time,
|
||||
completedAt *time.Time,
|
||||
participants []*Participant,
|
||||
) (*MPCSession, error) {
|
||||
sessionStatus, err := value_objects.NewSessionStatus(status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
threshold, err := value_objects.NewThreshold(thresholdT, thresholdN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MPCSession{
|
||||
ID: value_objects.SessionIDFromUUID(id),
|
||||
SessionType: SessionType(sessionType),
|
||||
Threshold: threshold,
|
||||
Participants: participants,
|
||||
Status: sessionStatus,
|
||||
MessageHash: messageHash,
|
||||
PublicKey: publicKey,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
CompletedAt: completedAt,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,354 @@
|
|||
//go:build e2e
|
||||
|
||||
package e2e_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type KeygenFlowTestSuite struct {
|
||||
suite.Suite
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func TestKeygenFlowSuite(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping e2e test in short mode")
|
||||
}
|
||||
suite.Run(t, new(KeygenFlowTestSuite))
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) SetupSuite() {
|
||||
s.baseURL = os.Getenv("SESSION_COORDINATOR_URL")
|
||||
if s.baseURL == "" {
|
||||
s.baseURL = "http://localhost:8080"
|
||||
}
|
||||
|
||||
s.client = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Wait for service to be ready
|
||||
s.waitForService()
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) waitForService() {
|
||||
maxRetries := 30
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
resp, err := s.client.Get(s.baseURL + "/health")
|
||||
if err == nil && resp.StatusCode == http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
s.T().Fatal("Service not ready after waiting")
|
||||
}
|
||||
|
||||
type CreateSessionRequest struct {
|
||||
SessionType string `json:"sessionType"`
|
||||
ThresholdT int `json:"thresholdT"`
|
||||
ThresholdN int `json:"thresholdN"`
|
||||
CreatedBy string `json:"createdBy"`
|
||||
}
|
||||
|
||||
type CreateSessionResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
JoinToken string `json:"joinToken"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type JoinSessionRequest struct {
|
||||
JoinToken string `json:"joinToken"`
|
||||
PartyID string `json:"partyId"`
|
||||
DeviceType string `json:"deviceType"`
|
||||
DeviceID string `json:"deviceId"`
|
||||
}
|
||||
|
||||
type JoinSessionResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
PartyIndex int `json:"partyIndex"`
|
||||
Status string `json:"status"`
|
||||
Participants []struct {
|
||||
PartyID string `json:"partyId"`
|
||||
Status string `json:"status"`
|
||||
} `json:"participants"`
|
||||
}
|
||||
|
||||
type SessionStatusResponse struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Status string `json:"status"`
|
||||
ThresholdT int `json:"thresholdT"`
|
||||
ThresholdN int `json:"thresholdN"`
|
||||
Participants []struct {
|
||||
PartyID string `json:"partyId"`
|
||||
PartyIndex int `json:"partyIndex"`
|
||||
Status string `json:"status"`
|
||||
} `json:"participants"`
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) TestCompleteKeygenFlow() {
|
||||
// Step 1: Create a new keygen session
|
||||
createReq := CreateSessionRequest{
|
||||
SessionType: "keygen",
|
||||
ThresholdT: 2,
|
||||
ThresholdN: 3,
|
||||
CreatedBy: "e2e_test_user",
|
||||
}
|
||||
|
||||
createResp := s.createSession(createReq)
|
||||
require.NotEmpty(s.T(), createResp.SessionID)
|
||||
require.NotEmpty(s.T(), createResp.JoinToken)
|
||||
assert.Equal(s.T(), "created", createResp.Status)
|
||||
|
||||
sessionID := createResp.SessionID
|
||||
joinToken := createResp.JoinToken
|
||||
|
||||
// Step 2: First party joins
|
||||
joinReq1 := JoinSessionRequest{
|
||||
JoinToken: joinToken,
|
||||
PartyID: "party_user_device",
|
||||
DeviceType: "iOS",
|
||||
DeviceID: "device_001",
|
||||
}
|
||||
joinResp1 := s.joinSession(joinReq1)
|
||||
assert.Equal(s.T(), sessionID, joinResp1.SessionID)
|
||||
assert.Equal(s.T(), 0, joinResp1.PartyIndex)
|
||||
|
||||
// Step 3: Second party joins
|
||||
joinReq2 := JoinSessionRequest{
|
||||
JoinToken: joinToken,
|
||||
PartyID: "party_server",
|
||||
DeviceType: "server",
|
||||
DeviceID: "server_001",
|
||||
}
|
||||
joinResp2 := s.joinSession(joinReq2)
|
||||
assert.Equal(s.T(), sessionID, joinResp2.SessionID)
|
||||
assert.Equal(s.T(), 1, joinResp2.PartyIndex)
|
||||
|
||||
// Step 4: Third party joins
|
||||
joinReq3 := JoinSessionRequest{
|
||||
JoinToken: joinToken,
|
||||
PartyID: "party_recovery",
|
||||
DeviceType: "recovery",
|
||||
DeviceID: "recovery_001",
|
||||
}
|
||||
joinResp3 := s.joinSession(joinReq3)
|
||||
assert.Equal(s.T(), sessionID, joinResp3.SessionID)
|
||||
assert.Equal(s.T(), 2, joinResp3.PartyIndex)
|
||||
|
||||
// Step 5: Check session status - should have all participants
|
||||
statusResp := s.getSessionStatus(sessionID)
|
||||
assert.Equal(s.T(), 3, len(statusResp.Participants))
|
||||
|
||||
// Step 6: Mark parties as ready
|
||||
s.markPartyReady(sessionID, "party_user_device")
|
||||
s.markPartyReady(sessionID, "party_server")
|
||||
s.markPartyReady(sessionID, "party_recovery")
|
||||
|
||||
// Step 7: Start the session
|
||||
s.startSession(sessionID)
|
||||
|
||||
// Step 8: Verify session is in progress
|
||||
statusResp = s.getSessionStatus(sessionID)
|
||||
assert.Equal(s.T(), "in_progress", statusResp.Status)
|
||||
|
||||
// Step 9: Report completion (simulate keygen completion)
|
||||
publicKey := []byte("test-group-public-key-from-keygen")
|
||||
s.reportCompletion(sessionID, "party_user_device", publicKey)
|
||||
|
||||
// Step 10: Verify session is completed
|
||||
statusResp = s.getSessionStatus(sessionID)
|
||||
assert.Equal(s.T(), "completed", statusResp.Status)
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) TestJoinSessionWithInvalidToken() {
|
||||
joinReq := JoinSessionRequest{
|
||||
JoinToken: "invalid-token",
|
||||
PartyID: "party_test",
|
||||
DeviceType: "iOS",
|
||||
DeviceID: "device_test",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(joinReq)
|
||||
resp, err := s.client.Post(
|
||||
s.baseURL+"/api/v1/sessions/join",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(s.T(), http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) TestGetNonExistentSession() {
|
||||
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/00000000-0000-0000-0000-000000000000")
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(s.T(), http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) TestExceedParticipantLimit() {
|
||||
// Create session with 2 participants max
|
||||
createReq := CreateSessionRequest{
|
||||
SessionType: "keygen",
|
||||
ThresholdT: 2,
|
||||
ThresholdN: 2,
|
||||
CreatedBy: "e2e_test_user_limit",
|
||||
}
|
||||
|
||||
createResp := s.createSession(createReq)
|
||||
joinToken := createResp.JoinToken
|
||||
|
||||
// Join with 2 parties (should succeed)
|
||||
for i := 0; i < 2; i++ {
|
||||
joinReq := JoinSessionRequest{
|
||||
JoinToken: joinToken,
|
||||
PartyID: "party_" + string(rune('a'+i)),
|
||||
DeviceType: "test",
|
||||
DeviceID: "device_" + string(rune('a'+i)),
|
||||
}
|
||||
s.joinSession(joinReq)
|
||||
}
|
||||
|
||||
// Try to join with 3rd party (should fail)
|
||||
joinReq := JoinSessionRequest{
|
||||
JoinToken: joinToken,
|
||||
PartyID: "party_extra",
|
||||
DeviceType: "test",
|
||||
DeviceID: "device_extra",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(joinReq)
|
||||
resp, err := s.client.Post(
|
||||
s.baseURL+"/api/v1/sessions/join",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(s.T(), http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *KeygenFlowTestSuite) createSession(req CreateSessionRequest) CreateSessionResponse {
|
||||
body, _ := json.Marshal(req)
|
||||
resp, err := s.client.Post(
|
||||
s.baseURL+"/api/v1/sessions",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusCreated, resp.StatusCode)
|
||||
|
||||
var result CreateSessionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) joinSession(req JoinSessionRequest) JoinSessionResponse {
|
||||
body, _ := json.Marshal(req)
|
||||
resp, err := s.client.Post(
|
||||
s.baseURL+"/api/v1/sessions/join",
|
||||
"application/json",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||
|
||||
var result JoinSessionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) getSessionStatus(sessionID string) SessionStatusResponse {
|
||||
resp, err := s.client.Get(s.baseURL + "/api/v1/sessions/" + sessionID)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||
|
||||
var result SessionStatusResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) markPartyReady(sessionID, partyID string) {
|
||||
req := map[string]string{"partyId": partyID}
|
||||
body, _ := json.Marshal(req)
|
||||
|
||||
httpReq, _ := http.NewRequest(
|
||||
http.MethodPut,
|
||||
s.baseURL+"/api/v1/sessions/"+sessionID+"/parties/"+partyID+"/ready",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) startSession(sessionID string) {
|
||||
httpReq, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
s.baseURL+"/api/v1/sessions/"+sessionID+"/start",
|
||||
nil,
|
||||
)
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (s *KeygenFlowTestSuite) reportCompletion(sessionID string, partyID string, publicKey []byte) {
|
||||
req := map[string]interface{}{
|
||||
"party_id": partyID,
|
||||
"public_key": string(publicKey),
|
||||
}
|
||||
body, _ := json.Marshal(req)
|
||||
|
||||
httpReq, _ := http.NewRequest(
|
||||
http.MethodPost,
|
||||
s.baseURL+"/api/v1/sessions/"+sessionID+"/complete",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
require.NoError(s.T(), err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(s.T(), http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
Loading…
Reference in New Issue