diff --git a/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go b/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go index 93a46685..54540a35 100644 --- a/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go +++ b/backend/mpc-system/services/session-coordinator/adapters/input/http/session_http_handler.go @@ -1,543 +1,545 @@ -package http - -import ( - "net/http" - "time" - - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "github.com/rwadurian/mpc-system/pkg/logger" - "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" - "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" - "go.uber.org/zap" -) - -// SessionHTTPHandler handles HTTP requests for session management -type SessionHTTPHandler struct { - createSessionUC *use_cases.CreateSessionUseCase - joinSessionUC *use_cases.JoinSessionUseCase - getSessionStatusUC *use_cases.GetSessionStatusUseCase - reportCompletionUC *use_cases.ReportCompletionUseCase - closeSessionUC *use_cases.CloseSessionUseCase - sessionRepo repositories.SessionRepository -} - -// NewSessionHTTPHandler creates a new HTTP handler -func NewSessionHTTPHandler( - createSessionUC *use_cases.CreateSessionUseCase, - joinSessionUC *use_cases.JoinSessionUseCase, - getSessionStatusUC *use_cases.GetSessionStatusUseCase, - reportCompletionUC *use_cases.ReportCompletionUseCase, - closeSessionUC *use_cases.CloseSessionUseCase, - sessionRepo repositories.SessionRepository, -) *SessionHTTPHandler { - return &SessionHTTPHandler{ - createSessionUC: createSessionUC, - joinSessionUC: joinSessionUC, - getSessionStatusUC: getSessionStatusUC, - reportCompletionUC: reportCompletionUC, - closeSessionUC: closeSessionUC, - sessionRepo: sessionRepo, - } -} - -// RegisterRoutes registers HTTP routes -func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) { - sessions := r.Group("/sessions") - { - sessions.POST("", h.CreateSession) - sessions.POST("/join", h.JoinSessionByToken) - sessions.POST("/:id/join", h.JoinSession) - sessions.GET("/:id", h.GetSessionStatus) - sessions.GET("/:id/status", h.GetSessionStatus) - sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady) - sessions.POST("/:id/start", h.StartSession) - sessions.POST("/:id/complete", h.ReportCompletion) - sessions.DELETE("/:id", h.CloseSession) - } -} - -// CreateSessionRequest is the HTTP request body for creating a session -type CreateSessionRequest struct { - SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"` - ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"` - ThresholdT int `json:"thresholdT" binding:"required,min=1"` - CreatedBy string `json:"createdBy" binding:"required"` - Participants []ParticipantInfoRequest `json:"participants,omitempty"` - MessageHash string `json:"messageHash,omitempty"` - ExpiresIn int64 `json:"expiresIn,omitempty"` -} - -// ParticipantInfoRequest represents a participant in the request -type ParticipantInfoRequest struct { - PartyID string `json:"party_id" binding:"required"` - DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` -} - -// DeviceInfoRequest represents device info in the request -type DeviceInfoRequest struct { - DeviceType string `json:"device_type" binding:"required"` - DeviceID string `json:"device_id,omitempty"` - Platform string `json:"platform,omitempty"` - AppVersion string `json:"app_version,omitempty"` -} - -// CreateSessionResponse is the HTTP response for creating a session -type CreateSessionResponse struct { - SessionID string `json:"sessionId"` - JoinToken string `json:"joinToken"` - Status string `json:"status"` - ExpiresAt int64 `json:"expiresAt,omitempty"` -} - -// CreateSession handles POST /sessions -func (h *SessionHTTPHandler) CreateSession(c *gin.Context) { - var req CreateSessionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Validate threshold - if req.ThresholdT > req.ThresholdN { - c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"}) - return - } - - // Convert request to input - participants := make([]input.ParticipantInfo, len(req.Participants)) - for i, p := range req.Participants { - participants[i] = input.ParticipantInfo{ - PartyID: p.PartyID, - DeviceInfo: entities.DeviceInfo{ - DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), - DeviceID: p.DeviceInfo.DeviceID, - Platform: p.DeviceInfo.Platform, - AppVersion: p.DeviceInfo.AppVersion, - }, - } - } - - var messageHash []byte - if req.MessageHash != "" { - messageHash = []byte(req.MessageHash) - } - - expiresIn := time.Duration(req.ExpiresIn) * time.Second - if expiresIn == 0 { - expiresIn = 24 * time.Hour // Default: 24-hour session validity - } - - inputData := input.CreateSessionInput{ - InitiatorID: req.CreatedBy, - SessionType: req.SessionType, - ThresholdN: req.ThresholdN, - ThresholdT: req.ThresholdT, - Participants: participants, - MessageHash: messageHash, - ExpiresIn: expiresIn, - } - - output, err := h.createSessionUC.Execute(c.Request.Context(), inputData) - if err != nil { - logger.Error("failed to create session", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // Extract a single join token (for E2E compatibility) - // In a real scenario with pre-registered participants, we'd return all tokens - // For now, generate a universal join token - joinToken := output.SessionID.String() // Use session ID as join token for simplicity - if len(output.JoinTokens) > 0 { - // If there are participant-specific tokens, use the first one - for _, token := range output.JoinTokens { - joinToken = token - break - } - } - - c.JSON(http.StatusCreated, CreateSessionResponse{ - SessionID: output.SessionID.String(), - JoinToken: joinToken, - Status: "created", - ExpiresAt: output.ExpiresAt.UnixMilli(), - }) -} - -// JoinSessionRequest is the HTTP request body for joining a session -type JoinSessionRequest struct { - PartyID string `json:"party_id" binding:"required"` - JoinToken string `json:"join_token" binding:"required"` - DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` -} - -// JoinSessionResponse is the HTTP response for joining a session -type JoinSessionResponse struct { - Success bool `json:"success"` - SessionInfo SessionInfoDTO `json:"session_info"` - OtherParties []PartyInfoDTO `json:"other_parties"` -} - -// SessionInfoDTO represents session info in responses -type SessionInfoDTO struct { - SessionID string `json:"session_id"` - SessionType string `json:"session_type"` - ThresholdN int `json:"threshold_n"` - ThresholdT int `json:"threshold_t"` - MessageHash string `json:"message_hash,omitempty"` - Status string `json:"status"` -} - -// PartyInfoDTO represents party info in responses -type PartyInfoDTO struct { - PartyID string `json:"party_id"` - PartyIndex int `json:"party_index"` - DeviceInfo DeviceInfoRequest `json:"device_info"` -} - -// JoinSessionByTokenRequest is the HTTP request body for joining by token -type JoinSessionByTokenRequest struct { - JoinToken string `json:"joinToken" binding:"required"` - PartyID string `json:"partyId" binding:"required"` - DeviceType string `json:"deviceType" binding:"required"` - DeviceID string `json:"deviceId,omitempty"` -} - -// JoinSessionByTokenResponse is the HTTP response for joining by token -type JoinSessionByTokenResponse struct { - SessionID string `json:"sessionId"` - PartyIndex int `json:"partyIndex"` - Status string `json:"status"` - Participants []ParticipantStatusDTO `json:"participants"` -} - -// ParticipantStatusDTO represents participant status in responses -type ParticipantStatusDTO struct { - PartyID string `json:"partyId"` - PartyIndex int `json:"partyIndex"` - Status string `json:"status"` -} - -// JoinSessionByToken handles POST /sessions/join (join by token without session ID) -func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) { - var req JoinSessionByTokenRequest - if err := c.ShouldBindJSON(&req); err != nil { - logger.Error("failed to bind join session request", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Pass empty UUID - the use case will extract session ID from the JWT - inputData := input.JoinSessionInput{ - SessionID: uuid.Nil, - PartyID: req.PartyID, - JoinToken: req.JoinToken, - DeviceInfo: entities.DeviceInfo{ - DeviceType: entities.DeviceType(req.DeviceType), - DeviceID: req.DeviceID, - }, - } - - output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) - if err != nil { - logger.Error("failed to join session", zap.Error(err)) - // Return 401 for authentication/token errors - if err.Error() == "invalid token" || err.Error() == "token expired" { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Build participant status list - myPartyIndex := output.PartyIndex - - participants := make([]ParticipantStatusDTO, 0) - participants = append(participants, ParticipantStatusDTO{ - PartyID: req.PartyID, - Status: "joined", - }) - for _, p := range output.OtherParties { - participants = append(participants, ParticipantStatusDTO{ - PartyID: p.PartyID, - Status: "joined", - }) - } - - c.JSON(http.StatusOK, JoinSessionByTokenResponse{ - SessionID: output.SessionInfo.SessionID.String(), - PartyIndex: myPartyIndex, - Status: output.SessionInfo.Status, - Participants: participants, - }) -} - -// JoinSession handles POST /sessions/:id/join -func (h *SessionHTTPHandler) JoinSession(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - var req JoinSessionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - inputData := input.JoinSessionInput{ - SessionID: sessionID, - PartyID: req.PartyID, - JoinToken: req.JoinToken, - DeviceInfo: entities.DeviceInfo{ - DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), - DeviceID: req.DeviceInfo.DeviceID, - Platform: req.DeviceInfo.Platform, - AppVersion: req.DeviceInfo.AppVersion, - }, - } - - output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) - if err != nil { - logger.Error("failed to join session", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - otherParties := make([]PartyInfoDTO, len(output.OtherParties)) - for i, p := range output.OtherParties { - otherParties[i] = PartyInfoDTO{ - PartyID: p.PartyID, - PartyIndex: p.PartyIndex, - DeviceInfo: DeviceInfoRequest{ - DeviceType: string(p.DeviceInfo.DeviceType), - DeviceID: p.DeviceInfo.DeviceID, - Platform: p.DeviceInfo.Platform, - AppVersion: p.DeviceInfo.AppVersion, - }, - } - } - - c.JSON(http.StatusOK, JoinSessionResponse{ - Success: output.Success, - SessionInfo: SessionInfoDTO{ - SessionID: output.SessionInfo.SessionID.String(), - SessionType: output.SessionInfo.SessionType, - ThresholdN: output.SessionInfo.ThresholdN, - ThresholdT: output.SessionInfo.ThresholdT, - MessageHash: string(output.SessionInfo.MessageHash), - Status: output.SessionInfo.Status, - }, - OtherParties: otherParties, - }) -} - -// SessionStatusResponse is the HTTP response for session status -type SessionStatusResponse struct { - SessionID string `json:"sessionId"` - Status string `json:"status"` - ThresholdT int `json:"thresholdT"` - ThresholdN int `json:"thresholdN"` - Participants []ParticipantStatusDTO `json:"participants"` - PublicKey string `json:"publicKey,omitempty"` -} - -// GetSessionStatus handles GET /sessions/:id/status -func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID) - if err != nil { - logger.Error("failed to get session status", zap.Error(err)) - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - - // Convert participants to DTO - participants := make([]ParticipantStatusDTO, len(output.Participants)) - for i, p := range output.Participants { - participants[i] = ParticipantStatusDTO{ - PartyID: p.PartyID, - PartyIndex: p.PartyIndex, - Status: p.Status, - } - } - - c.JSON(http.StatusOK, SessionStatusResponse{ - SessionID: output.SessionID.String(), - Status: output.Status, - ThresholdT: output.ThresholdT, - ThresholdN: output.ThresholdN, - Participants: participants, - PublicKey: string(output.PublicKey), - }) -} - -// ReportCompletionRequest is the HTTP request for reporting completion -type ReportCompletionRequest struct { - PartyID string `json:"party_id" binding:"required"` - PublicKey string `json:"public_key,omitempty"` - Signature string `json:"signature,omitempty"` -} - -// ReportCompletion handles POST /sessions/:id/complete -func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - var req ReportCompletionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - inputData := input.ReportCompletionInput{ - SessionID: sessionID, - PartyID: req.PartyID, - PublicKey: []byte(req.PublicKey), - Signature: []byte(req.Signature), - } - - output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData) - if err != nil { - logger.Error("failed to report completion", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": output.Success, - "all_completed": output.AllCompleted, - }) -} - -// CloseSession handles DELETE /sessions/:id -func (h *SessionHTTPHandler) CloseSession(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - err = h.closeSessionUC.Execute(c.Request.Context(), sessionID) - if err != nil { - logger.Error("failed to close session", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"success": true}) -} - -// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready -func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - partyID := c.Param("partyId") - if partyID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"}) - return - } - - logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) - - // Load session - session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) - if err != nil { - if err == entities.ErrSessionNotFound { - c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) - return - } - logger.Error("failed to load session", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) - return - } - - // Create party ID value object - partyIDVO, err := value_objects.NewPartyID(partyID) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"}) - return - } - - // Update participant status to ready - if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil { - logger.Error("failed to mark party as ready", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Save session - if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { - logger.Error("failed to save session", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) - return - } - - c.JSON(http.StatusOK, gin.H{"success": true}) -} - -// StartSession handles POST /sessions/:id/start -func (h *SessionHTTPHandler) StartSession(c *gin.Context) { - sessionID, err := uuid.Parse(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) - return - } - - logger.Info("starting session", zap.String("session_id", sessionID.String())) - - // Load session - session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) - if err != nil { - if err == entities.ErrSessionNotFound { - c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) - return - } - logger.Error("failed to load session", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) - return - } - - // Start the session - if err := session.Start(); err != nil { - logger.Error("failed to start session", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // Save session - if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { - logger.Error("failed to save session", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) - return - } - - c.JSON(http.StatusOK, gin.H{"success": true}) -} - -// HealthCheck handles GET /health -func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "status": "healthy", - "service": "session-coordinator", - "time": time.Now().UTC().Format(time.RFC3339), - }) -} +package http + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/rwadurian/mpc-system/pkg/logger" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" + "go.uber.org/zap" +) + +// SessionHTTPHandler handles HTTP requests for session management +type SessionHTTPHandler struct { + createSessionUC *use_cases.CreateSessionUseCase + joinSessionUC *use_cases.JoinSessionUseCase + getSessionStatusUC *use_cases.GetSessionStatusUseCase + reportCompletionUC *use_cases.ReportCompletionUseCase + closeSessionUC *use_cases.CloseSessionUseCase + sessionRepo repositories.SessionRepository +} + +// NewSessionHTTPHandler creates a new HTTP handler +func NewSessionHTTPHandler( + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, + sessionRepo repositories.SessionRepository, +) *SessionHTTPHandler { + return &SessionHTTPHandler{ + createSessionUC: createSessionUC, + joinSessionUC: joinSessionUC, + getSessionStatusUC: getSessionStatusUC, + reportCompletionUC: reportCompletionUC, + closeSessionUC: closeSessionUC, + sessionRepo: sessionRepo, + } +} + +// RegisterRoutes registers HTTP routes +func (h *SessionHTTPHandler) RegisterRoutes(r *gin.RouterGroup) { + sessions := r.Group("/sessions") + { + sessions.POST("", h.CreateSession) + sessions.POST("/join", h.JoinSessionByToken) + sessions.POST("/:id/join", h.JoinSession) + sessions.GET("/:id", h.GetSessionStatus) + sessions.GET("/:id/status", h.GetSessionStatus) + sessions.PUT("/:id/parties/:partyId/ready", h.MarkPartyReady) + sessions.POST("/:id/start", h.StartSession) + sessions.POST("/:id/complete", h.ReportCompletion) + sessions.DELETE("/:id", h.CloseSession) + } +} + +// CreateSessionRequest is the HTTP request body for creating a session +type CreateSessionRequest struct { + SessionType string `json:"sessionType" binding:"required,oneof=keygen sign"` + ThresholdN int `json:"thresholdN" binding:"required,min=2,max=10"` + ThresholdT int `json:"thresholdT" binding:"required,min=1"` + CreatedBy string `json:"createdBy" binding:"required"` + Participants []ParticipantInfoRequest `json:"participants,omitempty"` + MessageHash string `json:"messageHash,omitempty"` + ExpiresIn int64 `json:"expiresIn,omitempty"` +} + +// ParticipantInfoRequest represents a participant in the request +type ParticipantInfoRequest struct { + PartyID string `json:"party_id" binding:"required"` + DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` +} + +// DeviceInfoRequest represents device info in the request +type DeviceInfoRequest struct { + DeviceType string `json:"device_type" binding:"required"` + DeviceID string `json:"device_id,omitempty"` + Platform string `json:"platform,omitempty"` + AppVersion string `json:"app_version,omitempty"` +} + +// CreateSessionResponse is the HTTP response for creating a session +type CreateSessionResponse struct { + SessionID string `json:"sessionId"` + JoinToken string `json:"joinToken"` + Status string `json:"status"` + ExpiresAt int64 `json:"expiresAt,omitempty"` +} + +// CreateSession handles POST /sessions +func (h *SessionHTTPHandler) CreateSession(c *gin.Context) { + var req CreateSessionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Validate threshold + if req.ThresholdT > req.ThresholdN { + c.JSON(http.StatusBadRequest, gin.H{"error": "threshold_t cannot exceed threshold_n"}) + return + } + + // Convert request to input + participants := make([]input.ParticipantInfo, len(req.Participants)) + for i, p := range req.Participants { + participants[i] = input.ParticipantInfo{ + PartyID: p.PartyID, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + var messageHash []byte + if req.MessageHash != "" { + messageHash = []byte(req.MessageHash) + } + + expiresIn := time.Duration(req.ExpiresIn) * time.Second + if expiresIn == 0 { + expiresIn = 24 * time.Hour // Default: 24-hour session validity + } + + inputData := input.CreateSessionInput{ + InitiatorID: req.CreatedBy, + SessionType: req.SessionType, + ThresholdN: req.ThresholdN, + ThresholdT: req.ThresholdT, + Participants: participants, + MessageHash: messageHash, + ExpiresIn: expiresIn, + } + + output, err := h.createSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to create session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Extract a single join token (for E2E compatibility) + // In a real scenario with pre-registered participants, we'd return all tokens + // For now, generate a universal join token + joinToken := output.SessionID.String() // Use session ID as join token for simplicity + if len(output.JoinTokens) > 0 { + // If there are participant-specific tokens, use the first one + for _, token := range output.JoinTokens { + joinToken = token + break + } + } + + c.JSON(http.StatusCreated, CreateSessionResponse{ + SessionID: output.SessionID.String(), + JoinToken: joinToken, + Status: "created", + ExpiresAt: output.ExpiresAt.UnixMilli(), + }) +} + +// JoinSessionRequest is the HTTP request body for joining a session +type JoinSessionRequest struct { + PartyID string `json:"party_id" binding:"required"` + JoinToken string `json:"join_token" binding:"required"` + DeviceInfo DeviceInfoRequest `json:"device_info" binding:"required"` +} + +// JoinSessionResponse is the HTTP response for joining a session +type JoinSessionResponse struct { + Success bool `json:"success"` + SessionInfo SessionInfoDTO `json:"session_info"` + OtherParties []PartyInfoDTO `json:"other_parties"` +} + +// SessionInfoDTO represents session info in responses +type SessionInfoDTO struct { + SessionID string `json:"session_id"` + SessionType string `json:"session_type"` + ThresholdN int `json:"threshold_n"` + ThresholdT int `json:"threshold_t"` + MessageHash string `json:"message_hash,omitempty"` + Status string `json:"status"` +} + +// PartyInfoDTO represents party info in responses +type PartyInfoDTO struct { + PartyID string `json:"party_id"` + PartyIndex int `json:"party_index"` + DeviceInfo DeviceInfoRequest `json:"device_info"` +} + +// JoinSessionByTokenRequest is the HTTP request body for joining by token +type JoinSessionByTokenRequest struct { + JoinToken string `json:"joinToken" binding:"required"` + PartyID string `json:"partyId" binding:"required"` + DeviceType string `json:"deviceType" binding:"required"` + DeviceID string `json:"deviceId,omitempty"` +} + +// JoinSessionByTokenResponse is the HTTP response for joining by token +type JoinSessionByTokenResponse struct { + SessionID string `json:"sessionId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` + Participants []ParticipantStatusDTO `json:"participants"` +} + +// ParticipantStatusDTO represents participant status in responses +type ParticipantStatusDTO struct { + PartyID string `json:"partyId"` + PartyIndex int `json:"partyIndex"` + Status string `json:"status"` +} + +// JoinSessionByToken handles POST /sessions/join (join by token without session ID) +func (h *SessionHTTPHandler) JoinSessionByToken(c *gin.Context) { + var req JoinSessionByTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + logger.Error("failed to bind join session request", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Pass empty UUID - the use case will extract session ID from the JWT + inputData := input.JoinSessionInput{ + SessionID: uuid.Nil, + PartyID: req.PartyID, + JoinToken: req.JoinToken, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceType), + DeviceID: req.DeviceID, + }, + } + + output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to join session", zap.Error(err)) + // Return 401 for authentication/token errors + if err.Error() == "invalid token" || err.Error() == "token expired" { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Build participant status list + myPartyIndex := output.PartyIndex + + participants := make([]ParticipantStatusDTO, 0) + participants = append(participants, ParticipantStatusDTO{ + PartyID: req.PartyID, + Status: "joined", + }) + for _, p := range output.OtherParties { + participants = append(participants, ParticipantStatusDTO{ + PartyID: p.PartyID, + Status: "joined", + }) + } + + c.JSON(http.StatusOK, JoinSessionByTokenResponse{ + SessionID: output.SessionInfo.SessionID.String(), + PartyIndex: myPartyIndex, + Status: output.SessionInfo.Status, + Participants: participants, + }) +} + +// JoinSession handles POST /sessions/:id/join +func (h *SessionHTTPHandler) JoinSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + var req JoinSessionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + inputData := input.JoinSessionInput{ + SessionID: sessionID, + PartyID: req.PartyID, + JoinToken: req.JoinToken, + DeviceInfo: entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), + DeviceID: req.DeviceInfo.DeviceID, + Platform: req.DeviceInfo.Platform, + AppVersion: req.DeviceInfo.AppVersion, + }, + } + + output, err := h.joinSessionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to join session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + otherParties := make([]PartyInfoDTO, len(output.OtherParties)) + for i, p := range output.OtherParties { + otherParties[i] = PartyInfoDTO{ + PartyID: p.PartyID, + PartyIndex: p.PartyIndex, + DeviceInfo: DeviceInfoRequest{ + DeviceType: string(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + c.JSON(http.StatusOK, JoinSessionResponse{ + Success: output.Success, + SessionInfo: SessionInfoDTO{ + SessionID: output.SessionInfo.SessionID.String(), + SessionType: output.SessionInfo.SessionType, + ThresholdN: output.SessionInfo.ThresholdN, + ThresholdT: output.SessionInfo.ThresholdT, + MessageHash: string(output.SessionInfo.MessageHash), + Status: output.SessionInfo.Status, + }, + OtherParties: otherParties, + }) +} + +// SessionStatusResponse is the HTTP response for session status +type SessionStatusResponse struct { + SessionID string `json:"sessionId"` + Status string `json:"status"` + ThresholdT int `json:"thresholdT"` + ThresholdN int `json:"thresholdN"` + Participants []ParticipantStatusDTO `json:"participants"` + PublicKey string `json:"publicKey,omitempty"` + Signature string `json:"signature,omitempty"` +} + +// GetSessionStatus handles GET /sessions/:id/status +func (h *SessionHTTPHandler) GetSessionStatus(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + output, err := h.getSessionStatusUC.Execute(c.Request.Context(), sessionID) + if err != nil { + logger.Error("failed to get session status", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + // Convert participants to DTO + participants := make([]ParticipantStatusDTO, len(output.Participants)) + for i, p := range output.Participants { + participants[i] = ParticipantStatusDTO{ + PartyID: p.PartyID, + PartyIndex: p.PartyIndex, + Status: p.Status, + } + } + + c.JSON(http.StatusOK, SessionStatusResponse{ + SessionID: output.SessionID.String(), + Status: output.Status, + ThresholdT: output.ThresholdT, + ThresholdN: output.ThresholdN, + Participants: participants, + PublicKey: string(output.PublicKey), + Signature: string(output.Signature), + }) +} + +// ReportCompletionRequest is the HTTP request for reporting completion +type ReportCompletionRequest struct { + PartyID string `json:"party_id" binding:"required"` + PublicKey string `json:"public_key,omitempty"` + Signature string `json:"signature,omitempty"` +} + +// ReportCompletion handles POST /sessions/:id/complete +func (h *SessionHTTPHandler) ReportCompletion(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + var req ReportCompletionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + inputData := input.ReportCompletionInput{ + SessionID: sessionID, + PartyID: req.PartyID, + PublicKey: []byte(req.PublicKey), + Signature: []byte(req.Signature), + } + + output, err := h.reportCompletionUC.Execute(c.Request.Context(), inputData) + if err != nil { + logger.Error("failed to report completion", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": output.Success, + "all_completed": output.AllCompleted, + }) +} + +// CloseSession handles DELETE /sessions/:id +func (h *SessionHTTPHandler) CloseSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + err = h.closeSessionUC.Execute(c.Request.Context(), sessionID) + if err != nil { + logger.Error("failed to close session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// MarkPartyReady handles PUT /sessions/:id/parties/:partyId/ready +func (h *SessionHTTPHandler) MarkPartyReady(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + partyID := c.Param("partyId") + if partyID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "party ID is required"}) + return + } + + logger.Info("marking party as ready", zap.String("session_id", sessionID.String()), zap.String("party_id", partyID)) + + // Load session + session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) + if err != nil { + if err == entities.ErrSessionNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + logger.Error("failed to load session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) + return + } + + // Create party ID value object + partyIDVO, err := value_objects.NewPartyID(partyID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid party ID"}) + return + } + + // Update participant status to ready + if err := session.UpdateParticipantStatus(partyIDVO, value_objects.ParticipantStatusReady); err != nil { + logger.Error("failed to mark party as ready", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Save session + if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { + logger.Error("failed to save session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// StartSession handles POST /sessions/:id/start +func (h *SessionHTTPHandler) StartSession(c *gin.Context) { + sessionID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session ID"}) + return + } + + logger.Info("starting session", zap.String("session_id", sessionID.String())) + + // Load session + session, err := h.sessionRepo.FindByUUID(c.Request.Context(), sessionID) + if err != nil { + if err == entities.ErrSessionNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + logger.Error("failed to load session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load session"}) + return + } + + // Start the session + if err := session.Start(); err != nil { + logger.Error("failed to start session", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Save session + if err := h.sessionRepo.Update(c.Request.Context(), session); err != nil { + logger.Error("failed to save session", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save session"}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// HealthCheck handles GET /health +func (h *SessionHTTPHandler) HealthCheck(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + "service": "session-coordinator", + "time": time.Now().UTC().Format(time.RFC3339), + }) +} diff --git a/backend/mpc-system/services/session-coordinator/adapters/output/memory/signature_cache.go b/backend/mpc-system/services/session-coordinator/adapters/output/memory/signature_cache.go new file mode 100644 index 00000000..facce248 --- /dev/null +++ b/backend/mpc-system/services/session-coordinator/adapters/output/memory/signature_cache.go @@ -0,0 +1,59 @@ +package memory + +import ( + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + // SignatureCacheTTL is the time-to-live for cached signatures (24 hours) + SignatureCacheTTL = 24 * time.Hour +) + +// signatureCacheEntry represents a cached signature with expiration +type signatureCacheEntry struct { + Signature []byte + ExpiresAt time.Time +} + +// SignatureCache provides in-memory caching for signatures +type SignatureCache struct { + cache sync.Map +} + +// Global signature cache instance +var globalSignatureCache = &SignatureCache{} + +// GetSignatureCache returns the global signature cache instance +func GetSignatureCache() *SignatureCache { + return globalSignatureCache +} + +// Set stores a signature in the cache with 24h TTL +func (c *SignatureCache) Set(sessionID uuid.UUID, signature []byte) { + entry := signatureCacheEntry{ + Signature: signature, + ExpiresAt: time.Now().Add(SignatureCacheTTL), + } + c.cache.Store(sessionID.String(), entry) +} + +// Get retrieves a signature from the cache +func (c *SignatureCache) Get(sessionID uuid.UUID) ([]byte, bool) { + value, ok := c.cache.Load(sessionID.String()) + if !ok { + return nil, false + } + + entry := value.(signatureCacheEntry) + + // Check if expired + if time.Now().After(entry.ExpiresAt) { + c.cache.Delete(sessionID.String()) + return nil, false + } + + return entry.Signature, true +} diff --git a/backend/mpc-system/services/session-coordinator/application/use_cases/get_session_status.go b/backend/mpc-system/services/session-coordinator/application/use_cases/get_session_status.go index c093136b..01c9b4e1 100644 --- a/backend/mpc-system/services/session-coordinator/application/use_cases/get_session_status.go +++ b/backend/mpc-system/services/session-coordinator/application/use_cases/get_session_status.go @@ -4,7 +4,9 @@ import ( "context" "github.com/google/uuid" + "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/memory" "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" ) @@ -45,7 +47,13 @@ func (uc *GetSessionStatusUseCase) Execute( } } - // 3. Build response + // 3. For completed sign sessions, get signature from cache + var signature []byte + if session.SessionType == entities.SessionTypeSign && session.Status.String() == "completed" { + signature, _ = memory.GetSignatureCache().Get(sessionID) + } + + // 4. Build response // has_delegate is only meaningful for keygen sessions hasDelegate := session.DelegatePartyID != "" && string(session.SessionType) == "keygen" return &input.SessionStatusOutput{ @@ -56,6 +64,7 @@ func (uc *GetSessionStatusUseCase) Execute( ThresholdN: session.Threshold.N(), Participants: participants, PublicKey: session.PublicKey, + Signature: signature, HasDelegate: hasDelegate, DelegatePartyID: session.DelegatePartyID, }, nil diff --git a/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go b/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go index 356cfd2b..c01f53b3 100644 --- a/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go +++ b/backend/mpc-system/services/session-coordinator/application/use_cases/report_completion.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rwadurian/mpc-system/pkg/logger" + "github.com/rwadurian/mpc-system/services/session-coordinator/adapters/output/memory" "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/output" "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" @@ -169,6 +170,15 @@ func (uc *ReportCompletionUseCase) executeWithRetry( return nil, err } + // For sign sessions, cache the signature for HTTP API retrieval + // Note: server-party sends signature in PublicKey field (same field used for keygen public key) + if session.SessionType == entities.SessionTypeSign && len(inputData.PublicKey) > 0 { + memory.GetSignatureCache().Set(session.ID.UUID(), inputData.PublicKey) + logger.Info("cached signature for sign session", + zap.String("session_id", session.ID.String()), + zap.Int("signature_len", len(inputData.PublicKey))) + } + // Publish session completed event completedEvent := output.SessionCompletedEvent{ SessionID: session.ID.String(), diff --git a/backend/services/blockchain-service/src/domain/services/erc20-transfer.service.ts b/backend/services/blockchain-service/src/domain/services/erc20-transfer.service.ts index 3578aa32..51cd37ba 100644 --- a/backend/services/blockchain-service/src/domain/services/erc20-transfer.service.ts +++ b/backend/services/blockchain-service/src/domain/services/erc20-transfer.service.ts @@ -7,6 +7,7 @@ import { formatUnits, Transaction, Signature, + recoverAddress, } from 'ethers'; import { ChainConfigService } from './chain-config.service'; import { ChainType } from '@/domain/value-objects'; @@ -191,16 +192,42 @@ export class Erc20TransferService { data: transferData, }); - const tx = Transaction.from({ - type: 2, // EIP-1559 - chainId: config.chainId, - nonce, - to: config.usdtContract, - data: transferData, - gasLimit: gasEstimate * BigInt(120) / BigInt(100), // 增加 20% buffer - maxFeePerGas: feeData.maxFeePerGas, - maxPriorityFeePerGas: feeData.maxPriorityFeePerGas, - }); + const gasLimit = gasEstimate * BigInt(120) / BigInt(100); // 增加 20% buffer + + // 检测链是否支持 EIP-1559 + // 如果 maxFeePerGas 为 null 或 0,则使用 legacy 交易 + const supportsEip1559 = feeData.maxFeePerGas && feeData.maxFeePerGas > BigInt(0); + this.logger.log(`[TRANSFER] Chain supports EIP-1559: ${supportsEip1559}`); + this.logger.log(`[TRANSFER] Fee data: gasPrice=${feeData.gasPrice}, maxFeePerGas=${feeData.maxFeePerGas}`); + + let tx: Transaction; + if (supportsEip1559) { + // EIP-1559 交易 (type 2) + tx = Transaction.from({ + type: 2, + chainId: config.chainId, + nonce, + to: config.usdtContract, + data: transferData, + gasLimit, + maxFeePerGas: feeData.maxFeePerGas, + maxPriorityFeePerGas: feeData.maxPriorityFeePerGas, + }); + this.logger.log(`[TRANSFER] Built EIP-1559 transaction`); + } else { + // Legacy 交易 (type 0) + const gasPrice = feeData.gasPrice || BigInt(1000000000); // 默认 1 gwei + tx = Transaction.from({ + type: 0, + chainId: config.chainId, + nonce, + to: config.usdtContract, + data: transferData, + gasLimit, + gasPrice, + }); + this.logger.log(`[TRANSFER] Built legacy transaction with gasPrice=${gasPrice}`); + } this.logger.log(`[TRANSFER] Transaction built: nonce=${nonce}, gasLimit=${tx.gasLimit}`); @@ -213,8 +240,43 @@ export class Erc20TransferService { const signatureHex = await this.mpcSigningClient.signMessage(unsignedTxHash); this.logger.log(`[TRANSFER] MPC signature obtained: ${signatureHex.slice(0, 20)}...`); - // 解析签名 - const signature = Signature.from(signatureHex); + // 解析签名 - MPC 返回 64 字节 (r+s),需要转换为 ethers.js 格式 + // 确保有 0x 前缀 + const normalizedSig = signatureHex.startsWith('0x') ? signatureHex : `0x${signatureHex}`; + this.logger.log(`[TRANSFER] Normalized signature: ${normalizedSig.slice(0, 22)}...`); + + // MPC 签名是 64 字节 (r: 32 bytes + s: 32 bytes),需要添加 v (recovery id) + // 对于 EIP-1559 交易,v = 0 或 1 (yParity) + // 我们需要尝试两个值来恢复正确的地址 + const sigBytes = normalizedSig.slice(2); // 去掉 0x + const r = `0x${sigBytes.slice(0, 64)}`; + const s = `0x${sigBytes.slice(64, 128)}`; + + this.logger.log(`[TRANSFER] Signature r: ${r.slice(0, 20)}...`); + this.logger.log(`[TRANSFER] Signature s: ${s.slice(0, 20)}...`); + + // 尝试 yParity 0 和 1 来找到正确的 recovery id + let signature: Signature | null = null; + for (const yParity of [0, 1] as const) { + try { + const testSig = Signature.from({ r, s, yParity }); + // 使用 recoverAddress 验证签名 + const recoveredAddress = recoverAddress(unsignedTxHash, testSig); + this.logger.log(`[TRANSFER] Recovered address with yParity=${yParity}: ${recoveredAddress}`); + + if (recoveredAddress.toLowerCase() === this.hotWalletAddress.toLowerCase()) { + this.logger.log(`[TRANSFER] Found correct yParity: ${yParity}`); + signature = testSig; + break; + } + } catch (e) { + this.logger.debug(`[TRANSFER] yParity=${yParity} failed: ${e}`); + } + } + + if (!signature) { + throw new Error('Failed to recover correct signature - address mismatch'); + } // 创建已签名交易 const signedTx = tx.clone(); diff --git a/backend/services/wallet-service/prisma/schema.prisma b/backend/services/wallet-service/prisma/schema.prisma index f3dc49ae..5fa83ae5 100644 --- a/backend/services/wallet-service/prisma/schema.prisma +++ b/backend/services/wallet-service/prisma/schema.prisma @@ -58,6 +58,9 @@ model WalletAccount { // 状态 status String @default("ACTIVE") @map("status") @db.VarChar(20) + // 乐观锁版本号 + version Int @default(0) @map("version") + createdAt DateTime @default(now()) @map("created_at") updatedAt DateTime @updatedAt @map("updated_at") diff --git a/backend/services/wallet-service/src/application/event-handlers/withdrawal-status.handler.ts b/backend/services/wallet-service/src/application/event-handlers/withdrawal-status.handler.ts index 10622bbc..1e4c4439 100644 --- a/backend/services/wallet-service/src/application/event-handlers/withdrawal-status.handler.ts +++ b/backend/services/wallet-service/src/application/event-handlers/withdrawal-status.handler.ts @@ -10,23 +10,36 @@ import { IWalletAccountRepository, WALLET_ACCOUNT_REPOSITORY, } from '@/domain/repositories'; +import { PrismaService } from '@/infrastructure/persistence/prisma/prisma.service'; +import { WithdrawalOrder, WalletAccount } from '@/domain/aggregates'; +import { WithdrawalStatus, Money, UserId } from '@/domain/value-objects'; +import { OptimisticLockError } from '@/shared/exceptions/domain.exception'; +import Decimal from 'decimal.js'; /** * Withdrawal Status Handler * * Handles withdrawal status events from blockchain-service. * Updates withdrawal order status and handles fund refunds on failure. + * + * IMPORTANT: + * - All operations use database transactions for atomicity. + * - Wallet balance updates use optimistic locking to prevent concurrent modification issues. */ @Injectable() export class WithdrawalStatusHandler implements OnModuleInit { private readonly logger = new Logger(WithdrawalStatusHandler.name); + // Max retry count for optimistic lock conflicts + private readonly MAX_RETRIES = 3; + constructor( private readonly withdrawalEventConsumer: WithdrawalEventConsumerService, @Inject(WITHDRAWAL_ORDER_REPOSITORY) private readonly withdrawalRepo: IWithdrawalOrderRepository, @Inject(WALLET_ACCOUNT_REPOSITORY) private readonly walletRepo: IWalletAccountRepository, + private readonly prisma: PrismaService, ) {} onModuleInit() { @@ -41,7 +54,9 @@ export class WithdrawalStatusHandler implements OnModuleInit { /** * Handle withdrawal confirmed event - * Update order status to CONFIRMED and store txHash + * Update order status to CONFIRMED, store txHash, and deduct frozen balance + * + * Uses database transaction + optimistic locking to ensure atomicity and prevent race conditions. */ private async handleWithdrawalConfirmed( payload: WithdrawalConfirmedPayload, @@ -50,26 +65,129 @@ export class WithdrawalStatusHandler implements OnModuleInit { this.logger.log(`[CONFIRMED] orderNo: ${payload.orderNo}`); this.logger.log(`[CONFIRMED] txHash: ${payload.txHash}`); + let retries = 0; + while (retries < this.MAX_RETRIES) { + try { + await this.executeWithdrawalConfirmed(payload); + return; // Success, exit + } catch (error) { + if (this.isOptimisticLockError(error)) { + retries++; + this.logger.warn(`[CONFIRMED] Optimistic lock conflict for ${payload.orderNo}, retry ${retries}/${this.MAX_RETRIES}`); + if (retries >= this.MAX_RETRIES) { + this.logger.error(`[CONFIRMED] Max retries exceeded for ${payload.orderNo}`); + throw error; + } + // Brief delay before retry + await this.sleep(50 * retries); + } else { + throw error; + } + } + } + } + + /** + * Execute the withdrawal confirmed logic within a transaction + */ + private async executeWithdrawalConfirmed( + payload: WithdrawalConfirmedPayload, + ): Promise { try { - // Find the withdrawal order - const order = await this.withdrawalRepo.findByOrderNo(payload.orderNo); - if (!order) { - this.logger.error(`[CONFIRMED] Order not found: ${payload.orderNo}`); - return; - } + // Use transaction to ensure atomicity + await this.prisma.$transaction(async (tx) => { + // Find the withdrawal order + const orderRecord = await tx.withdrawalOrder.findUnique({ + where: { orderNo: payload.orderNo }, + }); - // Update order status: FROZEN -> BROADCASTED -> CONFIRMED - // If still FROZEN, first mark as broadcasted with txHash - if (order.isFrozen) { - order.markAsBroadcasted(payload.txHash); - } + if (!orderRecord) { + this.logger.error(`[CONFIRMED] Order not found: ${payload.orderNo}`); + return; + } - // Then mark as confirmed - if (order.isBroadcasted) { - order.markAsConfirmed(); - } + // Check if already confirmed (idempotency) + if (orderRecord.status === WithdrawalStatus.CONFIRMED) { + this.logger.log(`[CONFIRMED] Order ${payload.orderNo} already confirmed, skipping`); + return; + } - await this.withdrawalRepo.save(order); + // Determine new status based on current status + let newStatus = orderRecord.status; + let txHash = orderRecord.txHash; + let broadcastedAt = orderRecord.broadcastedAt; + let confirmedAt = orderRecord.confirmedAt; + + // FROZEN -> BROADCASTED -> CONFIRMED + if (orderRecord.status === WithdrawalStatus.FROZEN) { + newStatus = WithdrawalStatus.BROADCASTED; + txHash = payload.txHash; + broadcastedAt = new Date(); + } + + if (newStatus === WithdrawalStatus.BROADCASTED || orderRecord.status === WithdrawalStatus.BROADCASTED) { + newStatus = WithdrawalStatus.CONFIRMED; + confirmedAt = new Date(); + } + + // Update order status + await tx.withdrawalOrder.update({ + where: { id: orderRecord.id }, + data: { + status: newStatus, + txHash, + broadcastedAt, + confirmedAt, + }, + }); + + // Find wallet and deduct frozen balance with optimistic lock + let walletRecord = await tx.walletAccount.findUnique({ + where: { accountSequence: orderRecord.accountSequence }, + }); + + if (!walletRecord) { + walletRecord = await tx.walletAccount.findUnique({ + where: { userId: orderRecord.userId }, + }); + } + + if (walletRecord) { + // Deduct the total frozen amount (amount + fee) + const totalAmount = new Decimal(orderRecord.amount.toString()).add(new Decimal(orderRecord.fee.toString())); + const currentFrozen = new Decimal(walletRecord.usdtFrozen.toString()); + + if (currentFrozen.lessThan(totalAmount)) { + this.logger.error(`[CONFIRMED] Insufficient frozen balance: have ${currentFrozen}, need ${totalAmount}`); + throw new Error(`Insufficient frozen balance for withdrawal ${payload.orderNo}`); + } + + const newFrozen = currentFrozen.minus(totalAmount); + const currentVersion = walletRecord.version; + + // Optimistic lock: update only if version matches + const updateResult = await tx.walletAccount.updateMany({ + where: { + id: walletRecord.id, + version: currentVersion, // Optimistic lock condition + }, + data: { + usdtFrozen: newFrozen, + version: currentVersion + 1, // Increment version + updatedAt: new Date(), + }, + }); + + if (updateResult.count === 0) { + // Version mismatch - another transaction modified the record + throw new OptimisticLockError(`Optimistic lock conflict for wallet ${walletRecord.id}`); + } + + this.logger.log(`[CONFIRMED] Deducted ${totalAmount.toString()} USDT from frozen balance for ${orderRecord.accountSequence} (version: ${currentVersion} -> ${currentVersion + 1})`); + } else { + this.logger.error(`[CONFIRMED] Wallet not found for accountSequence: ${orderRecord.accountSequence}, userId: ${orderRecord.userId}`); + } + }); this.logger.log(`[CONFIRMED] Order ${payload.orderNo} confirmed successfully`); } catch (error) { @@ -80,7 +198,9 @@ export class WithdrawalStatusHandler implements OnModuleInit { /** * Handle withdrawal failed event - * Update order status to FAILED and refund frozen funds + * Update order status to FAILED and refund frozen funds (amount + fee) + * + * Uses database transaction + optimistic locking to ensure atomicity and prevent race conditions. */ private async handleWithdrawalFailed( payload: WithdrawalFailedPayload, @@ -89,35 +209,126 @@ export class WithdrawalStatusHandler implements OnModuleInit { this.logger.log(`[FAILED] orderNo: ${payload.orderNo}`); this.logger.log(`[FAILED] error: ${payload.error}`); - try { - // Find the withdrawal order - const order = await this.withdrawalRepo.findByOrderNo(payload.orderNo); - if (!order) { - this.logger.error(`[FAILED] Order not found: ${payload.orderNo}`); - return; - } - - // Mark order as failed - order.markAsFailed(payload.error); - await this.withdrawalRepo.save(order); - - // Refund frozen funds back to available balance if needed - if (order.needsUnfreeze()) { - // 优先使用 accountSequence 查找钱包(更可靠,避免 userId 变化导致扣错人) - let wallet = await this.walletRepo.findByAccountSequence(order.accountSequence); - if (!wallet) { - // 兜底:使用 userId 查找 - wallet = await this.walletRepo.findByUserId(order.userId.value); - } - if (wallet) { - // Unfreeze the amount (add back to available balance) - wallet.unfreeze(order.amount); - await this.walletRepo.save(wallet); - this.logger.log(`[FAILED] Refunded ${order.amount.value} USDT to account ${order.accountSequence}`); + let retries = 0; + while (retries < this.MAX_RETRIES) { + try { + await this.executeWithdrawalFailed(payload); + return; // Success, exit + } catch (error) { + if (this.isOptimisticLockError(error)) { + retries++; + this.logger.warn(`[FAILED] Optimistic lock conflict for ${payload.orderNo}, retry ${retries}/${this.MAX_RETRIES}`); + if (retries >= this.MAX_RETRIES) { + this.logger.error(`[FAILED] Max retries exceeded for ${payload.orderNo}`); + throw error; + } + // Brief delay before retry + await this.sleep(50 * retries); } else { - this.logger.error(`[FAILED] Wallet not found for accountSequence: ${order.accountSequence}, userId: ${order.userId}`); + throw error; } } + } + } + + /** + * Execute the withdrawal failed logic within a transaction + */ + private async executeWithdrawalFailed( + payload: WithdrawalFailedPayload, + ): Promise { + try { + // Use transaction to ensure atomicity + await this.prisma.$transaction(async (tx) => { + // Find the withdrawal order + const orderRecord = await tx.withdrawalOrder.findUnique({ + where: { orderNo: payload.orderNo }, + }); + + if (!orderRecord) { + this.logger.error(`[FAILED] Order not found: ${payload.orderNo}`); + return; + } + + // Check if already in terminal state (idempotency) + if (orderRecord.status === WithdrawalStatus.CONFIRMED || + orderRecord.status === WithdrawalStatus.FAILED || + orderRecord.status === WithdrawalStatus.CANCELLED) { + this.logger.log(`[FAILED] Order ${payload.orderNo} already in terminal state: ${orderRecord.status}, skipping`); + return; + } + + // Check if needs unfreeze (was frozen) + const needsUnfreeze = orderRecord.frozenAt !== null; + + // Update order status to FAILED + await tx.withdrawalOrder.update({ + where: { id: orderRecord.id }, + data: { + status: WithdrawalStatus.FAILED, + errorMessage: payload.error, + }, + }); + + // Refund frozen funds back to available balance if needed + if (needsUnfreeze) { + let walletRecord = await tx.walletAccount.findUnique({ + where: { accountSequence: orderRecord.accountSequence }, + }); + + if (!walletRecord) { + walletRecord = await tx.walletAccount.findUnique({ + where: { userId: orderRecord.userId }, + }); + } + + if (walletRecord) { + // Unfreeze the total amount (amount + fee) + const totalAmount = new Decimal(orderRecord.amount.toString()).add(new Decimal(orderRecord.fee.toString())); + const currentFrozen = new Decimal(walletRecord.usdtFrozen.toString()); + const currentAvailable = new Decimal(walletRecord.usdtAvailable.toString()); + const currentVersion = walletRecord.version; + + // Validate frozen balance + let newFrozen: Decimal; + let newAvailable: Decimal; + + if (currentFrozen.lessThan(totalAmount)) { + this.logger.warn(`[FAILED] Frozen balance (${currentFrozen}) less than refund amount (${totalAmount}), refunding what's available`); + // Refund whatever is frozen (shouldn't happen in normal flow) + const refundAmount = Decimal.min(currentFrozen, totalAmount); + newFrozen = currentFrozen.minus(refundAmount); + newAvailable = currentAvailable.add(refundAmount); + } else { + newFrozen = currentFrozen.minus(totalAmount); + newAvailable = currentAvailable.add(totalAmount); + } + + // Optimistic lock: update only if version matches + const updateResult = await tx.walletAccount.updateMany({ + where: { + id: walletRecord.id, + version: currentVersion, // Optimistic lock condition + }, + data: { + usdtFrozen: newFrozen, + usdtAvailable: newAvailable, + version: currentVersion + 1, // Increment version + updatedAt: new Date(), + }, + }); + + if (updateResult.count === 0) { + // Version mismatch - another transaction modified the record + throw new OptimisticLockError(`Optimistic lock conflict for wallet ${walletRecord.id}`); + } + + this.logger.log(`[FAILED] Refunded ${totalAmount.toString()} USDT (amount + fee) to account ${orderRecord.accountSequence} (version: ${currentVersion} -> ${currentVersion + 1})`); + } else { + this.logger.error(`[FAILED] Wallet not found for accountSequence: ${orderRecord.accountSequence}, userId: ${orderRecord.userId}`); + } + } + }); this.logger.log(`[FAILED] Order ${payload.orderNo} marked as failed`); } catch (error) { @@ -125,4 +336,18 @@ export class WithdrawalStatusHandler implements OnModuleInit { throw error; } } + + /** + * Check if error is an optimistic lock error + */ + private isOptimisticLockError(error: unknown): boolean { + return error instanceof OptimisticLockError; + } + + /** + * Sleep for specified milliseconds + */ + private sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); + } } diff --git a/backend/services/wallet-service/src/application/services/wallet-application.service.ts b/backend/services/wallet-service/src/application/services/wallet-application.service.ts index 9a6d4b1d..4ecd879d 100644 --- a/backend/services/wallet-service/src/application/services/wallet-application.service.ts +++ b/backend/services/wallet-service/src/application/services/wallet-application.service.ts @@ -19,7 +19,7 @@ import { FreezeForPlantingCommand, ConfirmPlantingDeductionCommand, UnfreezeForPlantingCommand, } from '@/application/commands'; import { GetMyWalletQuery, GetMyLedgerQuery } from '@/application/queries'; -import { DuplicateTransactionError, WalletNotFoundError } from '@/shared/exceptions/domain.exception'; +import { DuplicateTransactionError, WalletNotFoundError, OptimisticLockError } from '@/shared/exceptions/domain.exception'; import { WalletCacheService } from '@/infrastructure/redis'; import { EventPublisherService } from '@/infrastructure/kafka'; import { WithdrawalRequestedEvent } from '@/domain/events'; @@ -93,53 +93,165 @@ export class WalletApplicationService { // =============== Commands =============== + /** + * Handle deposit with transaction protection and optimistic locking + * + * Uses database transaction to ensure atomicity of: + * 1. Deposit order creation + * 2. Wallet balance update (with optimistic lock) + * 3. Ledger entry creation + */ async handleDeposit(command: HandleDepositCommand): Promise { - // Check for duplicate transaction - const exists = await this.depositRepo.existsByTxHash(command.txHash); - if (exists) { - throw new DuplicateTransactionError(command.txHash); + const MAX_RETRIES = 3; + let retries = 0; + + while (retries < MAX_RETRIES) { + try { + await this.executeHandleDeposit(command); + return; // Success, exit + } catch (error) { + if (this.isOptimisticLockError(error)) { + retries++; + this.logger.warn(`[DEPOSIT] Optimistic lock conflict for ${command.txHash}, retry ${retries}/${MAX_RETRIES}`); + if (retries >= MAX_RETRIES) { + this.logger.error(`[DEPOSIT] Max retries exceeded for ${command.txHash}`); + throw error; + } + // Brief delay before retry + await this.sleep(50 * retries); + } else { + throw error; + } + } } + } + /** + * Execute deposit logic within a transaction + */ + private async executeHandleDeposit(command: HandleDepositCommand): Promise { const accountSequence = command.accountSequence; - const userId = BigInt(command.userId); - const amount = Money.USDT(command.amount); + const userIdBigint = BigInt(command.userId); + const amountDecimal = new (await import('decimal.js')).default(command.amount); - // Get or create wallet by accountSequence (跨服务关联标识) - const wallet = await this.walletRepo.getOrCreate(accountSequence, userId); + await this.prisma.$transaction(async (tx) => { + // Check for duplicate transaction within transaction + const existingDeposit = await tx.depositOrder.findUnique({ + where: { txHash: command.txHash }, + }); + if (existingDeposit) { + throw new DuplicateTransactionError(command.txHash); + } - // Create deposit order - const depositOrder = DepositOrder.create({ - accountSequence, - userId: UserId.create(userId), - chainType: command.chainType, - amount, - txHash: command.txHash, + // Get or create wallet + let walletRecord = await tx.walletAccount.findUnique({ + where: { accountSequence }, + }); + + if (!walletRecord) { + walletRecord = await tx.walletAccount.create({ + data: { + accountSequence, + userId: userIdBigint, + usdtAvailable: 0, + usdtFrozen: 0, + dstAvailable: 0, + dstFrozen: 0, + bnbAvailable: 0, + bnbFrozen: 0, + ogAvailable: 0, + ogFrozen: 0, + rwadAvailable: 0, + rwadFrozen: 0, + hashpower: 0, + pendingUsdt: 0, + pendingHashpower: 0, + settleableUsdt: 0, + settleableHashpower: 0, + settledTotalUsdt: 0, + settledTotalHashpower: 0, + expiredTotalUsdt: 0, + expiredTotalHashpower: 0, + status: 'ACTIVE', + version: 0, + }, + }); + } + + // Create deposit order + await tx.depositOrder.create({ + data: { + accountSequence, + userId: userIdBigint, + chainType: command.chainType, + amount: amountDecimal, + txHash: command.txHash, + status: 'CONFIRMED', + confirmedAt: new Date(), + }, + }); + + // Update wallet balance with optimistic lock + const Decimal = (await import('decimal.js')).default; + const currentAvailable = new Decimal(walletRecord.usdtAvailable.toString()); + const newAvailable = currentAvailable.add(amountDecimal); + const currentVersion = walletRecord.version; + + const updateResult = await tx.walletAccount.updateMany({ + where: { + id: walletRecord.id, + version: currentVersion, // Optimistic lock condition + }, + data: { + usdtAvailable: newAvailable, + version: currentVersion + 1, // Increment version + updatedAt: new Date(), + }, + }); + + if (updateResult.count === 0) { + // Version mismatch - another transaction modified the record + throw new OptimisticLockError(`Optimistic lock conflict for wallet ${walletRecord.id}`); + } + + // Record ledger entry + const entryType = command.chainType === ChainType.KAVA + ? LedgerEntryType.DEPOSIT_KAVA + : LedgerEntryType.DEPOSIT_BSC; + + await tx.ledgerEntry.create({ + data: { + accountSequence, + userId: userIdBigint, + entryType, + amount: amountDecimal, + assetType: 'USDT', + balanceAfter: newAvailable, + refTxHash: command.txHash, + memo: `Deposit from ${command.chainType}`, + }, + }); + + this.logger.log(`[DEPOSIT] Credited ${amountDecimal.toString()} USDT to ${accountSequence} (version: ${currentVersion} -> ${currentVersion + 1})`); }); - depositOrder.confirm(); - await this.depositRepo.save(depositOrder); - // Credit wallet - wallet.deposit(amount, command.chainType, command.txHash); - await this.walletRepo.save(wallet); + // Invalidate wallet cache after deposit (outside transaction) + await this.walletCacheService.invalidateWallet(userIdBigint); + } - // Record ledger entry (append-only, 可审计) - const entryType = command.chainType === ChainType.KAVA - ? LedgerEntryType.DEPOSIT_KAVA - : LedgerEntryType.DEPOSIT_BSC; + /** + * Check if error is an optimistic lock error + */ + private isOptimisticLockError(error: unknown): boolean { + return error instanceof OptimisticLockError || + (error instanceof Error && error.message.includes('Optimistic lock conflict')); + } - const ledgerEntry = LedgerEntry.create({ - accountSequence, - userId: UserId.create(userId), - entryType, - amount, - balanceAfter: wallet.balances.usdt.available, - refTxHash: command.txHash, - memo: `Deposit from ${command.chainType}`, - }); - await this.ledgerRepo.save(ledgerEntry); - - // Invalidate wallet cache after deposit - await this.walletCacheService.invalidateWallet(userId); + /** + * Sleep for specified milliseconds + */ + private sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); } async deductForPlanting(command: DeductForPlantingCommand): Promise { diff --git a/backend/services/wallet-service/src/infrastructure/kafka/deposit-event-consumer.service.ts b/backend/services/wallet-service/src/infrastructure/kafka/deposit-event-consumer.service.ts index b623ef7d..a010d5da 100644 --- a/backend/services/wallet-service/src/infrastructure/kafka/deposit-event-consumer.service.ts +++ b/backend/services/wallet-service/src/infrastructure/kafka/deposit-event-consumer.service.ts @@ -143,6 +143,9 @@ export class DepositEventConsumerService implements OnModuleInit, OnModuleDestro } } catch (error) { this.logger.error(`[ERROR] Error processing deposit event from ${topic}`, error); + // Re-throw to trigger Kafka retry mechanism + // This ensures messages are not marked as consumed until successfully processed + throw error; } }, }); diff --git a/backend/services/wallet-service/src/infrastructure/kafka/withdrawal-event-consumer.service.ts b/backend/services/wallet-service/src/infrastructure/kafka/withdrawal-event-consumer.service.ts index 4d4a4c59..367660bf 100644 --- a/backend/services/wallet-service/src/infrastructure/kafka/withdrawal-event-consumer.service.ts +++ b/backend/services/wallet-service/src/infrastructure/kafka/withdrawal-event-consumer.service.ts @@ -171,6 +171,9 @@ export class WithdrawalEventConsumerService implements OnModuleInit, OnModuleDes } } catch (error) { this.logger.error(`[ERROR] Error processing withdrawal event from ${topic}`, error); + // Re-throw to trigger Kafka retry mechanism + // This ensures messages are not marked as consumed until successfully processed + throw error; } }, }); diff --git a/backend/services/wallet-service/src/shared/exceptions/domain.exception.ts b/backend/services/wallet-service/src/shared/exceptions/domain.exception.ts index 5d2e6209..9e815844 100644 --- a/backend/services/wallet-service/src/shared/exceptions/domain.exception.ts +++ b/backend/services/wallet-service/src/shared/exceptions/domain.exception.ts @@ -39,3 +39,10 @@ export class InvalidOperationError extends DomainError { this.name = 'InvalidOperationError'; } } + +export class OptimisticLockError extends DomainError { + constructor(message: string) { + super(message); + this.name = 'OptimisticLockError'; + } +}