diff --git a/backend/mpc-system/services/account/adapters/input/http/account_handler.go b/backend/mpc-system/services/account/adapters/input/http/account_handler.go index c1245bdc..5f6118db 100644 --- a/backend/mpc-system/services/account/adapters/input/http/account_handler.go +++ b/backend/mpc-system/services/account/adapters/input/http/account_handler.go @@ -8,10 +8,12 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/rwadurian/mpc-system/pkg/logger" "github.com/rwadurian/mpc-system/services/account/adapters/output/grpc" "github.com/rwadurian/mpc-system/services/account/application/ports" "github.com/rwadurian/mpc-system/services/account/application/use_cases" "github.com/rwadurian/mpc-system/services/account/domain/value_objects" + "go.uber.org/zap" ) // AccountHTTPHandler handles HTTP requests for accounts @@ -579,6 +581,11 @@ func (h *AccountHTTPHandler) CreateKeygenSession(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + logger.Info("Calling CreateKeygenSession via gRPC", + zap.Int("threshold_n", req.ThresholdN), + zap.Int("threshold_t", req.ThresholdT), + zap.Int("num_participants", len(participants))) + resp, err := h.sessionCoordinatorClient.CreateKeygenSession( ctx, int32(req.ThresholdN), @@ -588,10 +595,15 @@ func (h *AccountHTTPHandler) CreateKeygenSession(c *gin.Context) { ) if err != nil { + logger.Error("gRPC CreateKeygenSession failed", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + logger.Info("gRPC CreateKeygenSession succeeded", + zap.String("session_id", resp.SessionID), + zap.Int("num_join_tokens", len(resp.JoinTokens))) + c.JSON(http.StatusCreated, gin.H{ "session_id": resp.SessionID, "session_type": "keygen", diff --git a/backend/mpc-system/services/account/adapters/output/grpc/session_coordinator_client.go b/backend/mpc-system/services/account/adapters/output/grpc/session_coordinator_client.go index ce186375..e3d3292a 100644 --- a/backend/mpc-system/services/account/adapters/output/grpc/session_coordinator_client.go +++ b/backend/mpc-system/services/account/adapters/output/grpc/session_coordinator_client.go @@ -1,202 +1,212 @@ -package grpc - -import ( - "context" - "fmt" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - - coordinatorpb "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1" - "github.com/rwadurian/mpc-system/pkg/logger" - "go.uber.org/zap" -) - -// SessionCoordinatorClient wraps the gRPC client for session coordinator -type SessionCoordinatorClient struct { - client coordinatorpb.SessionCoordinatorClient - conn *grpc.ClientConn -} - -// NewSessionCoordinatorClient creates a new session coordinator gRPC client -func NewSessionCoordinatorClient(address string) (*SessionCoordinatorClient, error) { - var conn *grpc.ClientConn - var err error - - maxRetries := 5 - for i := 0; i < maxRetries; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - conn, err = grpc.DialContext( - ctx, - address, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - ) - cancel() - - if err == nil { - break - } - - if i < maxRetries-1 { - logger.Warn("Failed to connect to session coordinator, retrying...", - zap.Int("attempt", i+1), - zap.Int("max_retries", maxRetries), - zap.Error(err)) - time.Sleep(time.Duration(i+1) * 2 * time.Second) - } - } - - if err != nil { - return nil, fmt.Errorf("failed to connect to session coordinator after %d retries: %w", maxRetries, err) - } - - logger.Info("Connected to session coordinator", zap.String("address", address)) - - client := coordinatorpb.NewSessionCoordinatorClient(conn) - - return &SessionCoordinatorClient{ - client: client, - conn: conn, - }, nil -} - -// CreateKeygenSession creates a new keygen session -func (c *SessionCoordinatorClient) CreateKeygenSession( - ctx context.Context, - thresholdN int32, - thresholdT int32, - participants []ParticipantInfo, - expiresInSeconds int64, -) (*CreateSessionResponse, error) { - pbParticipants := make([]*coordinatorpb.ParticipantInfo, len(participants)) - for i, p := range participants { - pbParticipants[i] = &coordinatorpb.ParticipantInfo{ - PartyId: p.PartyID, - DeviceInfo: &coordinatorpb.DeviceInfo{ - DeviceType: p.DeviceType, - DeviceId: p.DeviceID, - Platform: p.Platform, - AppVersion: p.AppVersion, - }, - } - } - - req := &coordinatorpb.CreateSessionRequest{ - SessionType: "keygen", - ThresholdN: thresholdN, - ThresholdT: thresholdT, - Participants: pbParticipants, - ExpiresInSeconds: expiresInSeconds, - } - - resp, err := c.client.CreateSession(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to create keygen session: %w", err) - } - - return &CreateSessionResponse{ - SessionID: resp.SessionId, - JoinTokens: resp.JoinTokens, - ExpiresAt: resp.ExpiresAt, - }, nil -} - -// CreateSigningSession creates a new signing session -func (c *SessionCoordinatorClient) CreateSigningSession( - ctx context.Context, - thresholdT int32, - participants []ParticipantInfo, - messageHash []byte, - expiresInSeconds int64, -) (*CreateSessionResponse, error) { - pbParticipants := make([]*coordinatorpb.ParticipantInfo, len(participants)) - for i, p := range participants { - pbParticipants[i] = &coordinatorpb.ParticipantInfo{ - PartyId: p.PartyID, - DeviceInfo: &coordinatorpb.DeviceInfo{ - DeviceType: p.DeviceType, - DeviceId: p.DeviceID, - Platform: p.Platform, - AppVersion: p.AppVersion, - }, - } - } - - req := &coordinatorpb.CreateSessionRequest{ - SessionType: "sign", - ThresholdN: int32(len(participants)), - ThresholdT: thresholdT, - Participants: pbParticipants, - MessageHash: messageHash, - ExpiresInSeconds: expiresInSeconds, - } - - resp, err := c.client.CreateSession(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to create signing session: %w", err) - } - - return &CreateSessionResponse{ - SessionID: resp.SessionId, - JoinTokens: resp.JoinTokens, - ExpiresAt: resp.ExpiresAt, - }, nil -} - -// GetSessionStatus gets the status of a session -func (c *SessionCoordinatorClient) GetSessionStatus( - ctx context.Context, - sessionID string, -) (*SessionStatusResponse, error) { - req := &coordinatorpb.GetSessionStatusRequest{ - SessionId: sessionID, - } - - resp, err := c.client.GetSessionStatus(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to get session status: %w", err) - } - - return &SessionStatusResponse{ - Status: resp.Status, - CompletedParties: resp.CompletedParties, - TotalParties: resp.TotalParties, - PublicKey: resp.PublicKey, - Signature: resp.Signature, - }, nil -} - -// Close closes the gRPC connection -func (c *SessionCoordinatorClient) Close() error { - if c.conn != nil { - return c.conn.Close() - } - return nil -} - -// ParticipantInfo contains participant information -type ParticipantInfo struct { - PartyID string - DeviceType string - DeviceID string - Platform string - AppVersion string -} - -// CreateSessionResponse contains the created session information -type CreateSessionResponse struct { - SessionID string - JoinTokens map[string]string - ExpiresAt int64 -} - -// SessionStatusResponse contains session status information -type SessionStatusResponse struct { - Status string - CompletedParties int32 - TotalParties int32 - PublicKey []byte - Signature []byte -} +package grpc + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + coordinatorpb "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1" + "github.com/rwadurian/mpc-system/pkg/logger" + "go.uber.org/zap" +) + +// SessionCoordinatorClient wraps the gRPC client for session coordinator +type SessionCoordinatorClient struct { + client coordinatorpb.SessionCoordinatorClient + conn *grpc.ClientConn +} + +// NewSessionCoordinatorClient creates a new session coordinator gRPC client +func NewSessionCoordinatorClient(address string) (*SessionCoordinatorClient, error) { + var conn *grpc.ClientConn + var err error + + maxRetries := 5 + for i := 0; i < maxRetries; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + conn, err = grpc.DialContext( + ctx, + address, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + cancel() + + if err == nil { + break + } + + if i < maxRetries-1 { + logger.Warn("Failed to connect to session coordinator, retrying...", + zap.Int("attempt", i+1), + zap.Int("max_retries", maxRetries), + zap.Error(err)) + time.Sleep(time.Duration(i+1) * 2 * time.Second) + } + } + + if err != nil { + return nil, fmt.Errorf("failed to connect to session coordinator after %d retries: %w", maxRetries, err) + } + + logger.Info("Connected to session coordinator", zap.String("address", address)) + + client := coordinatorpb.NewSessionCoordinatorClient(conn) + + return &SessionCoordinatorClient{ + client: client, + conn: conn, + }, nil +} + +// CreateKeygenSession creates a new keygen session +func (c *SessionCoordinatorClient) CreateKeygenSession( + ctx context.Context, + thresholdN int32, + thresholdT int32, + participants []ParticipantInfo, + expiresInSeconds int64, +) (*CreateSessionResponse, error) { + pbParticipants := make([]*coordinatorpb.ParticipantInfo, len(participants)) + for i, p := range participants { + pbParticipants[i] = &coordinatorpb.ParticipantInfo{ + PartyId: p.PartyID, + DeviceInfo: &coordinatorpb.DeviceInfo{ + DeviceType: p.DeviceType, + DeviceId: p.DeviceID, + Platform: p.Platform, + AppVersion: p.AppVersion, + }, + } + } + + req := &coordinatorpb.CreateSessionRequest{ + SessionType: "keygen", + ThresholdN: thresholdN, + ThresholdT: thresholdT, + Participants: pbParticipants, + ExpiresInSeconds: expiresInSeconds, + } + + logger.Info("Sending CreateSession gRPC request", + zap.String("session_type", "keygen"), + zap.Int32("threshold_n", thresholdN), + zap.Int32("threshold_t", thresholdT)) + + resp, err := c.client.CreateSession(ctx, req) + if err != nil { + logger.Error("CreateSession gRPC call failed", zap.Error(err)) + return nil, fmt.Errorf("failed to create keygen session: %w", err) + } + + logger.Info("CreateSession gRPC call succeeded", + zap.String("session_id", resp.SessionId), + zap.Int("num_join_tokens", len(resp.JoinTokens))) + + return &CreateSessionResponse{ + SessionID: resp.SessionId, + JoinTokens: resp.JoinTokens, + ExpiresAt: resp.ExpiresAt, + }, nil +} + +// CreateSigningSession creates a new signing session +func (c *SessionCoordinatorClient) CreateSigningSession( + ctx context.Context, + thresholdT int32, + participants []ParticipantInfo, + messageHash []byte, + expiresInSeconds int64, +) (*CreateSessionResponse, error) { + pbParticipants := make([]*coordinatorpb.ParticipantInfo, len(participants)) + for i, p := range participants { + pbParticipants[i] = &coordinatorpb.ParticipantInfo{ + PartyId: p.PartyID, + DeviceInfo: &coordinatorpb.DeviceInfo{ + DeviceType: p.DeviceType, + DeviceId: p.DeviceID, + Platform: p.Platform, + AppVersion: p.AppVersion, + }, + } + } + + req := &coordinatorpb.CreateSessionRequest{ + SessionType: "sign", + ThresholdN: int32(len(participants)), + ThresholdT: thresholdT, + Participants: pbParticipants, + MessageHash: messageHash, + ExpiresInSeconds: expiresInSeconds, + } + + resp, err := c.client.CreateSession(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to create signing session: %w", err) + } + + return &CreateSessionResponse{ + SessionID: resp.SessionId, + JoinTokens: resp.JoinTokens, + ExpiresAt: resp.ExpiresAt, + }, nil +} + +// GetSessionStatus gets the status of a session +func (c *SessionCoordinatorClient) GetSessionStatus( + ctx context.Context, + sessionID string, +) (*SessionStatusResponse, error) { + req := &coordinatorpb.GetSessionStatusRequest{ + SessionId: sessionID, + } + + resp, err := c.client.GetSessionStatus(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to get session status: %w", err) + } + + return &SessionStatusResponse{ + Status: resp.Status, + CompletedParties: resp.CompletedParties, + TotalParties: resp.TotalParties, + PublicKey: resp.PublicKey, + Signature: resp.Signature, + }, nil +} + +// Close closes the gRPC connection +func (c *SessionCoordinatorClient) Close() error { + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +// ParticipantInfo contains participant information +type ParticipantInfo struct { + PartyID string + DeviceType string + DeviceID string + Platform string + AppVersion string +} + +// CreateSessionResponse contains the created session information +type CreateSessionResponse struct { + SessionID string + JoinTokens map[string]string + ExpiresAt int64 +} + +// SessionStatusResponse contains session status information +type SessionStatusResponse struct { + Status string + CompletedParties int32 + TotalParties int32 + PublicKey []byte + Signature []byte +} diff --git a/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go b/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go index 6876569e..9f07d7bf 100644 --- a/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go +++ b/backend/mpc-system/services/session-coordinator/adapters/input/grpc/session_grpc_handler.go @@ -1,324 +1,337 @@ -package grpc - -import ( - "context" - "time" - - "github.com/google/uuid" - pb "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1" - "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" - "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" - "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// SessionCoordinatorServer implements the gRPC SessionCoordinator service -type SessionCoordinatorServer struct { - pb.UnimplementedSessionCoordinatorServer - createSessionUC *use_cases.CreateSessionUseCase - joinSessionUC *use_cases.JoinSessionUseCase - getSessionStatusUC *use_cases.GetSessionStatusUseCase - reportCompletionUC *use_cases.ReportCompletionUseCase - closeSessionUC *use_cases.CloseSessionUseCase - sessionRepo repositories.SessionRepository -} - -// NewSessionCoordinatorServer creates a new gRPC server -func NewSessionCoordinatorServer( - createSessionUC *use_cases.CreateSessionUseCase, - joinSessionUC *use_cases.JoinSessionUseCase, - getSessionStatusUC *use_cases.GetSessionStatusUseCase, - reportCompletionUC *use_cases.ReportCompletionUseCase, - closeSessionUC *use_cases.CloseSessionUseCase, - sessionRepo repositories.SessionRepository, -) *SessionCoordinatorServer { - return &SessionCoordinatorServer{ - createSessionUC: createSessionUC, - joinSessionUC: joinSessionUC, - getSessionStatusUC: getSessionStatusUC, - reportCompletionUC: reportCompletionUC, - closeSessionUC: closeSessionUC, - sessionRepo: sessionRepo, - } -} - -// CreateSession creates a new MPC session -func (s *SessionCoordinatorServer) CreateSession( - ctx context.Context, - req *pb.CreateSessionRequest, -) (*pb.CreateSessionResponse, error) { - // Convert request to input - participants := make([]input.ParticipantInfo, len(req.Participants)) - for i, p := range req.Participants { - var deviceInfo entities.DeviceInfo - if p.DeviceInfo != nil { - deviceInfo = entities.DeviceInfo{ - DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), - DeviceID: p.DeviceInfo.DeviceId, - Platform: p.DeviceInfo.Platform, - AppVersion: p.DeviceInfo.AppVersion, - } - } - participants[i] = input.ParticipantInfo{ - PartyID: p.PartyId, - DeviceInfo: deviceInfo, - } - } - - inputData := input.CreateSessionInput{ - InitiatorID: "", - SessionType: req.SessionType, - ThresholdN: int(req.ThresholdN), - ThresholdT: int(req.ThresholdT), - Participants: participants, - MessageHash: req.MessageHash, - ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second, - } - - // Execute use case - output, err := s.createSessionUC.Execute(ctx, inputData) - if err != nil { - return nil, toGRPCError(err) - } - - // Convert output to response - return &pb.CreateSessionResponse{ - SessionId: output.SessionID.String(), - JoinTokens: output.JoinTokens, - ExpiresAt: output.ExpiresAt.UnixMilli(), - }, nil -} - -// JoinSession allows a participant to join a session -func (s *SessionCoordinatorServer) JoinSession( - ctx context.Context, - req *pb.JoinSessionRequest, -) (*pb.JoinSessionResponse, error) { - sessionID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - - var deviceInfo entities.DeviceInfo - if req.DeviceInfo != nil { - deviceInfo = entities.DeviceInfo{ - DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), - DeviceID: req.DeviceInfo.DeviceId, - Platform: req.DeviceInfo.Platform, - AppVersion: req.DeviceInfo.AppVersion, - } - } - - inputData := input.JoinSessionInput{ - SessionID: sessionID, - PartyID: req.PartyId, - JoinToken: req.JoinToken, - DeviceInfo: deviceInfo, - } - - output, err := s.joinSessionUC.Execute(ctx, inputData) - if err != nil { - return nil, toGRPCError(err) - } - - // Convert other parties to response format - otherParties := make([]*pb.PartyInfo, len(output.OtherParties)) - for i, p := range output.OtherParties { - otherParties[i] = &pb.PartyInfo{ - PartyId: p.PartyID, - PartyIndex: int32(p.PartyIndex), - DeviceInfo: &pb.DeviceInfo{ - DeviceType: string(p.DeviceInfo.DeviceType), - DeviceId: p.DeviceInfo.DeviceID, - Platform: p.DeviceInfo.Platform, - AppVersion: p.DeviceInfo.AppVersion, - }, - } - } - - return &pb.JoinSessionResponse{ - Success: output.Success, - SessionInfo: &pb.SessionInfo{ - SessionId: output.SessionInfo.SessionID.String(), - SessionType: output.SessionInfo.SessionType, - ThresholdN: int32(output.SessionInfo.ThresholdN), - ThresholdT: int32(output.SessionInfo.ThresholdT), - MessageHash: output.SessionInfo.MessageHash, - Status: output.SessionInfo.Status, - }, - OtherParties: otherParties, - }, nil -} - -// GetSessionStatus retrieves the status of a session -func (s *SessionCoordinatorServer) GetSessionStatus( - ctx context.Context, - req *pb.GetSessionStatusRequest, -) (*pb.GetSessionStatusResponse, error) { - sessionID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - - output, err := s.getSessionStatusUC.Execute(ctx, sessionID) - if err != nil { - return nil, toGRPCError(err) - } - - // Calculate completed parties from participants - completedParties := 0 - for _, p := range output.Participants { - if p.Status == "completed" { - completedParties++ - } - } - - return &pb.GetSessionStatusResponse{ - Status: output.Status, - CompletedParties: int32(completedParties), - TotalParties: int32(len(output.Participants)), - PublicKey: output.PublicKey, - Signature: output.Signature, - }, nil -} - -// ReportCompletion reports that a participant has completed -func (s *SessionCoordinatorServer) ReportCompletion( - ctx context.Context, - req *pb.ReportCompletionRequest, -) (*pb.ReportCompletionResponse, error) { - sessionID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - - inputData := input.ReportCompletionInput{ - SessionID: sessionID, - PartyID: req.PartyId, - PublicKey: req.PublicKey, - Signature: req.Signature, - } - - output, err := s.reportCompletionUC.Execute(ctx, inputData) - if err != nil { - return nil, toGRPCError(err) - } - - return &pb.ReportCompletionResponse{ - Success: output.Success, - AllCompleted: output.AllCompleted, - }, nil -} - -// CloseSession closes a session -func (s *SessionCoordinatorServer) CloseSession( - ctx context.Context, - req *pb.CloseSessionRequest, -) (*pb.CloseSessionResponse, error) { - sessionID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - - err = s.closeSessionUC.Execute(ctx, sessionID) - if err != nil { - return nil, toGRPCError(err) - } - - return &pb.CloseSessionResponse{ - Success: true, - }, nil -} - -// MarkPartyReady marks a party as ready -func (s *SessionCoordinatorServer) MarkPartyReady( - ctx context.Context, - req *pb.MarkPartyReadyRequest, -) (*pb.MarkPartyReadyResponse, error) { - parsedID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - sessionID := value_objects.SessionIDFromUUID(parsedID) - - session, err := s.sessionRepo.FindByID(ctx, sessionID) - if err != nil { - return nil, toGRPCError(err) - } - if session == nil { - return nil, status.Error(codes.NotFound, "session not found") - } - - // Mark party as ready - if err := session.MarkPartyReady(req.PartyId); err != nil { - return nil, toGRPCError(err) - } - - // Save session - if err := s.sessionRepo.Update(ctx, session); err != nil { - return nil, toGRPCError(err) - } - - // Check if all parties are ready - allReady := session.AllPartiesReady() - - return &pb.MarkPartyReadyResponse{ - Success: true, - AllReady: allReady, - }, nil -} - -// StartSession starts a session -func (s *SessionCoordinatorServer) StartSession( - ctx context.Context, - req *pb.StartSessionRequest, -) (*pb.StartSessionResponse, error) { - parsedID, err := uuid.Parse(req.SessionId) - if err != nil { - return nil, status.Error(codes.InvalidArgument, "invalid session ID") - } - sessionID := value_objects.SessionIDFromUUID(parsedID) - - session, err := s.sessionRepo.FindByID(ctx, sessionID) - if err != nil { - return nil, toGRPCError(err) - } - if session == nil { - return nil, status.Error(codes.NotFound, "session not found") - } - - // Start the session - if err := session.Start(); err != nil { - return nil, toGRPCError(err) - } - - // Save session - if err := s.sessionRepo.Update(ctx, session); err != nil { - return nil, toGRPCError(err) - } - - return &pb.StartSessionResponse{ - Success: true, - }, nil -} - -// toGRPCError converts domain errors to gRPC errors -func toGRPCError(err error) error { - switch err { - case entities.ErrSessionExpired: - return status.Error(codes.DeadlineExceeded, err.Error()) - case entities.ErrSessionFull: - return status.Error(codes.ResourceExhausted, err.Error()) - case entities.ErrParticipantNotFound: - return status.Error(codes.NotFound, err.Error()) - case entities.ErrSessionNotInProgress: - return status.Error(codes.FailedPrecondition, err.Error()) - case entities.ErrInvalidSessionType: - return status.Error(codes.InvalidArgument, err.Error()) - default: - return status.Error(codes.Internal, err.Error()) - } -} +package grpc + +import ( + "context" + "time" + + "github.com/google/uuid" + pb "github.com/rwadurian/mpc-system/api/grpc/coordinator/v1" + "github.com/rwadurian/mpc-system/pkg/logger" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/ports/input" + "github.com/rwadurian/mpc-system/services/session-coordinator/application/use_cases" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/entities" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/repositories" + "github.com/rwadurian/mpc-system/services/session-coordinator/domain/value_objects" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// SessionCoordinatorServer implements the gRPC SessionCoordinator service +type SessionCoordinatorServer struct { + pb.UnimplementedSessionCoordinatorServer + createSessionUC *use_cases.CreateSessionUseCase + joinSessionUC *use_cases.JoinSessionUseCase + getSessionStatusUC *use_cases.GetSessionStatusUseCase + reportCompletionUC *use_cases.ReportCompletionUseCase + closeSessionUC *use_cases.CloseSessionUseCase + sessionRepo repositories.SessionRepository +} + +// NewSessionCoordinatorServer creates a new gRPC server +func NewSessionCoordinatorServer( + createSessionUC *use_cases.CreateSessionUseCase, + joinSessionUC *use_cases.JoinSessionUseCase, + getSessionStatusUC *use_cases.GetSessionStatusUseCase, + reportCompletionUC *use_cases.ReportCompletionUseCase, + closeSessionUC *use_cases.CloseSessionUseCase, + sessionRepo repositories.SessionRepository, +) *SessionCoordinatorServer { + return &SessionCoordinatorServer{ + createSessionUC: createSessionUC, + joinSessionUC: joinSessionUC, + getSessionStatusUC: getSessionStatusUC, + reportCompletionUC: reportCompletionUC, + closeSessionUC: closeSessionUC, + sessionRepo: sessionRepo, + } +} + +// CreateSession creates a new MPC session +func (s *SessionCoordinatorServer) CreateSession( + ctx context.Context, + req *pb.CreateSessionRequest, +) (*pb.CreateSessionResponse, error) { + logger.Info("gRPC CreateSession received", + zap.String("session_type", req.SessionType), + zap.Int32("threshold_n", req.ThresholdN), + zap.Int32("threshold_t", req.ThresholdT), + zap.Int("num_participants", len(req.Participants))) + + // Convert request to input + participants := make([]input.ParticipantInfo, len(req.Participants)) + for i, p := range req.Participants { + var deviceInfo entities.DeviceInfo + if p.DeviceInfo != nil { + deviceInfo = entities.DeviceInfo{ + DeviceType: entities.DeviceType(p.DeviceInfo.DeviceType), + DeviceID: p.DeviceInfo.DeviceId, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + } + } + participants[i] = input.ParticipantInfo{ + PartyID: p.PartyId, + DeviceInfo: deviceInfo, + } + } + + inputData := input.CreateSessionInput{ + InitiatorID: "", + SessionType: req.SessionType, + ThresholdN: int(req.ThresholdN), + ThresholdT: int(req.ThresholdT), + Participants: participants, + MessageHash: req.MessageHash, + ExpiresIn: time.Duration(req.ExpiresInSeconds) * time.Second, + } + + // Execute use case + output, err := s.createSessionUC.Execute(ctx, inputData) + if err != nil { + logger.Error("gRPC CreateSession use case failed", zap.Error(err)) + return nil, toGRPCError(err) + } + + logger.Info("gRPC CreateSession completed successfully", + zap.String("session_id", output.SessionID.String()), + zap.Int("num_join_tokens", len(output.JoinTokens))) + + // Convert output to response + return &pb.CreateSessionResponse{ + SessionId: output.SessionID.String(), + JoinTokens: output.JoinTokens, + ExpiresAt: output.ExpiresAt.UnixMilli(), + }, nil +} + +// JoinSession allows a participant to join a session +func (s *SessionCoordinatorServer) JoinSession( + ctx context.Context, + req *pb.JoinSessionRequest, +) (*pb.JoinSessionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + var deviceInfo entities.DeviceInfo + if req.DeviceInfo != nil { + deviceInfo = entities.DeviceInfo{ + DeviceType: entities.DeviceType(req.DeviceInfo.DeviceType), + DeviceID: req.DeviceInfo.DeviceId, + Platform: req.DeviceInfo.Platform, + AppVersion: req.DeviceInfo.AppVersion, + } + } + + inputData := input.JoinSessionInput{ + SessionID: sessionID, + PartyID: req.PartyId, + JoinToken: req.JoinToken, + DeviceInfo: deviceInfo, + } + + output, err := s.joinSessionUC.Execute(ctx, inputData) + if err != nil { + return nil, toGRPCError(err) + } + + // Convert other parties to response format + otherParties := make([]*pb.PartyInfo, len(output.OtherParties)) + for i, p := range output.OtherParties { + otherParties[i] = &pb.PartyInfo{ + PartyId: p.PartyID, + PartyIndex: int32(p.PartyIndex), + DeviceInfo: &pb.DeviceInfo{ + DeviceType: string(p.DeviceInfo.DeviceType), + DeviceId: p.DeviceInfo.DeviceID, + Platform: p.DeviceInfo.Platform, + AppVersion: p.DeviceInfo.AppVersion, + }, + } + } + + return &pb.JoinSessionResponse{ + Success: output.Success, + SessionInfo: &pb.SessionInfo{ + SessionId: output.SessionInfo.SessionID.String(), + SessionType: output.SessionInfo.SessionType, + ThresholdN: int32(output.SessionInfo.ThresholdN), + ThresholdT: int32(output.SessionInfo.ThresholdT), + MessageHash: output.SessionInfo.MessageHash, + Status: output.SessionInfo.Status, + }, + OtherParties: otherParties, + }, nil +} + +// GetSessionStatus retrieves the status of a session +func (s *SessionCoordinatorServer) GetSessionStatus( + ctx context.Context, + req *pb.GetSessionStatusRequest, +) (*pb.GetSessionStatusResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + output, err := s.getSessionStatusUC.Execute(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + + // Calculate completed parties from participants + completedParties := 0 + for _, p := range output.Participants { + if p.Status == "completed" { + completedParties++ + } + } + + return &pb.GetSessionStatusResponse{ + Status: output.Status, + CompletedParties: int32(completedParties), + TotalParties: int32(len(output.Participants)), + PublicKey: output.PublicKey, + Signature: output.Signature, + }, nil +} + +// ReportCompletion reports that a participant has completed +func (s *SessionCoordinatorServer) ReportCompletion( + ctx context.Context, + req *pb.ReportCompletionRequest, +) (*pb.ReportCompletionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + inputData := input.ReportCompletionInput{ + SessionID: sessionID, + PartyID: req.PartyId, + PublicKey: req.PublicKey, + Signature: req.Signature, + } + + output, err := s.reportCompletionUC.Execute(ctx, inputData) + if err != nil { + return nil, toGRPCError(err) + } + + return &pb.ReportCompletionResponse{ + Success: output.Success, + AllCompleted: output.AllCompleted, + }, nil +} + +// CloseSession closes a session +func (s *SessionCoordinatorServer) CloseSession( + ctx context.Context, + req *pb.CloseSessionRequest, +) (*pb.CloseSessionResponse, error) { + sessionID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + + err = s.closeSessionUC.Execute(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + + return &pb.CloseSessionResponse{ + Success: true, + }, nil +} + +// MarkPartyReady marks a party as ready +func (s *SessionCoordinatorServer) MarkPartyReady( + ctx context.Context, + req *pb.MarkPartyReadyRequest, +) (*pb.MarkPartyReadyResponse, error) { + parsedID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + sessionID := value_objects.SessionIDFromUUID(parsedID) + + session, err := s.sessionRepo.FindByID(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + if session == nil { + return nil, status.Error(codes.NotFound, "session not found") + } + + // Mark party as ready + if err := session.MarkPartyReady(req.PartyId); err != nil { + return nil, toGRPCError(err) + } + + // Save session + if err := s.sessionRepo.Update(ctx, session); err != nil { + return nil, toGRPCError(err) + } + + // Check if all parties are ready + allReady := session.AllPartiesReady() + + return &pb.MarkPartyReadyResponse{ + Success: true, + AllReady: allReady, + }, nil +} + +// StartSession starts a session +func (s *SessionCoordinatorServer) StartSession( + ctx context.Context, + req *pb.StartSessionRequest, +) (*pb.StartSessionResponse, error) { + parsedID, err := uuid.Parse(req.SessionId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid session ID") + } + sessionID := value_objects.SessionIDFromUUID(parsedID) + + session, err := s.sessionRepo.FindByID(ctx, sessionID) + if err != nil { + return nil, toGRPCError(err) + } + if session == nil { + return nil, status.Error(codes.NotFound, "session not found") + } + + // Start the session + if err := session.Start(); err != nil { + return nil, toGRPCError(err) + } + + // Save session + if err := s.sessionRepo.Update(ctx, session); err != nil { + return nil, toGRPCError(err) + } + + return &pb.StartSessionResponse{ + Success: true, + }, nil +} + +// toGRPCError converts domain errors to gRPC errors +func toGRPCError(err error) error { + switch err { + case entities.ErrSessionExpired: + return status.Error(codes.DeadlineExceeded, err.Error()) + case entities.ErrSessionFull: + return status.Error(codes.ResourceExhausted, err.Error()) + case entities.ErrParticipantNotFound: + return status.Error(codes.NotFound, err.Error()) + case entities.ErrSessionNotInProgress: + return status.Error(codes.FailedPrecondition, err.Error()) + case entities.ErrInvalidSessionType: + return status.Error(codes.InvalidArgument, err.Error()) + default: + return status.Error(codes.Internal, err.Error()) + } +}