rwadurian/backend/mpc-system/services/server-party/infrastructure/cache/share_cache.go

171 lines
4.0 KiB
Go

package cache
import (
"sync"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/logger"
"go.uber.org/zap"
)
// ShareCacheEntry represents a cached share entry
type ShareCacheEntry struct {
SessionID uuid.UUID
PartyID string
Share []byte
PublicKey []byte
ExpiresAt time.Time
RetrievedOnce bool // Track if share has been retrieved
}
// ShareCache provides in-memory caching for delegate party shares
// Shares are stored temporarily and deleted after retrieval
type ShareCache struct {
entries map[string]*ShareCacheEntry // sessionID -> entry
mu sync.RWMutex
ttl time.Duration
}
// NewShareCache creates a new share cache
func NewShareCache(ttl time.Duration) *ShareCache {
cache := &ShareCache{
entries: make(map[string]*ShareCacheEntry),
ttl: ttl,
}
// Start background cleanup goroutine
go cache.cleanupExpired()
return cache
}
// Store stores a share in the cache
func (c *ShareCache) Store(sessionID uuid.UUID, partyID string, share, publicKey []byte) {
c.mu.Lock()
defer c.mu.Unlock()
entry := &ShareCacheEntry{
SessionID: sessionID,
PartyID: partyID,
Share: share,
PublicKey: publicKey,
ExpiresAt: time.Now().Add(c.ttl),
RetrievedOnce: false,
}
c.entries[sessionID.String()] = entry
logger.Info("Share stored in cache",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Int("share_size", len(share)),
zap.Time("expires_at", entry.ExpiresAt))
}
// Get retrieves a share from the cache
// The share is marked as retrieved but not deleted yet
// Use Delete() to remove it after successful delivery to user
func (c *ShareCache) Get(sessionID uuid.UUID) (*ShareCacheEntry, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.entries[sessionID.String()]
if !exists {
return nil, false
}
// Check if expired
if time.Now().After(entry.ExpiresAt) {
delete(c.entries, sessionID.String())
logger.Warn("Share expired",
zap.String("session_id", sessionID.String()))
return nil, false
}
// Mark as retrieved
entry.RetrievedOnce = true
logger.Info("Share retrieved from cache",
zap.String("session_id", sessionID.String()),
zap.String("party_id", entry.PartyID))
return entry, true
}
// Delete deletes a share from the cache
func (c *ShareCache) Delete(sessionID uuid.UUID) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.entries, sessionID.String())
logger.Info("Share deleted from cache",
zap.String("session_id", sessionID.String()))
}
// GetAndDelete retrieves and immediately deletes a share (atomic operation)
// This ensures the share can only be retrieved once
func (c *ShareCache) GetAndDelete(sessionID uuid.UUID) (*ShareCacheEntry, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.entries[sessionID.String()]
if !exists {
return nil, false
}
// Check if expired
if time.Now().After(entry.ExpiresAt) {
delete(c.entries, sessionID.String())
logger.Warn("Share expired",
zap.String("session_id", sessionID.String()))
return nil, false
}
// Delete immediately
delete(c.entries, sessionID.String())
logger.Info("Share retrieved and deleted from cache (one-time retrieval)",
zap.String("session_id", sessionID.String()),
zap.String("party_id", entry.PartyID))
return entry, true
}
// Size returns the number of entries in the cache
func (c *ShareCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.entries)
}
// cleanupExpired removes expired entries periodically
func (c *ShareCache) cleanupExpired() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now()
var expired []string
for sessionID, entry := range c.entries {
if now.After(entry.ExpiresAt) {
expired = append(expired, sessionID)
}
}
for _, sessionID := range expired {
delete(c.entries, sessionID)
}
if len(expired) > 0 {
logger.Info("Cleaned up expired shares",
zap.Int("count", len(expired)))
}
c.mu.Unlock()
}
}