rwadurian/backend/mpc-system/services/server-party/application/use_cases/participate_keygen.go

341 lines
9.5 KiB
Go

package use_cases
import (
"context"
"errors"
"os"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/pkg/tss"
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
"go.uber.org/zap"
)
var (
ErrKeygenFailed = errors.New("keygen failed")
ErrKeygenTimeout = errors.New("keygen timeout")
ErrInvalidSession = errors.New("invalid session")
ErrShareSaveFailed = errors.New("failed to save share")
)
// ParticipateKeygenInput contains input for keygen participation
type ParticipateKeygenInput struct {
SessionID uuid.UUID
PartyID string
JoinToken string
}
// ParticipateKeygenOutput contains output from keygen participation
type ParticipateKeygenOutput struct {
Success bool
KeyShare *entities.PartyKeyShare
PublicKey []byte
ShareForUser []byte // For delegate parties: encrypted share to return to user (not saved to DB)
}
// SessionCoordinatorClient defines the interface for session coordinator communication
type SessionCoordinatorClient interface {
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
}
// MessageRouterClient defines the interface for message router communication
type MessageRouterClient interface {
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
}
// SessionInfo contains session information from coordinator
type SessionInfo struct {
SessionID uuid.UUID
SessionType string
ThresholdN int
ThresholdT int
MessageHash []byte
Participants []ParticipantInfo
}
// ParticipantInfo contains participant information
type ParticipantInfo struct {
PartyID string
PartyIndex int
}
// MPCMessage represents an MPC message from the router
type MPCMessage struct {
FromParty string
IsBroadcast bool
RoundNumber int
Payload []byte
}
// ParticipateKeygenUseCase handles keygen participation
type ParticipateKeygenUseCase struct {
keyShareRepo repositories.KeyShareRepository
sessionClient SessionCoordinatorClient
messageRouter MessageRouterClient
cryptoService *crypto.CryptoService
}
// NewParticipateKeygenUseCase creates a new participate keygen use case
func NewParticipateKeygenUseCase(
keyShareRepo repositories.KeyShareRepository,
sessionClient SessionCoordinatorClient,
messageRouter MessageRouterClient,
cryptoService *crypto.CryptoService,
) *ParticipateKeygenUseCase {
return &ParticipateKeygenUseCase{
keyShareRepo: keyShareRepo,
sessionClient: sessionClient,
messageRouter: messageRouter,
cryptoService: cryptoService,
}
}
// Execute participates in a keygen session using real TSS protocol
func (uc *ParticipateKeygenUseCase) Execute(
ctx context.Context,
input ParticipateKeygenInput,
) (*ParticipateKeygenOutput, error) {
// 1. Join session via coordinator
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
if err != nil {
return nil, err
}
if sessionInfo.SessionType != "keygen" {
return nil, ErrInvalidSession
}
// 2. Find self in participants and build party index map
var selfIndex int
partyIndexMap := make(map[string]int)
for _, p := range sessionInfo.Participants {
partyIndexMap[p.PartyID] = p.PartyIndex
if p.PartyID == input.PartyID {
selfIndex = p.PartyIndex
}
}
// 3. Subscribe to messages
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
if err != nil {
return nil, err
}
// 4. Run TSS Keygen protocol
saveData, publicKey, err := uc.runKeygenProtocol(
ctx,
input.SessionID,
input.PartyID,
selfIndex,
sessionInfo.Participants,
sessionInfo.ThresholdN,
sessionInfo.ThresholdT,
msgChan,
partyIndexMap,
)
if err != nil {
return nil, err
}
// 5. Encrypt the share
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID)
if err != nil {
return nil, err
}
keyShare := entities.NewPartyKeyShare(
input.PartyID,
selfIndex,
input.SessionID,
sessionInfo.ThresholdN,
sessionInfo.ThresholdT,
encryptedShare,
publicKey,
)
// 6. Handle share based on party role
partyRole := uc.getPartyRole()
var shareForUser []byte
switch partyRole {
case "persistent":
// Persistent Party: save share to database
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
return nil, ErrShareSaveFailed
}
logger.Info("Share saved to database (persistent party)",
zap.String("party_id", input.PartyID),
zap.String("session_id", input.SessionID.String()))
case "delegate":
// Delegate Party: do NOT save to database, return to user
shareForUser = encryptedShare
logger.Info("Share NOT saved, will be returned to user (delegate party)",
zap.String("party_id", input.PartyID),
zap.String("session_id", input.SessionID.String()),
zap.Int("share_size", len(shareForUser)))
case "temporary":
// Temporary Party: optionally save to temp storage (not implemented yet)
logger.Info("Temporary party - share not saved",
zap.String("party_id", input.PartyID))
default:
// Default to persistent for safety
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
return nil, ErrShareSaveFailed
}
logger.Warn("Unknown party role, defaulting to persistent",
zap.String("party_id", input.PartyID),
zap.String("role", partyRole))
}
// 7. Report completion to coordinator
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil {
logger.Error("failed to report completion", zap.Error(err))
// Don't fail - share is handled
}
return &ParticipateKeygenOutput{
Success: true,
KeyShare: keyShare,
PublicKey: publicKey,
ShareForUser: shareForUser, // Only populated for delegate parties
}, nil
}
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
selfIndex int,
participants []ParticipantInfo,
n, t int,
msgChan <-chan *MPCMessage,
partyIndexMap map[string]int,
) ([]byte, []byte, error) {
logger.Info("Running keygen protocol",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Int("self_index", selfIndex),
zap.Int("n", n),
zap.Int("t", t))
// Create message handler adapter
msgHandler := &keygenMessageHandler{
sessionID: sessionID,
partyID: partyID,
messageRouter: uc.messageRouter,
msgChan: make(chan *tss.ReceivedMessage, 100),
partyIndexMap: partyIndexMap,
}
// Start message conversion goroutine
go msgHandler.convertMessages(ctx, msgChan)
// Create keygen config
config := tss.KeygenConfig{
Threshold: t,
TotalParties: n,
Timeout: 10 * time.Minute,
}
// Create party list
allParties := make([]tss.KeygenParty, len(participants))
for i, p := range participants {
allParties[i] = tss.KeygenParty{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
}
}
selfParty := tss.KeygenParty{
PartyID: partyID,
PartyIndex: selfIndex,
}
// Create keygen session
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
if err != nil {
return nil, nil, err
}
// Run keygen
result, err := session.Start(ctx)
if err != nil {
return nil, nil, err
}
logger.Info("Keygen completed successfully",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID))
return result.LocalPartySaveData, result.PublicKeyBytes, nil
}
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
type keygenMessageHandler struct {
sessionID uuid.UUID
partyID string
messageRouter MessageRouterClient
msgChan chan *tss.ReceivedMessage
partyIndexMap map[string]int
}
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
}
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
return h.msgChan
}
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
for {
select {
case <-ctx.Done():
close(h.msgChan)
return
case msg, ok := <-inChan:
if !ok {
close(h.msgChan)
return
}
fromIndex, exists := h.partyIndexMap[msg.FromParty]
if !exists {
continue
}
tssMsg := &tss.ReceivedMessage{
FromPartyIndex: fromIndex,
IsBroadcast: msg.IsBroadcast,
MsgBytes: msg.Payload,
}
select {
case h.msgChan <- tssMsg:
case <-ctx.Done():
return
}
}
}
}
// getPartyRole gets the party role from environment variable
// Returns "persistent" (default), "delegate", or "temporary"
func (uc *ParticipateKeygenUseCase) getPartyRole() string {
role := os.Getenv("PARTY_ROLE")
if role == "" {
return "persistent" // Default to persistent for safety
}
return role
}