171 lines
4.0 KiB
Go
171 lines
4.0 KiB
Go
package cache
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ShareCacheEntry represents a cached share entry
|
|
type ShareCacheEntry struct {
|
|
SessionID uuid.UUID
|
|
PartyID string
|
|
Share []byte
|
|
PublicKey []byte
|
|
ExpiresAt time.Time
|
|
RetrievedOnce bool // Track if share has been retrieved
|
|
}
|
|
|
|
// ShareCache provides in-memory caching for delegate party shares
|
|
// Shares are stored temporarily and deleted after retrieval
|
|
type ShareCache struct {
|
|
entries map[string]*ShareCacheEntry // sessionID -> entry
|
|
mu sync.RWMutex
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewShareCache creates a new share cache
|
|
func NewShareCache(ttl time.Duration) *ShareCache {
|
|
cache := &ShareCache{
|
|
entries: make(map[string]*ShareCacheEntry),
|
|
ttl: ttl,
|
|
}
|
|
|
|
// Start background cleanup goroutine
|
|
go cache.cleanupExpired()
|
|
|
|
return cache
|
|
}
|
|
|
|
// Store stores a share in the cache
|
|
func (c *ShareCache) Store(sessionID uuid.UUID, partyID string, share, publicKey []byte) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
entry := &ShareCacheEntry{
|
|
SessionID: sessionID,
|
|
PartyID: partyID,
|
|
Share: share,
|
|
PublicKey: publicKey,
|
|
ExpiresAt: time.Now().Add(c.ttl),
|
|
RetrievedOnce: false,
|
|
}
|
|
|
|
c.entries[sessionID.String()] = entry
|
|
|
|
logger.Info("Share stored in cache",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.Int("share_size", len(share)),
|
|
zap.Time("expires_at", entry.ExpiresAt))
|
|
}
|
|
|
|
// Get retrieves a share from the cache
|
|
// The share is marked as retrieved but not deleted yet
|
|
// Use Delete() to remove it after successful delivery to user
|
|
func (c *ShareCache) Get(sessionID uuid.UUID) (*ShareCacheEntry, bool) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
entry, exists := c.entries[sessionID.String()]
|
|
if !exists {
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(entry.ExpiresAt) {
|
|
delete(c.entries, sessionID.String())
|
|
logger.Warn("Share expired",
|
|
zap.String("session_id", sessionID.String()))
|
|
return nil, false
|
|
}
|
|
|
|
// Mark as retrieved
|
|
entry.RetrievedOnce = true
|
|
|
|
logger.Info("Share retrieved from cache",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", entry.PartyID))
|
|
|
|
return entry, true
|
|
}
|
|
|
|
// Delete deletes a share from the cache
|
|
func (c *ShareCache) Delete(sessionID uuid.UUID) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
delete(c.entries, sessionID.String())
|
|
|
|
logger.Info("Share deleted from cache",
|
|
zap.String("session_id", sessionID.String()))
|
|
}
|
|
|
|
// GetAndDelete retrieves and immediately deletes a share (atomic operation)
|
|
// This ensures the share can only be retrieved once
|
|
func (c *ShareCache) GetAndDelete(sessionID uuid.UUID) (*ShareCacheEntry, bool) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
entry, exists := c.entries[sessionID.String()]
|
|
if !exists {
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(entry.ExpiresAt) {
|
|
delete(c.entries, sessionID.String())
|
|
logger.Warn("Share expired",
|
|
zap.String("session_id", sessionID.String()))
|
|
return nil, false
|
|
}
|
|
|
|
// Delete immediately
|
|
delete(c.entries, sessionID.String())
|
|
|
|
logger.Info("Share retrieved and deleted from cache (one-time retrieval)",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", entry.PartyID))
|
|
|
|
return entry, true
|
|
}
|
|
|
|
// Size returns the number of entries in the cache
|
|
func (c *ShareCache) Size() int {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return len(c.entries)
|
|
}
|
|
|
|
// cleanupExpired removes expired entries periodically
|
|
func (c *ShareCache) cleanupExpired() {
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
c.mu.Lock()
|
|
now := time.Now()
|
|
var expired []string
|
|
|
|
for sessionID, entry := range c.entries {
|
|
if now.After(entry.ExpiresAt) {
|
|
expired = append(expired, sessionID)
|
|
}
|
|
}
|
|
|
|
for _, sessionID := range expired {
|
|
delete(c.entries, sessionID)
|
|
}
|
|
|
|
if len(expired) > 0 {
|
|
logger.Info("Cleaned up expired shares",
|
|
zap.Int("count", len(expired)))
|
|
}
|
|
|
|
c.mu.Unlock()
|
|
}
|
|
}
|