376 lines
11 KiB
Go
376 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
_ "github.com/lib/pq"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/reflection"
|
|
|
|
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
|
|
"github.com/rwadurian/mpc-system/pkg/config"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
grpcadapter "github.com/rwadurian/mpc-system/services/message-router/adapters/input/grpc"
|
|
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/memory"
|
|
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/postgres"
|
|
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
|
|
"github.com/rwadurian/mpc-system/services/message-router/domain"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func main() {
|
|
// Parse flags
|
|
configPath := flag.String("config", "", "Path to config file")
|
|
flag.Parse()
|
|
|
|
// Load configuration
|
|
cfg, err := config.Load(*configPath)
|
|
if err != nil {
|
|
fmt.Printf("Failed to load config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Initialize logger
|
|
if err := logger.Init(&logger.Config{
|
|
Level: cfg.Logger.Level,
|
|
Encoding: cfg.Logger.Encoding,
|
|
}); err != nil {
|
|
fmt.Printf("Failed to initialize logger: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer logger.Sync()
|
|
|
|
logger.Info("Starting Message Router Service",
|
|
zap.String("environment", cfg.Server.Environment),
|
|
zap.Int("grpc_port", cfg.Server.GRPCPort),
|
|
zap.Int("http_port", cfg.Server.HTTPPort))
|
|
|
|
// Initialize database connection
|
|
db, err := initDatabase(cfg.Database)
|
|
if err != nil {
|
|
logger.Fatal("Failed to connect to database", zap.Error(err))
|
|
}
|
|
defer db.Close()
|
|
|
|
// Initialize repositories and adapters
|
|
messageRepo := postgres.NewMessagePostgresRepo(db)
|
|
|
|
// Initialize in-memory message broker (replaces RabbitMQ)
|
|
messageBroker := memory.NewMessageBrokerAdapter()
|
|
defer messageBroker.Close()
|
|
|
|
// Initialize party registry and event broadcaster for party-driven architecture
|
|
partyRegistry := domain.NewPartyRegistry()
|
|
eventBroadcaster := domain.NewSessionEventBroadcaster()
|
|
|
|
// Initialize use cases
|
|
routeMessageUC := use_cases.NewRouteMessageUseCase(messageRepo, messageBroker)
|
|
getPendingMessagesUC := use_cases.NewGetPendingMessagesUseCase(messageRepo)
|
|
|
|
// Connect to Session Coordinator for proxying session operations
|
|
// This allows server-parties to only connect to Message Router
|
|
coordinatorAddr := os.Getenv("SESSION_COORDINATOR_ADDR")
|
|
if coordinatorAddr == "" {
|
|
coordinatorAddr = "session-coordinator:50051" // Default in docker-compose
|
|
}
|
|
var coordinatorConn *grpc.ClientConn
|
|
coordinatorConn, err = grpc.NewClient(coordinatorAddr,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
logger.Warn("Failed to connect to Session Coordinator for proxying (session operations will be unavailable)",
|
|
zap.String("address", coordinatorAddr),
|
|
zap.Error(err))
|
|
} else {
|
|
defer coordinatorConn.Close()
|
|
logger.Info("Connected to Session Coordinator for proxying session operations",
|
|
zap.String("address", coordinatorAddr))
|
|
}
|
|
|
|
// Start message cleanup background job
|
|
go runMessageCleanup(messageRepo)
|
|
|
|
// Start stale party detection background job
|
|
go runStalePartyDetection(partyRegistry)
|
|
|
|
// Create shutdown context
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Start servers
|
|
errChan := make(chan error, 2)
|
|
|
|
// Start gRPC server
|
|
go func() {
|
|
if err := startGRPCServer(cfg, routeMessageUC, getPendingMessagesUC, messageBroker, partyRegistry, eventBroadcaster, messageRepo, coordinatorConn); err != nil {
|
|
errChan <- fmt.Errorf("gRPC server error: %w", err)
|
|
}
|
|
}()
|
|
|
|
// Start HTTP server
|
|
go func() {
|
|
if err := startHTTPServer(cfg, routeMessageUC, getPendingMessagesUC); err != nil {
|
|
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
|
}
|
|
}()
|
|
|
|
// Wait for shutdown signal
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
select {
|
|
case sig := <-sigChan:
|
|
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
|
case err := <-errChan:
|
|
logger.Error("Server error", zap.Error(err))
|
|
}
|
|
|
|
// Graceful shutdown
|
|
logger.Info("Shutting down...")
|
|
cancel()
|
|
|
|
time.Sleep(5 * time.Second)
|
|
logger.Info("Shutdown complete")
|
|
|
|
_ = ctx
|
|
}
|
|
|
|
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
|
|
const maxRetries = 10
|
|
const retryDelay = 2 * time.Second
|
|
|
|
var db *sql.DB
|
|
var err error
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
db, err = sql.Open("postgres", cfg.DSN())
|
|
if err != nil {
|
|
logger.Warn("Failed to open database connection, retrying...",
|
|
zap.Int("attempt", i+1),
|
|
zap.Int("max_retries", maxRetries),
|
|
zap.Error(err))
|
|
time.Sleep(retryDelay * time.Duration(i+1))
|
|
continue
|
|
}
|
|
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
|
|
|
// Test connection with Ping
|
|
if err = db.Ping(); err != nil {
|
|
logger.Warn("Failed to ping database, retrying...",
|
|
zap.Int("attempt", i+1),
|
|
zap.Int("max_retries", maxRetries),
|
|
zap.Error(err))
|
|
db.Close()
|
|
time.Sleep(retryDelay * time.Duration(i+1))
|
|
continue
|
|
}
|
|
|
|
// Verify database is actually usable with a simple query
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
var result int
|
|
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
|
|
cancel()
|
|
if err != nil {
|
|
logger.Warn("Database ping succeeded but query failed, retrying...",
|
|
zap.Int("attempt", i+1),
|
|
zap.Int("max_retries", maxRetries),
|
|
zap.Error(err))
|
|
db.Close()
|
|
time.Sleep(retryDelay * time.Duration(i+1))
|
|
continue
|
|
}
|
|
|
|
logger.Info("Connected to PostgreSQL and verified connectivity",
|
|
zap.Int("attempt", i+1))
|
|
return db, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
|
|
}
|
|
|
|
func startGRPCServer(
|
|
cfg *config.Config,
|
|
routeMessageUC *use_cases.RouteMessageUseCase,
|
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
|
messageBroker *memory.MessageBrokerAdapter,
|
|
partyRegistry *domain.PartyRegistry,
|
|
eventBroadcaster *domain.SessionEventBroadcaster,
|
|
messageRepo *postgres.MessagePostgresRepo,
|
|
coordinatorConn *grpc.ClientConn,
|
|
) error {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
grpcServer := grpc.NewServer()
|
|
|
|
// Create and register the message router gRPC handler with party registry and event broadcaster
|
|
messageRouterServer := grpcadapter.NewMessageRouterServer(
|
|
routeMessageUC,
|
|
getPendingMessagesUC,
|
|
messageBroker,
|
|
partyRegistry,
|
|
eventBroadcaster,
|
|
messageRepo,
|
|
)
|
|
|
|
// Set coordinator connection for proxying session operations
|
|
// This allows server-parties to only connect to Message Router
|
|
if coordinatorConn != nil {
|
|
messageRouterServer.SetCoordinatorConnection(coordinatorConn)
|
|
}
|
|
|
|
pb.RegisterMessageRouterServer(grpcServer, messageRouterServer)
|
|
|
|
// Enable reflection for debugging
|
|
reflection.Register(grpcServer)
|
|
|
|
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
|
|
return grpcServer.Serve(listener)
|
|
}
|
|
|
|
func startHTTPServer(
|
|
cfg *config.Config,
|
|
routeMessageUC *use_cases.RouteMessageUseCase,
|
|
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
|
|
) error {
|
|
if cfg.Server.Environment == "production" {
|
|
gin.SetMode(gin.ReleaseMode)
|
|
}
|
|
|
|
router := gin.New()
|
|
router.Use(gin.Recovery())
|
|
router.Use(gin.Logger())
|
|
|
|
// Health check
|
|
router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"status": "healthy",
|
|
"service": "message-router",
|
|
})
|
|
})
|
|
|
|
// API routes
|
|
api := router.Group("/api/v1")
|
|
{
|
|
api.POST("/messages/route", func(c *gin.Context) {
|
|
var req struct {
|
|
SessionID string `json:"session_id" binding:"required"`
|
|
FromParty string `json:"from_party" binding:"required"`
|
|
ToParties []string `json:"to_parties"`
|
|
RoundNumber int `json:"round_number"`
|
|
MessageType string `json:"message_type"`
|
|
Payload []byte `json:"payload" binding:"required"`
|
|
}
|
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
input := use_cases.RouteMessageInput{
|
|
SessionID: req.SessionID,
|
|
FromParty: req.FromParty,
|
|
ToParties: req.ToParties,
|
|
RoundNumber: req.RoundNumber,
|
|
MessageType: req.MessageType,
|
|
Payload: req.Payload,
|
|
}
|
|
|
|
output, err := routeMessageUC.Execute(c.Request.Context(), input)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": output.Success,
|
|
"message_id": output.MessageID,
|
|
})
|
|
})
|
|
|
|
api.GET("/messages/pending", func(c *gin.Context) {
|
|
input := use_cases.GetPendingMessagesInput{
|
|
SessionID: c.Query("session_id"),
|
|
PartyID: c.Query("party_id"),
|
|
AfterTimestamp: 0,
|
|
}
|
|
|
|
messages, err := getPendingMessagesUC.Execute(c.Request.Context(), input)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"messages": messages})
|
|
})
|
|
}
|
|
|
|
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
|
|
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
|
|
}
|
|
|
|
func runMessageCleanup(messageRepo *postgres.MessagePostgresRepo) {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
|
|
// Delete messages older than 24 hours
|
|
cutoff := time.Now().Add(-24 * time.Hour)
|
|
count, err := messageRepo.DeleteOlderThan(ctx, cutoff)
|
|
cancel()
|
|
|
|
if err != nil {
|
|
logger.Error("Failed to cleanup old messages", zap.Error(err))
|
|
} else if count > 0 {
|
|
logger.Info("Cleaned up old messages", zap.Int64("count", count))
|
|
}
|
|
}
|
|
}
|
|
|
|
// runStalePartyDetection periodically checks for stale parties and marks them as offline
|
|
// Parties that haven't sent a heartbeat within the timeout are considered offline
|
|
func runStalePartyDetection(partyRegistry *domain.PartyRegistry) {
|
|
// Check every 30 seconds for stale parties
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
// Parties are considered stale if no heartbeat for 2 minutes
|
|
staleTimeout := 2 * time.Minute
|
|
|
|
logger.Info("Started stale party detection",
|
|
zap.Duration("check_interval", 30*time.Second),
|
|
zap.Duration("stale_timeout", staleTimeout))
|
|
|
|
for range ticker.C {
|
|
staleParties := partyRegistry.MarkStalePartiesOffline(staleTimeout)
|
|
|
|
if len(staleParties) > 0 {
|
|
for _, party := range staleParties {
|
|
logger.Warn("Party marked as offline (no heartbeat)",
|
|
zap.String("party_id", party.PartyID),
|
|
zap.String("role", party.Role),
|
|
zap.Time("last_seen", party.LastSeen),
|
|
zap.Bool("has_notification", party.IsOfflineMode()))
|
|
}
|
|
}
|
|
}
|
|
}
|