package main import ( "context" "database/sql" "encoding/hex" "flag" "fmt" "net/http" "os" "os/signal" "sync" "syscall" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" _ "github.com/lib/pq" router "github.com/rwadurian/mpc-system/api/grpc/router/v1" "github.com/rwadurian/mpc-system/pkg/config" "github.com/rwadurian/mpc-system/pkg/crypto" "github.com/rwadurian/mpc-system/pkg/logger" grpcclient "github.com/rwadurian/mpc-system/services/server-party/adapters/output/grpc" "github.com/rwadurian/mpc-system/services/server-party/adapters/output/postgres" "github.com/rwadurian/mpc-system/services/server-party/application/use_cases" "go.uber.org/zap" ) // PendingSession stores session info between session_created and session_started events type PendingSession struct { SessionID uuid.UUID JoinToken string MessageHash []byte ThresholdN int ThresholdT int SelectedParties []string CreatedAt time.Time } // PendingSessionCache stores pending sessions waiting for session_started type PendingSessionCache struct { mu sync.RWMutex sessions map[string]*PendingSession // sessionID -> PendingSession } // Global pending session cache var pendingSessionCache = &PendingSessionCache{ sessions: make(map[string]*PendingSession), } // Store stores a pending session func (c *PendingSessionCache) Store(sessionID string, session *PendingSession) { c.mu.Lock() defer c.mu.Unlock() c.sessions[sessionID] = session logger.Info("Pending session stored", zap.String("session_id", sessionID)) } // Get retrieves and deletes a pending session func (c *PendingSessionCache) Get(sessionID string) (*PendingSession, bool) { c.mu.Lock() defer c.mu.Unlock() session, exists := c.sessions[sessionID] if exists { delete(c.sessions, sessionID) logger.Info("Pending session retrieved and deleted", zap.String("session_id", sessionID)) } return session, exists } // Delete removes a pending session without returning it func (c *PendingSessionCache) Delete(sessionID string) { c.mu.Lock() defer c.mu.Unlock() delete(c.sessions, sessionID) } func main() { // Parse flags configPath := flag.String("config", "", "Path to config file") flag.Parse() // Load configuration cfg, err := config.Load(*configPath) if err != nil { fmt.Printf("Failed to load config: %v\n", err) os.Exit(1) } // Initialize logger if err := logger.Init(&logger.Config{ Level: cfg.Logger.Level, Encoding: cfg.Logger.Encoding, }); err != nil { fmt.Printf("Failed to initialize logger: %v\n", err) os.Exit(1) } defer logger.Sync() logger.Info("Starting Server Party Co-Managed Service", zap.String("environment", cfg.Server.Environment), zap.Int("http_port", cfg.Server.HTTPPort)) // Initialize database connection db, err := initDatabase(cfg.Database) if err != nil { logger.Fatal("Failed to connect to database", zap.Error(err)) } defer db.Close() // Initialize crypto service with master key from environment masterKeyHex := os.Getenv("MPC_CRYPTO_MASTER_KEY") if masterKeyHex == "" { masterKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" // 64 hex chars = 32 bytes } masterKey, err := hex.DecodeString(masterKeyHex) if err != nil { logger.Fatal("Invalid master key format", zap.Error(err)) } cryptoService, err := crypto.NewCryptoService(masterKey) if err != nil { logger.Fatal("Failed to create crypto service", zap.Error(err)) } // Get Message Router address from environment routerAddr := os.Getenv("MESSAGE_ROUTER_ADDR") if routerAddr == "" { routerAddr = "localhost:9092" } // Initialize Message Router client messageRouter, err := grpcclient.NewMessageRouterClient(routerAddr) if err != nil { logger.Fatal("Failed to connect to message router", zap.Error(err)) } defer messageRouter.Close() // Initialize repositories keyShareRepo := postgres.NewKeySharePostgresRepo(db) // Initialize use cases participateKeygenUC := use_cases.NewParticipateKeygenUseCase( keyShareRepo, messageRouter, messageRouter, cryptoService, ) // Create shutdown context ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Get party ID from environment partyID := os.Getenv("PARTY_ID") if partyID == "" { partyID, _ = os.Hostname() if partyID == "" { partyID = "co-managed-party-" + uuid.New().String()[:8] } } // Party role is co_managed_persistent - different from normal persistent // This ensures co_managed_keygen sessions only select these parties partyRole := "co_managed_persistent" // Register this party with Message Router logger.Info("Registering co-managed party with Message Router", zap.String("party_id", partyID), zap.String("role", partyRole)) if err := messageRouter.RegisterPartyWithNotification(ctx, partyID, partyRole, "1.0.0", nil); err != nil { logger.Fatal("Failed to register party", zap.Error(err)) } // Start heartbeat heartbeatCancel := messageRouter.StartHeartbeat(ctx, partyID, 30*time.Second, func(pendingCount int32) { if pendingCount > 0 { logger.Info("Pending messages detected via heartbeat", zap.String("party_id", partyID), zap.Int32("pending_count", pendingCount)) } }) defer heartbeatCancel() logger.Info("Heartbeat started", zap.String("party_id", partyID), zap.Duration("interval", 30*time.Second)) // Subscribe to session events with two-phase handling for co_managed_keygen logger.Info("Subscribing to session events (co_managed_keygen only)", zap.String("party_id", partyID)) eventHandler := createCoManagedSessionEventHandler( ctx, partyID, messageRouter, participateKeygenUC, ) if err := messageRouter.SubscribeSessionEvents(ctx, partyID, eventHandler); err != nil { logger.Fatal("Failed to subscribe to session events", zap.Error(err)) } logger.Info("Co-managed party initialized successfully", zap.String("party_id", partyID), zap.String("role", partyRole)) // Start HTTP server errChan := make(chan error, 1) go func() { if err := startHTTPServer(cfg); err != nil { errChan <- fmt.Errorf("HTTP server error: %w", err) } }() // Wait for shutdown signal sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) select { case sig := <-sigChan: logger.Info("Received shutdown signal", zap.String("signal", sig.String())) case err := <-errChan: logger.Error("Server error", zap.Error(err)) } // Graceful shutdown logger.Info("Shutting down...") cancel() time.Sleep(5 * time.Second) logger.Info("Shutdown complete") } func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) { const maxRetries = 10 const retryDelay = 2 * time.Second var db *sql.DB var err error for i := 0; i < maxRetries; i++ { db, err = sql.Open("postgres", cfg.DSN()) if err != nil { logger.Warn("Failed to open database connection, retrying...", zap.Int("attempt", i+1), zap.Int("max_retries", maxRetries), zap.Error(err)) time.Sleep(retryDelay * time.Duration(i+1)) continue } db.SetMaxOpenConns(cfg.MaxOpenConns) db.SetMaxIdleConns(cfg.MaxIdleConns) db.SetConnMaxLifetime(cfg.ConnMaxLife) if err = db.Ping(); err != nil { logger.Warn("Failed to ping database, retrying...", zap.Int("attempt", i+1), zap.Int("max_retries", maxRetries), zap.Error(err)) db.Close() time.Sleep(retryDelay * time.Duration(i+1)) continue } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) var result int err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result) cancel() if err != nil { logger.Warn("Database ping succeeded but query failed, retrying...", zap.Int("attempt", i+1), zap.Int("max_retries", maxRetries), zap.Error(err)) db.Close() time.Sleep(retryDelay * time.Duration(i+1)) continue } logger.Info("Connected to PostgreSQL and verified connectivity", zap.Int("attempt", i+1)) return db, nil } return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err) } func startHTTPServer(cfg *config.Config) error { if cfg.Server.Environment == "production" { gin.SetMode(gin.ReleaseMode) } r := gin.New() r.Use(gin.Recovery()) r.Use(gin.Logger()) // Health check r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "healthy", "service": "server-party-co-managed", }) }) logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort)) return r.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort)) } // createCoManagedSessionEventHandler creates a handler specifically for co_managed_keygen sessions // Two-phase event handling: // Phase 1 (session_created): JoinSession immediately + store session info // Phase 2 (session_started): Execute TSS protocol (same timing as user clients receiving all_joined) func createCoManagedSessionEventHandler( ctx context.Context, partyID string, messageRouter *grpcclient.MessageRouterClient, participateKeygenUC *use_cases.ParticipateKeygenUseCase, ) func(*router.SessionEvent) { return func(event *router.SessionEvent) { // Check if this party is selected for the session isSelected := false for _, selectedParty := range event.SelectedParties { if selectedParty == partyID { isSelected = true break } } if !isSelected { logger.Debug("Party not selected for this session", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) return } logger.Info("Received session event", zap.String("session_id", event.SessionId), zap.String("party_id", partyID), zap.String("event_type", event.EventType)) // Parse session ID sessionID, err := uuid.Parse(event.SessionId) if err != nil { logger.Error("Invalid session ID", zap.Error(err)) return } // Handle different event types switch event.EventType { case "session_created": // Only handle keygen sessions (no message_hash) if len(event.MessageHash) > 0 { logger.Debug("Ignoring sign session (co-managed only handles keygen)", zap.String("session_id", event.SessionId)) return } // Phase 1: Get join token joinToken, exists := event.JoinTokens[partyID] if !exists { logger.Error("No join token found for party in session_created", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) return } // Immediately call JoinSession (this is required to trigger session_started) joinCtx, joinCancel := context.WithTimeout(ctx, 30*time.Second) _, err := messageRouter.JoinSession(joinCtx, sessionID, partyID, joinToken) joinCancel() if err != nil { logger.Error("Failed to join session", zap.String("session_id", event.SessionId), zap.String("party_id", partyID), zap.Error(err)) return } logger.Info("Successfully joined session, waiting for session_started", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) // Store pending session for later use when session_started arrives pendingSessionCache.Store(event.SessionId, &PendingSession{ SessionID: sessionID, JoinToken: joinToken, MessageHash: event.MessageHash, ThresholdN: int(event.ThresholdN), ThresholdT: int(event.ThresholdT), SelectedParties: event.SelectedParties, CreatedAt: time.Now(), }) case "session_started": // Phase 2: All participants have joined, now execute TSS protocol pendingSession, exists := pendingSessionCache.Get(event.SessionId) if !exists { logger.Warn("No pending session found for session_started event", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) return } logger.Info("Session started event received, beginning TSS keygen protocol", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) // Execute TSS keygen protocol in goroutine // Timeout starts NOW (when session_started is received), not at session_created go func() { // 10 minute timeout for TSS protocol execution participateCtx, cancel := context.WithTimeout(ctx, 10*time.Minute) defer cancel() logger.Info("Auto-participating in co_managed_keygen session", zap.String("session_id", event.SessionId), zap.String("party_id", partyID)) // Build SessionInfo from session_started event (NOT from pendingSession cache) // session_started event contains ALL participants who have joined, // including external parties that joined dynamically after session_created // Note: We already called JoinSession in session_created phase, // so we use ExecuteWithSessionInfo to skip the duplicate JoinSession call participants := make([]use_cases.ParticipantInfo, len(event.SelectedParties)) for i, p := range event.SelectedParties { participants[i] = use_cases.ParticipantInfo{ PartyID: p, PartyIndex: i, } } sessionInfo := &use_cases.SessionInfo{ SessionID: pendingSession.SessionID, SessionType: "co_managed_keygen", ThresholdN: int(event.ThresholdN), ThresholdT: int(event.ThresholdT), MessageHash: pendingSession.MessageHash, Participants: participants, } result, err := participateKeygenUC.ExecuteWithSessionInfo( participateCtx, pendingSession.SessionID, partyID, sessionInfo, ) if err != nil { logger.Error("Co-managed keygen participation failed", zap.Error(err), zap.String("session_id", event.SessionId)) } else { logger.Info("Co-managed keygen participation completed", zap.String("session_id", event.SessionId), zap.String("public_key", hex.EncodeToString(result.PublicKey))) } }() default: logger.Debug("Ignoring unhandled event type", zap.String("session_id", event.SessionId), zap.String("event_type", event.EventType)) } } }