rwadurian/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go

512 lines
16 KiB
Go

package use_cases
import (
"context"
"errors"
"os"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/pkg/tss"
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
"go.uber.org/zap"
)
var (
ErrKeygenFailed = errors.New("keygen failed")
ErrKeygenTimeout = errors.New("keygen timeout")
ErrInvalidSession = errors.New("invalid session")
ErrShareSaveFailed = errors.New("failed to save share")
)
// ParticipateKeygenInput contains input for keygen participation
type ParticipateKeygenInput struct {
SessionID uuid.UUID
PartyID string
JoinToken string
}
// ParticipateKeygenOutput contains output from keygen participation
type ParticipateKeygenOutput struct {
Success bool
KeyShare *entities.PartyKeyShare
PublicKey []byte
ShareForUser []byte // For delegate parties: encrypted share to return to user (not saved to DB)
}
// SessionCoordinatorClient defines the interface for session coordinator communication
type SessionCoordinatorClient interface {
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
GetSessionStatusFull(ctx context.Context, sessionID uuid.UUID) (*SessionStatusInfo, error)
}
// SessionStatusInfo contains full session status information
type SessionStatusInfo struct {
Status string
ThresholdN int
ThresholdT int
Participants []ParticipantInfo
}
// MessageRouterClient defines the interface for message router communication
type MessageRouterClient interface {
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
Heartbeat(ctx context.Context, partyID string) (int32, error)
}
// SessionInfo contains session information from coordinator
type SessionInfo struct {
SessionID uuid.UUID
SessionType string
ThresholdN int
ThresholdT int
MessageHash []byte
KeygenSessionID uuid.UUID // For signing sessions: which keygen session's share to use
Participants []ParticipantInfo
}
// ParticipantInfo contains participant information
type ParticipantInfo struct {
PartyID string
PartyIndex int
}
// MPCMessage represents an MPC message from the router
type MPCMessage struct {
MessageID string // Unique message ID for acknowledgment
SessionID string // Session ID for acknowledgment
FromParty string
IsBroadcast bool
RoundNumber int
Payload []byte
}
// ParticipateKeygenUseCase handles keygen participation
type ParticipateKeygenUseCase struct {
keyShareRepo repositories.KeyShareRepository
sessionClient SessionCoordinatorClient
messageRouter MessageRouterClient
cryptoService *crypto.CryptoService
}
// NewParticipateKeygenUseCase creates a new participate keygen use case
func NewParticipateKeygenUseCase(
keyShareRepo repositories.KeyShareRepository,
sessionClient SessionCoordinatorClient,
messageRouter MessageRouterClient,
cryptoService *crypto.CryptoService,
) *ParticipateKeygenUseCase {
return &ParticipateKeygenUseCase{
keyShareRepo: keyShareRepo,
sessionClient: sessionClient,
messageRouter: messageRouter,
cryptoService: cryptoService,
}
}
// Execute participates in a keygen session using real TSS protocol
func (uc *ParticipateKeygenUseCase) Execute(
ctx context.Context,
input ParticipateKeygenInput,
) (*ParticipateKeygenOutput, error) {
// 1. Join session via coordinator
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
if err != nil {
return nil, err
}
// Accept both "keygen" and "co_managed_keygen" session types
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
return nil, ErrInvalidSession
}
// For co_managed_keygen: wait for all N participants to join before proceeding
// This is necessary because server parties join immediately but external party joins later
if sessionInfo.SessionType == "co_managed_keygen" {
sessionInfo, err = uc.waitForAllParticipants(ctx, input.SessionID, sessionInfo, input.PartyID)
if err != nil {
return nil, err
}
}
// Delegate to the common execution logic
return uc.executeWithSessionInfo(ctx, input.SessionID, input.PartyID, sessionInfo)
}
// ExecuteWithSessionInfo participates in a keygen session with pre-obtained SessionInfo
// This is used by server-party-co-managed which has already called JoinSession in session_created phase
// and receives session_started event when all participants have joined
func (uc *ParticipateKeygenUseCase) ExecuteWithSessionInfo(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
sessionInfo *SessionInfo,
) (*ParticipateKeygenOutput, error) {
// Validate session type
if sessionInfo.SessionType != "keygen" && sessionInfo.SessionType != "co_managed_keygen" {
return nil, ErrInvalidSession
}
logger.Info("ExecuteWithSessionInfo: starting keygen with pre-obtained session info",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.String("session_type", sessionInfo.SessionType),
zap.Int("participants", len(sessionInfo.Participants)))
// Delegate to the common execution logic
return uc.executeWithSessionInfo(ctx, sessionID, partyID, sessionInfo)
}
// executeWithSessionInfo is the common execution logic shared by Execute and ExecuteWithSessionInfo
func (uc *ParticipateKeygenUseCase) executeWithSessionInfo(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
sessionInfo *SessionInfo,
) (*ParticipateKeygenOutput, error) {
// 1. Find self in participants and build party index map
var selfIndex int
partyIndexMap := make(map[string]int)
for _, p := range sessionInfo.Participants {
partyIndexMap[p.PartyID] = p.PartyIndex
if p.PartyID == partyID {
selfIndex = p.PartyIndex
}
logger.Debug("Added participant to index map",
zap.String("party_id", p.PartyID),
zap.Int("party_index", p.PartyIndex))
}
logger.Info("Built party index map",
zap.String("session_id", sessionID.String()),
zap.String("self_party_id", partyID),
zap.Int("self_index", selfIndex),
zap.Int("total_participants", len(sessionInfo.Participants)))
// 3. Subscribe to messages
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, sessionID, partyID)
if err != nil {
return nil, err
}
// 4. Run TSS Keygen protocol
saveData, publicKey, err := uc.runKeygenProtocol(
ctx,
sessionID,
partyID,
selfIndex,
sessionInfo.Participants,
sessionInfo.ThresholdN,
sessionInfo.ThresholdT,
msgChan,
partyIndexMap,
)
if err != nil {
return nil, err
}
// 5. Encrypt the share
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, partyID)
if err != nil {
return nil, err
}
keyShare := entities.NewPartyKeyShare(
partyID,
selfIndex,
sessionID,
sessionInfo.ThresholdN,
sessionInfo.ThresholdT,
encryptedShare,
publicKey,
)
// 6. Handle share based on party role
partyRole := uc.getPartyRole()
var shareForUser []byte
switch partyRole {
case "persistent":
// Persistent Party: save share to database
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
return nil, ErrShareSaveFailed
}
logger.Info("Share saved to database (persistent party)",
zap.String("party_id", partyID),
zap.String("session_id", sessionID.String()))
case "delegate":
// Delegate Party: do NOT save to database, return to user
shareForUser = encryptedShare
logger.Info("Share NOT saved, will be returned to user (delegate party)",
zap.String("party_id", partyID),
zap.String("session_id", sessionID.String()),
zap.Int("share_size", len(shareForUser)))
case "temporary":
// Temporary Party: optionally save to temp storage (not implemented yet)
logger.Info("Temporary party - share not saved",
zap.String("party_id", partyID))
default:
// Default to persistent for safety
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
return nil, ErrShareSaveFailed
}
logger.Warn("Unknown party role, defaulting to persistent",
zap.String("party_id", partyID),
zap.String("role", partyRole))
}
// 7. Report completion to coordinator
if err := uc.sessionClient.ReportCompletion(ctx, sessionID, partyID, publicKey); err != nil {
logger.Error("failed to report completion", zap.Error(err))
// Don't fail - share is handled
}
return &ParticipateKeygenOutput{
Success: true,
KeyShare: keyShare,
PublicKey: publicKey,
ShareForUser: shareForUser, // Only populated for delegate parties
}, nil
}
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
selfIndex int,
participants []ParticipantInfo,
n, t int,
msgChan <-chan *MPCMessage,
partyIndexMap map[string]int,
) ([]byte, []byte, error) {
logger.Info("Running keygen protocol",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Int("self_index", selfIndex),
zap.Int("n", n),
zap.Int("t", t))
// Create message handler adapter
msgHandler := &keygenMessageHandler{
sessionID: sessionID,
partyID: partyID,
messageRouter: uc.messageRouter,
msgChan: make(chan *tss.ReceivedMessage, 100),
partyIndexMap: partyIndexMap,
}
// Start message conversion goroutine
go msgHandler.convertMessages(ctx, msgChan)
// Create keygen config
config := tss.KeygenConfig{
Threshold: t,
TotalParties: n,
Timeout: 10 * time.Minute,
}
// Create party list
allParties := make([]tss.KeygenParty, len(participants))
for i, p := range participants {
allParties[i] = tss.KeygenParty{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
}
}
selfParty := tss.KeygenParty{
PartyID: partyID,
PartyIndex: selfIndex,
}
// Create keygen session
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
if err != nil {
return nil, nil, err
}
// Run keygen
result, err := session.Start(ctx)
if err != nil {
return nil, nil, err
}
logger.Info("Keygen completed successfully",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID))
return result.LocalPartySaveData, result.PublicKeyBytes, nil
}
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
type keygenMessageHandler struct {
sessionID uuid.UUID
partyID string
messageRouter MessageRouterClient
msgChan chan *tss.ReceivedMessage
partyIndexMap map[string]int
}
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
}
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
return h.msgChan
}
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
for {
select {
case <-ctx.Done():
logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String()))
close(h.msgChan)
return
case msg, ok := <-inChan:
if !ok {
logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String()))
close(h.msgChan)
return
}
logger.Debug("Received MPC message for conversion",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Bool("is_broadcast", msg.IsBroadcast),
zap.Int("payload_size", len(msg.Payload)))
fromIndex, exists := h.partyIndexMap[msg.FromParty]
if !exists {
logger.Warn("Message from unknown party - dropping",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Any("known_parties", h.partyIndexMap))
continue
}
tssMsg := &tss.ReceivedMessage{
FromPartyIndex: fromIndex,
IsBroadcast: msg.IsBroadcast,
MsgBytes: msg.Payload,
}
logger.Debug("Converted message, sending to TSS",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Int("from_index", fromIndex))
select {
case h.msgChan <- tssMsg:
logger.Debug("Message sent to TSS successfully",
zap.String("session_id", h.sessionID.String()))
case <-ctx.Done():
return
}
}
}
}
// getPartyRole gets the party role from environment variable
// Returns "persistent" (default), "delegate", or "temporary"
func (uc *ParticipateKeygenUseCase) getPartyRole() string {
role := os.Getenv("PARTY_ROLE")
if role == "" {
return "persistent" // Default to persistent for safety
}
return role
}
// waitForAllParticipants waits for all N participants to join the session
// This is only used for co_managed_keygen sessions where server parties join first
// and need to wait for the external party to join via invite code
func (uc *ParticipateKeygenUseCase) waitForAllParticipants(
ctx context.Context,
sessionID uuid.UUID,
initialSessionInfo *SessionInfo,
partyID string,
) (*SessionInfo, error) {
logger.Info("Waiting for all participants to join co_managed_keygen session",
zap.String("session_id", sessionID.String()),
zap.Int("expected_n", initialSessionInfo.ThresholdN),
zap.Int("current_participants", len(initialSessionInfo.Participants)))
// If already have all participants, return immediately
if len(initialSessionInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants already joined",
zap.String("session_id", sessionID.String()))
return initialSessionInfo, nil
}
// Poll for session status until all participants join or timeout
pollInterval := 2 * time.Second
maxWaitTime := 5 * time.Minute
deadline := time.Now().Add(maxWaitTime)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
// Send heartbeat to keep the party alive during wait
// This prevents the session-coordinator from timing out this party
_, heartbeatErr := uc.messageRouter.Heartbeat(ctx, partyID)
if heartbeatErr != nil {
logger.Warn("Failed to send heartbeat during wait",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Error(heartbeatErr))
// Continue anyway - heartbeat failure is not fatal
}
// Get full session status including participants
statusInfo, err := uc.sessionClient.GetSessionStatusFull(ctx, sessionID)
if err != nil {
logger.Warn("Failed to get session status, will retry",
zap.String("session_id", sessionID.String()),
zap.Error(err))
continue
}
logger.Debug("Polled session status",
zap.String("session_id", sessionID.String()),
zap.String("status", statusInfo.Status),
zap.Int("participants", len(statusInfo.Participants)),
zap.Int("expected_n", initialSessionInfo.ThresholdN))
// Check if session is in_progress (all parties joined and ready)
if statusInfo.Status == "in_progress" && len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants joined, session is in_progress",
zap.String("session_id", sessionID.String()),
zap.Int("participants", len(statusInfo.Participants)))
// Update session info with full participants list
initialSessionInfo.Participants = statusInfo.Participants
return initialSessionInfo, nil
}
// Also accept if we have all N participants even if status hasn't changed
if len(statusInfo.Participants) >= initialSessionInfo.ThresholdN {
logger.Info("All participants joined",
zap.String("session_id", sessionID.String()),
zap.Int("participants", len(statusInfo.Participants)))
initialSessionInfo.Participants = statusInfo.Participants
return initialSessionInfo, nil
}
}
}
logger.Error("Timeout waiting for all participants",
zap.String("session_id", sessionID.String()),
zap.Int("expected_n", initialSessionInfo.ThresholdN))
return nil, ErrKeygenTimeout
}