rwadurian/backend/mpc-system/services/message-router/cmd/server/main.go

425 lines
12 KiB
Go

package main
import (
"context"
"database/sql"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
_ "github.com/lib/pq"
amqp "github.com/rabbitmq/amqp091-go"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
pb "github.com/rwadurian/mpc-system/api/grpc/router/v1"
"github.com/rwadurian/mpc-system/pkg/config"
"github.com/rwadurian/mpc-system/pkg/logger"
grpcadapter "github.com/rwadurian/mpc-system/services/message-router/adapters/input/grpc"
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/postgres"
"github.com/rwadurian/mpc-system/services/message-router/adapters/output/rabbitmq"
"github.com/rwadurian/mpc-system/services/message-router/application/use_cases"
"github.com/rwadurian/mpc-system/services/message-router/domain"
"go.uber.org/zap"
)
func main() {
// Parse flags
configPath := flag.String("config", "", "Path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Failed to load config: %v\n", err)
os.Exit(1)
}
// Initialize logger
if err := logger.Init(&logger.Config{
Level: cfg.Logger.Level,
Encoding: cfg.Logger.Encoding,
}); err != nil {
fmt.Printf("Failed to initialize logger: %v\n", err)
os.Exit(1)
}
defer logger.Sync()
logger.Info("Starting Message Router Service",
zap.String("environment", cfg.Server.Environment),
zap.Int("grpc_port", cfg.Server.GRPCPort),
zap.Int("http_port", cfg.Server.HTTPPort))
// Initialize database connection
db, err := initDatabase(cfg.Database)
if err != nil {
logger.Fatal("Failed to connect to database", zap.Error(err))
}
defer db.Close()
// Initialize RabbitMQ connection
rabbitConn, err := initRabbitMQ(cfg.RabbitMQ)
if err != nil {
logger.Fatal("Failed to connect to RabbitMQ", zap.Error(err))
}
defer rabbitConn.Close()
// Initialize repositories and adapters
messageRepo := postgres.NewMessagePostgresRepo(db)
messageBroker, err := rabbitmq.NewMessageBrokerAdapter(rabbitConn)
if err != nil {
logger.Fatal("Failed to create message broker", zap.Error(err))
}
defer messageBroker.Close()
// Initialize party registry and event broadcaster for party-driven architecture
partyRegistry := domain.NewPartyRegistry()
eventBroadcaster := domain.NewSessionEventBroadcaster()
// Initialize use cases
routeMessageUC := use_cases.NewRouteMessageUseCase(messageRepo, messageBroker)
getPendingMessagesUC := use_cases.NewGetPendingMessagesUseCase(messageRepo)
// Start message cleanup background job
go runMessageCleanup(messageRepo)
// Create shutdown context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start servers
errChan := make(chan error, 2)
// Start gRPC server
go func() {
if err := startGRPCServer(cfg, routeMessageUC, getPendingMessagesUC, messageBroker, partyRegistry, eventBroadcaster); err != nil {
errChan <- fmt.Errorf("gRPC server error: %w", err)
}
}()
// Start HTTP server
go func() {
if err := startHTTPServer(cfg, routeMessageUC, getPendingMessagesUC); err != nil {
errChan <- fmt.Errorf("HTTP server error: %w", err)
}
}()
// Wait for shutdown signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigChan:
logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
case err := <-errChan:
logger.Error("Server error", zap.Error(err))
}
// Graceful shutdown
logger.Info("Shutting down...")
cancel()
time.Sleep(5 * time.Second)
logger.Info("Shutdown complete")
_ = ctx
}
func initDatabase(cfg config.DatabaseConfig) (*sql.DB, error) {
const maxRetries = 10
const retryDelay = 2 * time.Second
var db *sql.DB
var err error
for i := 0; i < maxRetries; i++ {
db, err = sql.Open("postgres", cfg.DSN())
if err != nil {
logger.Warn("Failed to open database connection, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLife)
// Test connection with Ping
if err = db.Ping(); err != nil {
logger.Warn("Failed to ping database, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
// Verify database is actually usable with a simple query
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
var result int
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
cancel()
if err != nil {
logger.Warn("Database ping succeeded but query failed, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
db.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
logger.Info("Connected to PostgreSQL and verified connectivity",
zap.Int("attempt", i+1))
return db, nil
}
return nil, fmt.Errorf("failed to connect to database after %d retries: %w", maxRetries, err)
}
func initRabbitMQ(cfg config.RabbitMQConfig) (*amqp.Connection, error) {
const maxRetries = 10
const retryDelay = 2 * time.Second
var conn *amqp.Connection
var err error
for i := 0; i < maxRetries; i++ {
// Attempt to dial RabbitMQ
conn, err = amqp.Dial(cfg.URL())
if err != nil {
logger.Warn("Failed to dial RabbitMQ, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.String("url", maskPassword(cfg.URL())),
zap.Error(err))
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
// Verify connection is actually usable by opening a channel
ch, err := conn.Channel()
if err != nil {
logger.Warn("RabbitMQ connection established but channel creation failed, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
conn.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
// Test the channel with a simple operation (declare a test exchange)
err = ch.ExchangeDeclare(
"mpc.health.check", // name
"fanout", // type
false, // durable
true, // auto-deleted
false, // internal
false, // no-wait
nil, // arguments
)
if err != nil {
logger.Warn("RabbitMQ channel created but exchange declaration failed, retrying...",
zap.Int("attempt", i+1),
zap.Int("max_retries", maxRetries),
zap.Error(err))
ch.Close()
conn.Close()
time.Sleep(retryDelay * time.Duration(i+1))
continue
}
// Clean up test exchange
ch.ExchangeDelete("mpc.health.check", false, false)
ch.Close()
// Setup connection close notification
closeChan := make(chan *amqp.Error, 1)
conn.NotifyClose(closeChan)
go func() {
err := <-closeChan
if err != nil {
logger.Error("RabbitMQ connection closed unexpectedly", zap.Error(err))
}
}()
logger.Info("Connected to RabbitMQ and verified connectivity",
zap.Int("attempt", i+1))
return conn, nil
}
return nil, fmt.Errorf("failed to connect to RabbitMQ after %d retries: %w", maxRetries, err)
}
// maskPassword masks the password in the RabbitMQ URL for logging
func maskPassword(url string) string {
// Simple masking: amqp://user:password@host:port -> amqp://user:****@host:port
start := 0
for i := 0; i < len(url); i++ {
if url[i] == ':' && i > 0 && url[i-1] != '/' {
start = i + 1
break
}
}
if start == 0 {
return url
}
end := start
for i := start; i < len(url); i++ {
if url[i] == '@' {
end = i
break
}
}
if end == start {
return url
}
return url[:start] + "****" + url[end:]
}
func startGRPCServer(
cfg *config.Config,
routeMessageUC *use_cases.RouteMessageUseCase,
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
messageBroker *rabbitmq.MessageBrokerAdapter,
partyRegistry *domain.PartyRegistry,
eventBroadcaster *domain.SessionEventBroadcaster,
) error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Server.GRPCPort))
if err != nil {
return err
}
grpcServer := grpc.NewServer()
// Create and register the message router gRPC handler with party registry and event broadcaster
messageRouterServer := grpcadapter.NewMessageRouterServer(
routeMessageUC,
getPendingMessagesUC,
messageBroker,
partyRegistry,
eventBroadcaster,
)
pb.RegisterMessageRouterServer(grpcServer, messageRouterServer)
// Enable reflection for debugging
reflection.Register(grpcServer)
logger.Info("Starting gRPC server", zap.Int("port", cfg.Server.GRPCPort))
return grpcServer.Serve(listener)
}
func startHTTPServer(
cfg *config.Config,
routeMessageUC *use_cases.RouteMessageUseCase,
getPendingMessagesUC *use_cases.GetPendingMessagesUseCase,
) error {
if cfg.Server.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.New()
router.Use(gin.Recovery())
router.Use(gin.Logger())
// Health check
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "message-router",
})
})
// API routes
api := router.Group("/api/v1")
{
api.POST("/messages/route", func(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
FromParty string `json:"from_party" binding:"required"`
ToParties []string `json:"to_parties"`
RoundNumber int `json:"round_number"`
MessageType string `json:"message_type"`
Payload []byte `json:"payload" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
input := use_cases.RouteMessageInput{
SessionID: req.SessionID,
FromParty: req.FromParty,
ToParties: req.ToParties,
RoundNumber: req.RoundNumber,
MessageType: req.MessageType,
Payload: req.Payload,
}
output, err := routeMessageUC.Execute(c.Request.Context(), input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"success": output.Success,
"message_id": output.MessageID,
})
})
api.GET("/messages/pending", func(c *gin.Context) {
input := use_cases.GetPendingMessagesInput{
SessionID: c.Query("session_id"),
PartyID: c.Query("party_id"),
AfterTimestamp: 0,
}
messages, err := getPendingMessagesUC.Execute(c.Request.Context(), input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"messages": messages})
})
}
logger.Info("Starting HTTP server", zap.Int("port", cfg.Server.HTTPPort))
return router.Run(fmt.Sprintf(":%d", cfg.Server.HTTPPort))
}
func runMessageCleanup(messageRepo *postgres.MessagePostgresRepo) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
// Delete messages older than 24 hours
cutoff := time.Now().Add(-24 * time.Hour)
count, err := messageRepo.DeleteOlderThan(ctx, cutoff)
cancel()
if err != nil {
logger.Error("Failed to cleanup old messages", zap.Error(err))
} else if count > 0 {
logger.Info("Cleaned up old messages", zap.Int64("count", count))
}
}
}