rwadurian/backend/mpc-system/services/server-party-co-managed/cmd/server/main.go

465 lines
14 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/hex"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
_ "github.com/lib/pq"
router "github.com/rwadurian/mpc-system/api/grpc/router/v1"
"github.com/rwadurian/mpc-system/pkg/config"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
grpcclient "github.com/rwadurian/mpc-system/services/server-party/adapters/output/grpc"
"github.com/rwadurian/mpc-system/services/server-party/adapters/output/postgres"
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
"go.uber.org/zap"
)
// PendingSession stores session info between session_created and session_started events
type PendingSession struct {
SessionID uuid.UUID
JoinToken string
MessageHash []byte
ThresholdN int
ThresholdT int
SelectedParties []string
CreatedAt time.Time
}
// PendingSessionCache stores pending sessions waiting for session_started
type PendingSessionCache struct {
mu sync.RWMutex
sessions map[string]*PendingSession // sessionID -> PendingSession
}
// Global pending session cache
var pendingSessionCache = &PendingSessionCache{
sessions: make(map[string]*PendingSession),
}
// Store stores a pending session
func (c *PendingSessionCache) Store(sessionID string, session *PendingSession) {
c.mu.Lock()
defer c.mu.Unlock()
c.sessions[sessionID] = session
logger.Info("Pending session stored",
zap.String("session_id", sessionID))
}
// Get retrieves and deletes a pending session
func (c *PendingSessionCache) Get(sessionID string) (*PendingSession, bool) {
c.mu.Lock()
defer c.mu.Unlock()
session, exists := c.sessions[sessionID]
if exists {
delete(c.sessions, sessionID)
logger.Info("Pending session retrieved and deleted",
zap.String("session_id", sessionID))
}
return session, exists
}
// Delete removes a pending session without returning it
func (c *PendingSessionCache) Delete(sessionID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sessionID)
}
func main() {
// Parse flags
configPath := flag.String("config", "", "Path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Failed to load config: %v\n", err)
os.Exit(1)
}
// Initialize logger
if err := logger.Init(&logger.Config{
Level: cfg.Logger.Level,
Encoding: cfg.Logger.Encoding,
}); err != nil {
fmt.Printf("Failed to initialize logger: %v\n", err)
os.Exit(1)
}
defer logger.Sync()
logger.Info("Starting Server Party Co-Managed Service",
zap.String("environment", cfg.Server.Environment),
zap.Int("http_port", cfg.Server.HTTPPort))
// Initialize database connection
db, err := initDatabase(cfg.Database)
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer db.Close()
// Initialize crypto service with master key from environment
masterKeyHex := os.Getenv("MPC_CRYPTO_MASTER_KEY")
if masterKeyHex == "" {
masterKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" // 64 hex chars = 32 bytes
}
masterKey, err := hex.DecodeString(masterKeyHex)
if err != nil {
logger.Fatal("Invalid master key format", zap.Error(err))
}
cryptoService, err := crypto.NewCryptoService(masterKey)
if err != nil {
logger.Fatal("Failed to create crypto service", zap.Error(err))
}
// Get Message Router address from environment
routerAddr := os.Getenv("MESSAGE_ROUTER_ADDR")
if routerAddr == "" {
routerAddr = "localhost:9092"
}
// Initialize Message Router client
messageRouter, err := grpcclient.NewMessageRouterClient(routerAddr)
if err != nil {
logger.Fatal("Failed to connect to message router", zap.Error(err))
}
defer messageRouter.Close()
// Initialize repositories
keyShareRepo := postgres.NewKeySharePostgresRepo(db)
// Initialize use cases
participateKeygenUC := use_cases.NewParticipateKeygenUseCase(
keyShareRepo,
messageRouter,
messageRouter,
cryptoService,
)
// Create shutdown context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Get party ID from environment
partyID := os.Getenv("PARTY_ID")
if partyID == "" {
partyID, _ = os.Hostname()
if partyID == "" {
partyID = "co-managed-party-" + uuid.New().String()[:8]
}
}
// Party role is co_managed_persistent - different from normal persistent
// This ensures co_managed_keygen sessions only select these parties
partyRole := "co_managed_persistent"
// Register this party with Message Router
logger.Info("Registering co-managed party with Message Router",
zap.String("party_id", partyID),
zap.String("role", partyRole))
if err := messageRouter.RegisterPartyWithNotification(ctx, partyID, partyRole, "1.0.0", nil); err != nil {
logger.Fatal("Failed to register party", zap.Error(err))
}
// Start heartbeat
heartbeatCancel := messageRouter.StartHeartbeat(ctx, partyID, 30*time.Second, func(pendingCount int32) {
if pendingCount > 0 {
logger.Info("Pending messages detected via heartbeat",
zap.String("party_id", partyID),
zap.Int32("pending_count", pendingCount))
}
})
defer heartbeatCancel()
logger.Info("Heartbeat started", zap.String("party_id", partyID), zap.Duration("interval", 30*time.Second))
// Subscribe to session events with two-phase handling for co_managed_keygen
logger.Info("Subscribing to session events (co_managed_keygen only)", zap.String("party_id", partyID))
eventHandler := createCoManagedSessionEventHandler(
ctx,
partyID,
messageRouter,
participateKeygenUC,
)
if err := messageRouter.SubscribeSessionEvents(ctx, partyID, eventHandler); err != nil {
logger.Fatal("Failed to subscribe to session events", zap.Error(err))
}
logger.Info("Co-managed party initialized successfully",
zap.String("party_id", partyID),
zap.String("role", partyRole))
// Start HTTP server
errChan := make(chan error, 1)
go func() {
if err := startHTTPServer(cfg); err != nil {
errChan <- fmt.Errorf("HTTP server error: %w", err)
}
}()
// Wait for shutdown signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigChan:
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
case err := <-errChan:
logger.Error("Server error", zap.Error(err))
}
// Graceful shutdown
logger.Info("Shutting down...")
cancel()
time.Sleep(5 * time.Second)
logger.Info("Shutdown complete")
}
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
const maxRetries = 10
const retryDelay = 2 * time.Second
var db *sql.DB
var err error
for i := 0; i < maxRetries; i++ {
db, err = sql.Open("postgres", cfg.DSN())
if err != nil {
logger.Warn("Failed to open database connection, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLife)
if err = db.Ping(); err != nil {
logger.Warn("Failed to ping database, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
var result int
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
cancel()
if err != nil {
logger.Warn("Database ping succeeded but query failed, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
logger.Info("Connected to PostgreSQL and verified connectivity",
zap.Int("attempt", i+1))
return db, nil
}
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
}
func startHTTPServer(cfg *config.Config) error {
if cfg.Server.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(gin.Recovery())
r.Use(gin.Logger())
// Health check
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "server-party-co-managed",
})
})
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
return r.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
}
// createCoManagedSessionEventHandler creates a handler specifically for co_managed_keygen sessions
// Two-phase event handling:
// Phase 1 (session_created): JoinSession immediately + store session info
// Phase 2 (session_started): Execute TSS protocol (same timing as user clients receiving all_joined)
func createCoManagedSessionEventHandler(
ctx context.Context,
partyID string,
messageRouter *grpcclient.MessageRouterClient,
participateKeygenUC *use_cases.ParticipateKeygenUseCase,
) func(*router.SessionEvent) {
return func(event *router.SessionEvent) {
// Check if this party is selected for the session
isSelected := false
for _, selectedParty := range event.SelectedParties {
if selectedParty == partyID {
isSelected = true
break
}
}
if !isSelected {
logger.Debug("Party not selected for this session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
return
}
logger.Info("Received session event",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID),
zap.String("event_type", event.EventType))
// Parse session ID
sessionID, err := uuid.Parse(event.SessionId)
if err != nil {
logger.Error("Invalid session ID", zap.Error(err))
return
}
// Handle different event types
switch event.EventType {
case "session_created":
// Only handle keygen sessions (no message_hash)
if len(event.MessageHash) > 0 {
logger.Debug("Ignoring sign session (co-managed only handles keygen)",
zap.String("session_id", event.SessionId))
return
}
// Phase 1: Get join token
joinToken, exists := event.JoinTokens[partyID]
if !exists {
logger.Error("No join token found for party in session_created",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
return
}
// Immediately call JoinSession (this is required to trigger session_started)
joinCtx, joinCancel := context.WithTimeout(ctx, 30*time.Second)
_, err := messageRouter.JoinSession(joinCtx, sessionID, partyID, joinToken)
joinCancel()
if err != nil {
logger.Error("Failed to join session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID),
zap.Error(err))
return
}
logger.Info("Successfully joined session, waiting for session_started",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
// Store pending session for later use when session_started arrives
pendingSessionCache.Store(event.SessionId, &PendingSession{
SessionID: sessionID,
JoinToken: joinToken,
MessageHash: event.MessageHash,
ThresholdN: int(event.ThresholdN),
ThresholdT: int(event.ThresholdT),
SelectedParties: event.SelectedParties,
CreatedAt: time.Now(),
})
case "session_started":
// Phase 2: All participants have joined, now execute TSS protocol
pendingSession, exists := pendingSessionCache.Get(event.SessionId)
if !exists {
logger.Warn("No pending session found for session_started event",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
return
}
logger.Info("Session started event received, beginning TSS keygen protocol",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
// Execute TSS keygen protocol in goroutine
// Timeout starts NOW (when session_started is received), not at session_created
go func() {
// 10 minute timeout for TSS protocol execution
participateCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
logger.Info("Auto-participating in co_managed_keygen session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
// Build SessionInfo from session_started event (NOT from pendingSession cache)
// session_started event contains ALL participants who have joined,
// including external parties that joined dynamically after session_created
// Note: We already called JoinSession in session_created phase,
// so we use ExecuteWithSessionInfo to skip the duplicate JoinSession call
participants := make([]use_cases.ParticipantInfo, len(event.SelectedParties))
for i, p := range event.SelectedParties {
participants[i] = use_cases.ParticipantInfo{
PartyID: p,
PartyIndex: i,
}
}
sessionInfo := &use_cases.SessionInfo{
SessionID: pendingSession.SessionID,
SessionType: "co_managed_keygen",
ThresholdN: int(event.ThresholdN),
ThresholdT: int(event.ThresholdT),
MessageHash: pendingSession.MessageHash,
Participants: participants,
}
result, err := participateKeygenUC.ExecuteWithSessionInfo(
participateCtx,
pendingSession.SessionID,
partyID,
sessionInfo,
)
if err != nil {
logger.Error("Co-managed keygen participation failed",
zap.Error(err),
zap.String("session_id", event.SessionId))
} else {
logger.Info("Co-managed keygen participation completed",
zap.String("session_id", event.SessionId),
zap.String("public_key", hex.EncodeToString(result.PublicKey)))
}
}()
default:
logger.Debug("Ignoring unhandled event type",
zap.String("session_id", event.SessionId),
zap.String("event_type", event.EventType))
}
}
}