rwadurian/backend/mpc-system/services/server-party/cmd/server/main.go

616 lines
18 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/hex"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
_ "github.com/lib/pq"
"github.com/rwadurian/mpc-system/pkg/config"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
grpcclient "github.com/rwadurian/mpc-system/services/server-party/adapters/output/grpc"
"github.com/rwadurian/mpc-system/services/server-party/adapters/output/postgres"
"github.com/rwadurian/mpc-system/services/server-party/application/use_cases"
"github.com/rwadurian/mpc-system/services/server-party/infrastructure/cache"
"go.uber.org/zap"
)
// Global share cache for delegate parties
var globalShareCache *cache.ShareCache
func main() {
// Parse flags
configPath := flag.String("config", "", "Path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Failed to load config: %v\n", err)
os.Exit(1)
}
// Initialize logger
if err := logger.Init(&logger.Config{
Level: cfg.Logger.Level,
Encoding: cfg.Logger.Encoding,
}); err != nil {
fmt.Printf("Failed to initialize logger: %v\n", err)
os.Exit(1)
}
defer logger.Sync()
logger.Info("Starting Server Party Service",
zap.String("environment", cfg.Server.Environment),
zap.Int("http_port", cfg.Server.HTTPPort))
// Initialize share cache for delegate parties (15 minute TTL)
globalShareCache = cache.NewShareCache(15 * time.Minute)
logger.Info("Share cache initialized", zap.Duration("ttl", 15*time.Minute))
// Initialize database connection
db, err := initDatabase(cfg.Database)
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer db.Close()
// Initialize crypto service with master key from environment
masterKeyHex := os.Getenv("MPC_CRYPTO_MASTER_KEY")
if masterKeyHex == "" {
masterKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" // 64 hex chars = 32 bytes
}
masterKey, err := hex.DecodeString(masterKeyHex)
if err != nil {
logger.Fatal("Invalid master key format", zap.Error(err))
}
cryptoService, err := crypto.NewCryptoService(masterKey)
if err != nil {
logger.Fatal("Failed to create crypto service", zap.Error(err))
}
// Get gRPC service addresses from environment
coordinatorAddr := os.Getenv("SESSION_COORDINATOR_ADDR")
if coordinatorAddr == "" {
coordinatorAddr = "localhost:9091"
}
routerAddr := os.Getenv("MESSAGE_ROUTER_ADDR")
if routerAddr == "" {
routerAddr = "localhost:9092"
}
// Initialize gRPC clients
sessionClient, err := grpcclient.NewSessionCoordinatorClient(coordinatorAddr)
if err != nil {
logger.Fatal("Failed to connect to session coordinator", zap.Error(err))
}
defer sessionClient.Close()
messageRouter, err := grpcclient.NewMessageRouterClient(routerAddr)
if err != nil {
logger.Fatal("Failed to connect to message router", zap.Error(err))
}
defer messageRouter.Close()
// Initialize repositories
keyShareRepo := postgres.NewKeySharePostgresRepo(db)
// Initialize use cases with real gRPC clients
participateKeygenUC := use_cases.NewParticipateKeygenUseCase(
keyShareRepo,
sessionClient,
messageRouter,
cryptoService,
)
participateSigningUC := use_cases.NewParticipateSigningUseCase(
keyShareRepo,
sessionClient,
messageRouter,
cryptoService,
)
// Create shutdown context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Get party ID from environment (use pod name in K8s, or hostname)
partyID := os.Getenv("PARTY_ID")
if partyID == "" {
partyID, _ = os.Hostname()
if partyID == "" {
partyID = "server-party-" + uuid.New().String()[:8]
}
}
// Get party role from environment (default: persistent)
partyRole := os.Getenv("PARTY_ROLE")
if partyRole == "" {
partyRole = "persistent"
}
// Register this party with Message Router
logger.Info("Registering party with Message Router",
zap.String("party_id", partyID),
zap.String("role", partyRole))
if err := messageRouter.RegisterParty(ctx, partyID, partyRole, "1.0.0"); err != nil {
logger.Fatal("Failed to register party", zap.Error(err))
}
// Subscribe to session events and handle them automatically
// Note: This will work after protobuf regeneration
logger.Info("Subscribing to session events", zap.String("party_id", partyID))
// TODO: Uncomment after protobuf regeneration
/*
eventHandler := createSessionEventHandler(
ctx,
partyID,
participateKeygenUC,
participateSigningUC,
sessionClient,
)
if err := messageRouter.SubscribeSessionEvents(ctx, partyID, eventHandler); err != nil {
logger.Fatal("Failed to subscribe to session events", zap.Error(err))
}
*/
logger.Info("Party-driven architecture initialized successfully",
zap.String("party_id", partyID),
zap.String("role", partyRole))
// Start HTTP server
errChan := make(chan error, 1)
go func() {
if err := startHTTPServer(cfg, participateKeygenUC, participateSigningUC, keyShareRepo); err != nil {
errChan <- fmt.Errorf("HTTP server error: %w", err)
}
}()
// Wait for shutdown signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigChan:
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
case err := <-errChan:
logger.Error("Server error", zap.Error(err))
}
// Graceful shutdown
logger.Info("Shutting down...")
cancel()
time.Sleep(5 * time.Second)
logger.Info("Shutdown complete")
_ = ctx
}
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
const maxRetries = 10
const retryDelay = 2 * time.Second
var db *sql.DB
var err error
for i := 0; i < maxRetries; i++ {
db, err = sql.Open("postgres", cfg.DSN())
if err != nil {
logger.Warn("Failed to open database connection, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLife)
// Test connection with Ping
if err = db.Ping(); err != nil {
logger.Warn("Failed to ping database, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
// Verify database is actually usable with a simple query
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
var result int
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
cancel()
if err != nil {
logger.Warn("Database ping succeeded but query failed, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
logger.Info("Connected to PostgreSQL and verified connectivity",
zap.Int("attempt", i+1))
return db, nil
}
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
}
func startHTTPServer(
cfg *config.Config,
participateKeygenUC *use_cases.ParticipateKeygenUseCase,
participateSigningUC *use_cases.ParticipateSigningUseCase,
keyShareRepo *postgres.KeySharePostgresRepo,
) error {
if cfg.Server.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.New()
router.Use(gin.Recovery())
router.Use(gin.Logger())
// Health check
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "server-party",
})
})
// API routes
api := router.Group("/api/v1")
{
// Keygen participation endpoint
api.POST("/keygen/participate", func(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
PartyID string `json:"party_id" binding:"required"`
JoinToken string `json:"join_token" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
sessionID, err := uuid.Parse(req.SessionID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
return
}
// Execute keygen participation asynchronously
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
input := use_cases.ParticipateKeygenInput{
SessionID: sessionID,
PartyID: req.PartyID,
JoinToken: req.JoinToken,
}
output, err := participateKeygenUC.Execute(ctx, input)
if err != nil {
logger.Error("Keygen participation failed",
zap.String("session_id", req.SessionID),
zap.String("party_id", req.PartyID),
zap.Error(err))
return
}
logger.Info("Keygen participation completed",
zap.String("session_id", req.SessionID),
zap.String("party_id", req.PartyID),
zap.Bool("success", output.Success))
// If this is a delegate party and share is available, store in cache
if output.ShareForUser != nil && len(output.ShareForUser) > 0 {
globalShareCache.Store(
sessionID,
req.PartyID,
output.ShareForUser,
output.PublicKey,
)
logger.Info("Share stored in cache for user retrieval (delegate party)",
zap.String("session_id", req.SessionID),
zap.String("party_id", req.PartyID),
zap.Int("share_size", len(output.ShareForUser)))
}
}()
c.JSON(http.StatusAccepted, gin.H{
"message": "keygen participation initiated",
"session_id": req.SessionID,
"party_id": req.PartyID,
})
})
// Signing participation endpoint
api.POST("/sign/participate", func(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
PartyID string `json:"party_id" binding:"required"`
JoinToken string `json:"join_token" binding:"required"`
MessageHash string `json:"message_hash"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
sessionID, err := uuid.Parse(req.SessionID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid session_id format"})
return
}
// Parse message hash if provided
var messageHash []byte
if req.MessageHash != "" {
messageHash, err = hex.DecodeString(req.MessageHash)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid message_hash format (expected hex)"})
return
}
}
// Execute signing participation asynchronously
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
input := use_cases.ParticipateSigningInput{
SessionID: sessionID,
PartyID: req.PartyID,
JoinToken: req.JoinToken,
MessageHash: messageHash,
}
output, err := participateSigningUC.Execute(ctx, input)
if err != nil {
logger.Error("Signing participation failed",
zap.String("session_id", req.SessionID),
zap.String("party_id", req.PartyID),
zap.Error(err))
return
}
logger.Info("Signing participation completed",
zap.String("session_id", req.SessionID),
zap.String("party_id", req.PartyID),
zap.Bool("success", output.Success),
zap.Int("signature_len", len(output.Signature)))
}()
c.JSON(http.StatusAccepted, gin.H{
"message": "signing participation initiated",
"session_id": req.SessionID,
"party_id": req.PartyID,
})
})
// Get key shares for a party
api.GET("/shares/:party_id", func(c *gin.Context) {
partyID := c.Param("party_id")
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
shares, err := keyShareRepo.ListByParty(ctx, partyID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch shares"})
return
}
// Return share metadata (not the actual encrypted data)
shareInfos := make([]gin.H, len(shares))
for i, share := range shares {
shareInfos[i] = gin.H{
"id": share.ID.String(),
"party_id": share.PartyID,
"party_index": share.PartyIndex,
"public_key": hex.EncodeToString(share.PublicKey),
"created_at": share.CreatedAt,
"last_used": share.LastUsedAt,
}
}
c.JSON(http.StatusOK, gin.H{
"party_id": partyID,
"count": len(shares),
"shares": shareInfos,
})
})
// Get user share for delegate parties (one-time retrieval)
// This endpoint is ONLY for delegate parties to return shares to users
api.GET("/sessions/:session_id/user-share", func(c *gin.Context) {
sessionIDStr := c.Param("session_id")
// Check if this is a delegate party
partyRole := os.Getenv("PARTY_ROLE")
if partyRole != "delegate" {
c.JSON(http.StatusForbidden, gin.H{
"error": "This endpoint is only available for delegate parties",
"role": partyRole,
})
return
}
sessionID, err := uuid.Parse(sessionIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid session_id format",
})
return
}
// Retrieve and delete share from cache (one-time retrieval)
entry, exists := globalShareCache.GetAndDelete(sessionID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{
"error": "Share not found or already retrieved",
"note": "Shares can only be retrieved once and expire after 15 minutes",
})
return
}
logger.Info("User share retrieved successfully",
zap.String("session_id", sessionIDStr),
zap.String("party_id", entry.PartyID))
c.JSON(http.StatusOK, gin.H{
"session_id": sessionIDStr,
"party_id": entry.PartyID,
"share": hex.EncodeToString(entry.Share),
"public_key": hex.EncodeToString(entry.PublicKey),
"note": "This share has been deleted from memory and cannot be retrieved again",
})
})
}
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
}
// createSessionEventHandler creates a handler for session events
// This implements the party-driven architecture where parties automatically
// respond to session creation events
//
// TODO: After protobuf regeneration, uncomment this function and update the import
// to include: router "github.com/rwadurian/mpc-system/api/grpc/router/v1"
func createSessionEventHandler(
ctx context.Context,
partyID string,
participateKeygenUC *use_cases.ParticipateKeygenUseCase,
participateSigningUC *use_cases.ParticipateSigningUseCase,
sessionClient *grpcclient.SessionCoordinatorClient,
) func(event interface{}) {
return func(eventInterface interface{}) {
// After protobuf regeneration, uncomment and use this implementation:
/*
event, ok := eventInterface.(*router.SessionEvent)
if !ok {
logger.Error("Invalid event type")
return
}
// Check if this party is selected for the session
isSelected := false
for _, selectedParty := range event.SelectedParties {
if selectedParty == partyID {
isSelected = true
break
}
}
if !isSelected {
logger.Debug("Party not selected for this session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
return
}
// Get join token for this party
joinToken, exists := event.JoinTokens[partyID]
if !exists {
logger.Error("No join token found for party",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
return
}
logger.Info("Party selected for session, auto-participating",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID),
zap.String("event_type", event.EventType))
// Parse session ID
sessionID, err := uuid.Parse(event.SessionId)
if err != nil {
logger.Error("Invalid session ID", zap.Error(err))
return
}
// Automatically participate based on session type
go func() {
ctx := context.Background()
// Determine session type from event
if event.EventType == "session_created" {
// Check if it's keygen or sign based on message_hash
if len(event.MessageHash) == 0 {
// Keygen session
logger.Info("Auto-participating in keygen session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
input := use_cases.ParticipateKeygenInput{
SessionID: sessionID,
PartyID: partyID,
JoinToken: joinToken,
}
result, err := participateKeygenUC.Execute(ctx, input)
if err != nil {
logger.Error("Keygen participation failed",
zap.Error(err),
zap.String("session_id", event.SessionId))
} else {
logger.Info("Keygen participation completed",
zap.String("session_id", event.SessionId),
zap.String("public_key", result.PublicKeyHex))
}
} else {
// Sign session
logger.Info("Auto-participating in sign session",
zap.String("session_id", event.SessionId),
zap.String("party_id", partyID))
input := use_cases.ParticipateSigningInput{
SessionID: sessionID,
PartyID: partyID,
JoinToken: joinToken,
MessageHash: event.MessageHash,
}
result, err := participateSigningUC.Execute(ctx, input)
if err != nil {
logger.Error("Signing participation failed",
zap.Error(err),
zap.String("session_id", event.SessionId))
} else {
logger.Info("Signing participation completed",
zap.String("session_id", event.SessionId),
zap.String("signature", result.SignatureHex))
}
}
}
}()
*/
}
}