512 lines
16 KiB
Go
512 lines
16 KiB
Go
package use_cases
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"github.com/rwadurian/mpc-system/pkg/tss"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
ErrKeygenFailed = errors.New("keygen failed")
|
|
ErrKeygenTimeout = errors.New("keygen timeout")
|
|
ErrInvalidSession = errors.New("invalid session")
|
|
ErrShareSaveFailed = errors.New("failed to save share")
|
|
)
|
|
|
|
// ParticipateKeygenInput contains input for keygen participation
|
|
type ParticipateKeygenInput struct {
|
|
SessionID uuid.UUID
|
|
PartyID string
|
|
JoinToken string
|
|
}
|
|
|
|
// ParticipateKeygenOutput contains output from keygen participation
|
|
type ParticipateKeygenOutput struct {
|
|
Success bool
|
|
KeyShare *entities.PartyKeyShare
|
|
PublicKey []byte
|
|
ShareForUser []byte // For delegate parties: encrypted share to return to user (not saved to DB)
|
|
}
|
|
|
|
// SessionCoordinatorClient defines the interface for session coordinator communication
|
|
type SessionCoordinatorClient interface {
|
|
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
|
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
|
GetSessionStatusFull(ctx context.Context, sessionID uuid.UUID) (*SessionStatusInfo, error)
|
|
}
|
|
|
|
// SessionStatusInfo contains full session status information
|
|
type SessionStatusInfo struct {
|
|
Status string
|
|
ThresholdN int
|
|
ThresholdT int
|
|
Participants []ParticipantInfo
|
|
}
|
|
|
|
// MessageRouterClient defines the interface for message router communication
|
|
type MessageRouterClient interface {
|
|
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
|
|
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
|
|
Heartbeat(ctx context.Context, partyID string) (int32, error)
|
|
}
|
|
|
|
// SessionInfo contains session information from coordinator
|
|
type SessionInfo struct {
|
|
SessionID uuid.UUID
|
|
SessionType string
|
|
ThresholdN int
|
|
ThresholdT int
|
|
MessageHash []byte
|
|
KeygenSessionID uuid.UUID // For signing sessions: which keygen session's share to use
|
|
Participants []ParticipantInfo
|
|
}
|
|
|
|
// ParticipantInfo contains participant information
|
|
type ParticipantInfo struct {
|
|
PartyID string
|
|
PartyIndex int
|
|
}
|
|
|
|
// MPCMessage represents an MPC message from the router
|
|
type MPCMessage struct {
|
|
MessageID string // Unique message ID for acknowledgment
|
|
SessionID string // Session ID for acknowledgment
|
|
FromParty string
|
|
IsBroadcast bool
|
|
RoundNumber int
|
|
Payload []byte
|
|
}
|
|
|
|
// ParticipateKeygenUseCase handles keygen participation
|
|
type ParticipateKeygenUseCase struct {
|
|
keyShareRepo repositories.KeyShareRepository
|
|
sessionClient SessionCoordinatorClient
|
|
messageRouter MessageRouterClient
|
|
cryptoService *crypto.CryptoService
|
|
}
|
|
|
|
// NewParticipateKeygenUseCase creates a new participate keygen use case
|
|
func NewParticipateKeygenUseCase(
|
|
keyShareRepo repositories.KeyShareRepository,
|
|
sessionClient SessionCoordinatorClient,
|
|
messageRouter MessageRouterClient,
|
|
cryptoService *crypto.CryptoService,
|
|
) *ParticipateKeygenUseCase {
|
|
return &ParticipateKeygenUseCase{
|
|
keyShareRepo: keyShareRepo,
|
|
sessionClient: sessionClient,
|
|
messageRouter: messageRouter,
|
|
cryptoService: cryptoService,
|
|
}
|
|
}
|
|
|
|
// Execute participates in a keygen session using real TSS protocol
|
|
func (uc *ParticipateKeygenUseCase) Execute(
|
|
ctx context.Context,
|
|
input ParticipateKeygenInput,
|
|
) (*ParticipateKeygenOutput, error) {
|
|
// 1. Join session via coordinator
|
|
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Accept both "keygen" and "co_managed_keygen" session types
|
|
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
|
|
return nil, ErrInvalidSession
|
|
}
|
|
|
|
// For co_managed_keygen: wait for all N participants to join before proceeding
|
|
// This is necessary because server parties join immediately but external party joins later
|
|
if sessionInfo.SessionType == "co_managed_keygen" {
|
|
sessionInfo, err = uc.waitForAllParticipants(ctx, input.SessionID, sessionInfo, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Delegate to the common execution logic
|
|
return uc.executeWithSessionInfo(ctx, input.SessionID, input.PartyID, sessionInfo)
|
|
}
|
|
|
|
// ExecuteWithSessionInfo participates in a keygen session with pre-obtained SessionInfo
|
|
// This is used by server-party-co-managed which has already called JoinSession in session_created phase
|
|
// and receives session_started event when all participants have joined
|
|
func (uc *ParticipateKeygenUseCase) ExecuteWithSessionInfo(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
sessionInfo *SessionInfo,
|
|
) (*ParticipateKeygenOutput, error) {
|
|
// Validate session type
|
|
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
|
|
return nil, ErrInvalidSession
|
|
}
|
|
|
|
logger.Info("ExecuteWithSessionInfo: starting keygen with pre-obtained session info",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.String("session_type", sessionInfo.SessionType),
|
|
zap.Int("participants", len(sessionInfo.Participants)))
|
|
|
|
// Delegate to the common execution logic
|
|
return uc.executeWithSessionInfo(ctx, sessionID, partyID, sessionInfo)
|
|
}
|
|
|
|
// executeWithSessionInfo is the common execution logic shared by Execute and ExecuteWithSessionInfo
|
|
func (uc *ParticipateKeygenUseCase) executeWithSessionInfo(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
sessionInfo *SessionInfo,
|
|
) (*ParticipateKeygenOutput, error) {
|
|
// 1. Find self in participants and build party index map
|
|
var selfIndex int
|
|
partyIndexMap := make(map[string]int)
|
|
for _, p := range sessionInfo.Participants {
|
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
|
if p.PartyID == partyID {
|
|
selfIndex = p.PartyIndex
|
|
}
|
|
logger.Debug("Added participant to index map",
|
|
zap.String("party_id", p.PartyID),
|
|
zap.Int("party_index", p.PartyIndex))
|
|
}
|
|
logger.Info("Built party index map",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("self_party_id", partyID),
|
|
zap.Int("self_index", selfIndex),
|
|
zap.Int("total_participants", len(sessionInfo.Participants)))
|
|
|
|
// 3. Subscribe to messages
|
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, sessionID, partyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 4. Run TSS Keygen protocol
|
|
saveData, publicKey, err := uc.runKeygenProtocol(
|
|
ctx,
|
|
sessionID,
|
|
partyID,
|
|
selfIndex,
|
|
sessionInfo.Participants,
|
|
sessionInfo.ThresholdN,
|
|
sessionInfo.ThresholdT,
|
|
msgChan,
|
|
partyIndexMap,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 5. Encrypt the share
|
|
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, partyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keyShare := entities.NewPartyKeyShare(
|
|
partyID,
|
|
selfIndex,
|
|
sessionID,
|
|
sessionInfo.ThresholdN,
|
|
sessionInfo.ThresholdT,
|
|
encryptedShare,
|
|
publicKey,
|
|
)
|
|
|
|
// 6. Handle share based on party role
|
|
partyRole := uc.getPartyRole()
|
|
var shareForUser []byte
|
|
|
|
switch partyRole {
|
|
case "persistent":
|
|
// Persistent Party: save share to database
|
|
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
|
|
return nil, ErrShareSaveFailed
|
|
}
|
|
logger.Info("Share saved to database (persistent party)",
|
|
zap.String("party_id", partyID),
|
|
zap.String("session_id", sessionID.String()))
|
|
|
|
case "delegate":
|
|
// Delegate Party: do NOT save to database, return to user
|
|
shareForUser = encryptedShare
|
|
logger.Info("Share NOT saved, will be returned to user (delegate party)",
|
|
zap.String("party_id", partyID),
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Int("share_size", len(shareForUser)))
|
|
|
|
case "temporary":
|
|
// Temporary Party: optionally save to temp storage (not implemented yet)
|
|
logger.Info("Temporary party - share not saved",
|
|
zap.String("party_id", partyID))
|
|
|
|
default:
|
|
// Default to persistent for safety
|
|
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
|
|
return nil, ErrShareSaveFailed
|
|
}
|
|
logger.Warn("Unknown party role, defaulting to persistent",
|
|
zap.String("party_id", partyID),
|
|
zap.String("role", partyRole))
|
|
}
|
|
|
|
// 7. Report completion to coordinator
|
|
if err := uc.sessionClient.ReportCompletion(ctx, sessionID, partyID, publicKey); err != nil {
|
|
logger.Error("failed to report completion", zap.Error(err))
|
|
// Don't fail - share is handled
|
|
}
|
|
|
|
return &ParticipateKeygenOutput{
|
|
Success: true,
|
|
KeyShare: keyShare,
|
|
PublicKey: publicKey,
|
|
ShareForUser: shareForUser, // Only populated for delegate parties
|
|
}, nil
|
|
}
|
|
|
|
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
|
|
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
selfIndex int,
|
|
participants []ParticipantInfo,
|
|
n, t int,
|
|
msgChan <-chan *MPCMessage,
|
|
partyIndexMap map[string]int,
|
|
) ([]byte, []byte, error) {
|
|
logger.Info("Running keygen protocol",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.Int("self_index", selfIndex),
|
|
zap.Int("n", n),
|
|
zap.Int("t", t))
|
|
|
|
// Create message handler adapter
|
|
msgHandler := &keygenMessageHandler{
|
|
sessionID: sessionID,
|
|
partyID: partyID,
|
|
messageRouter: uc.messageRouter,
|
|
msgChan: make(chan *tss.ReceivedMessage, 100),
|
|
partyIndexMap: partyIndexMap,
|
|
}
|
|
|
|
// Start message conversion goroutine
|
|
go msgHandler.convertMessages(ctx, msgChan)
|
|
|
|
// Create keygen config
|
|
config := tss.KeygenConfig{
|
|
Threshold: t,
|
|
TotalParties: n,
|
|
Timeout: 10 * time.Minute,
|
|
}
|
|
|
|
// Create party list
|
|
allParties := make([]tss.KeygenParty, len(participants))
|
|
for i, p := range participants {
|
|
allParties[i] = tss.KeygenParty{
|
|
PartyID: p.PartyID,
|
|
PartyIndex: p.PartyIndex,
|
|
}
|
|
}
|
|
|
|
selfParty := tss.KeygenParty{
|
|
PartyID: partyID,
|
|
PartyIndex: selfIndex,
|
|
}
|
|
|
|
// Create keygen session
|
|
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Run keygen
|
|
result, err := session.Start(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
logger.Info("Keygen completed successfully",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID))
|
|
|
|
return result.LocalPartySaveData, result.PublicKeyBytes, nil
|
|
}
|
|
|
|
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
|
type keygenMessageHandler struct {
|
|
sessionID uuid.UUID
|
|
partyID string
|
|
messageRouter MessageRouterClient
|
|
msgChan chan *tss.ReceivedMessage
|
|
partyIndexMap map[string]int
|
|
}
|
|
|
|
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
|
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
|
}
|
|
|
|
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
|
return h.msgChan
|
|
}
|
|
|
|
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String()))
|
|
close(h.msgChan)
|
|
return
|
|
case msg, ok := <-inChan:
|
|
if !ok {
|
|
logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String()))
|
|
close(h.msgChan)
|
|
return
|
|
}
|
|
|
|
logger.Debug("Received MPC message for conversion",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Bool("is_broadcast", msg.IsBroadcast),
|
|
zap.Int("payload_size", len(msg.Payload)))
|
|
|
|
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
|
if !exists {
|
|
logger.Warn("Message from unknown party - dropping",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Any("known_parties", h.partyIndexMap))
|
|
continue
|
|
}
|
|
|
|
tssMsg := &tss.ReceivedMessage{
|
|
FromPartyIndex: fromIndex,
|
|
IsBroadcast: msg.IsBroadcast,
|
|
MsgBytes: msg.Payload,
|
|
}
|
|
|
|
logger.Debug("Converted message, sending to TSS",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Int("from_index", fromIndex))
|
|
|
|
select {
|
|
case h.msgChan <- tssMsg:
|
|
logger.Debug("Message sent to TSS successfully",
|
|
zap.String("session_id", h.sessionID.String()))
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// getPartyRole gets the party role from environment variable
|
|
// Returns "persistent" (default), "delegate", or "temporary"
|
|
func (uc *ParticipateKeygenUseCase) getPartyRole() string {
|
|
role := os.Getenv("PARTY_ROLE")
|
|
if role == "" {
|
|
return "persistent" // Default to persistent for safety
|
|
}
|
|
return role
|
|
}
|
|
|
|
// waitForAllParticipants waits for all N participants to join the session
|
|
// This is only used for co_managed_keygen sessions where server parties join first
|
|
// and need to wait for the external party to join via invite code
|
|
func (uc *ParticipateKeygenUseCase) waitForAllParticipants(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
initialSessionInfo *SessionInfo,
|
|
partyID string,
|
|
) (*SessionInfo, error) {
|
|
logger.Info("Waiting for all participants to join co_managed_keygen session",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Int("expected_n", initialSessionInfo.ThresholdN),
|
|
zap.Int("current_participants", len(initialSessionInfo.Participants)))
|
|
|
|
// If already have all participants, return immediately
|
|
if len(initialSessionInfo.Participants) >= initialSessionInfo.ThresholdN {
|
|
logger.Info("All participants already joined",
|
|
zap.String("session_id", sessionID.String()))
|
|
return initialSessionInfo, nil
|
|
}
|
|
|
|
// Poll for session status until all participants join or timeout
|
|
pollInterval := 2 * time.Second
|
|
maxWaitTime := 5 * time.Minute
|
|
deadline := time.Now().Add(maxWaitTime)
|
|
|
|
for time.Now().Before(deadline) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(pollInterval):
|
|
// Send heartbeat to keep the party alive during wait
|
|
// This prevents the session-coordinator from timing out this party
|
|
_, heartbeatErr := uc.messageRouter.Heartbeat(ctx, partyID)
|
|
if heartbeatErr != nil {
|
|
logger.Warn("Failed to send heartbeat during wait",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.Error(heartbeatErr))
|
|
// Continue anyway - heartbeat failure is not fatal
|
|
}
|
|
|
|
// Get full session status including participants
|
|
statusInfo, err := uc.sessionClient.GetSessionStatusFull(ctx, sessionID)
|
|
if err != nil {
|
|
logger.Warn("Failed to get session status, will retry",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
logger.Debug("Polled session status",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("status", statusInfo.Status),
|
|
zap.Int("participants", len(statusInfo.Participants)),
|
|
zap.Int("expected_n", initialSessionInfo.ThresholdN))
|
|
|
|
// Check if session is in_progress (all parties joined and ready)
|
|
if statusInfo.Status == "in_progress" && len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
|
|
logger.Info("All participants joined, session is in_progress",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Int("participants", len(statusInfo.Participants)))
|
|
|
|
// Update session info with full participants list
|
|
initialSessionInfo.Participants = statusInfo.Participants
|
|
return initialSessionInfo, nil
|
|
}
|
|
|
|
// Also accept if we have all N participants even if status hasn't changed
|
|
if len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
|
|
logger.Info("All participants joined",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Int("participants", len(statusInfo.Participants)))
|
|
|
|
initialSessionInfo.Participants = statusInfo.Participants
|
|
return initialSessionInfo, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
logger.Error("Timeout waiting for all participants",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.Int("expected_n", initialSessionInfo.ThresholdN))
|
|
return nil, ErrKeygenTimeout
|
|
}
|