rwadurian/backend/mpc-system/services/server-party/application/use_cases/participate_signing.go

353 lines
11 KiB
Go

package use_cases
import (
"context"
"errors"
"math/big"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/pkg/tss"
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
"go.uber.org/zap"
)
var (
ErrSigningFailed = errors.New("signing failed")
ErrSigningTimeout = errors.New("signing timeout")
ErrKeyShareNotFound = errors.New("key share not found")
ErrInvalidSignSession = errors.New("invalid sign session")
)
// ParticipateSigningInput contains input for signing participation
type ParticipateSigningInput struct {
SessionID uuid.UUID
PartyID string
JoinToken string
MessageHash []byte
// For delegate parties: encrypted share provided by user (not loaded from DB)
UserShareData []byte
}
// ParticipateSigningOutput contains output from signing participation
type ParticipateSigningOutput struct {
Success bool
Signature []byte
R *big.Int
S *big.Int
}
// ParticipateSigningUseCase handles signing participation
type ParticipateSigningUseCase struct {
keyShareRepo repositories.KeyShareRepository
sessionClient SessionCoordinatorClient
messageRouter MessageRouterClient
cryptoService *crypto.CryptoService
}
// NewParticipateSigningUseCase creates a new participate signing use case
func NewParticipateSigningUseCase(
keyShareRepo repositories.KeyShareRepository,
sessionClient SessionCoordinatorClient,
messageRouter MessageRouterClient,
cryptoService *crypto.CryptoService,
) *ParticipateSigningUseCase {
return &ParticipateSigningUseCase{
keyShareRepo: keyShareRepo,
sessionClient: sessionClient,
messageRouter: messageRouter,
cryptoService: cryptoService,
}
}
// Execute participates in a signing session using real TSS protocol
func (uc *ParticipateSigningUseCase) Execute(
ctx context.Context,
input ParticipateSigningInput,
) (*ParticipateSigningOutput, error) {
// 1. Join session via coordinator
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
if err != nil {
return nil, err
}
if sessionInfo.SessionType != "sign" {
return nil, ErrInvalidSignSession
}
// 2. Get share data - either from user input (delegate) or from database (persistent)
var shareData []byte
var keyShareForUpdate *entities.PartyKeyShare
var originalThresholdN int // Original total parties from keygen
if len(input.UserShareData) > 0 {
// Delegate party: use share provided by user
shareData, err = uc.cryptoService.DecryptShare(input.UserShareData, input.PartyID)
if err != nil {
return nil, err
}
// For delegate party, get threshold info from session
originalThresholdN = sessionInfo.ThresholdN
logger.Info("Using user-provided share (delegate party)",
zap.String("party_id", input.PartyID),
zap.String("session_id", input.SessionID.String()))
} else {
// Persistent party: load from database
// If KeygenSessionID is provided, use it to load the specific share
// Otherwise, use the most recent share (fallback for backward compatibility)
if sessionInfo.KeygenSessionID != uuid.Nil {
// Load the specific share for this keygen session
keyShareForUpdate, err = uc.keyShareRepo.FindBySessionAndParty(ctx, sessionInfo.KeygenSessionID, input.PartyID)
if err != nil {
logger.Error("Failed to find keyshare for keygen session",
zap.String("party_id", input.PartyID),
zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()),
zap.Error(err))
return nil, ErrKeyShareNotFound
}
logger.Info("Using specific keyshare by keygen_session_id",
zap.String("party_id", input.PartyID),
zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()))
} else {
// Fallback: use the most recent key share
// TODO: This should be removed once all signing sessions provide keygen_session_id
keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID)
if err != nil || len(keyShares) == 0 {
return nil, ErrKeyShareNotFound
}
keyShareForUpdate = keyShares[len(keyShares)-1]
logger.Warn("Using most recent keyshare (keygen_session_id not provided)",
zap.String("party_id", input.PartyID),
zap.String("fallback_session_id", keyShareForUpdate.SessionID.String()))
}
// Get original threshold_n from keygen
originalThresholdN = keyShareForUpdate.ThresholdN
// Decrypt share data
shareData, err = uc.cryptoService.DecryptShare(keyShareForUpdate.ShareData, input.PartyID)
if err != nil {
return nil, err
}
logger.Info("Using database share (persistent party)",
zap.String("party_id", input.PartyID),
zap.String("session_id", input.SessionID.String()),
zap.String("keygen_session_id", keyShareForUpdate.SessionID.String()),
zap.Int("original_threshold_n", originalThresholdN),
zap.Int("threshold_t", keyShareForUpdate.ThresholdT))
}
// 4. Find self in participants and build party index map
var selfIndex int
partyIndexMap := make(map[string]int)
for _, p := range sessionInfo.Participants {
partyIndexMap[p.PartyID] = p.PartyIndex
if p.PartyID == input.PartyID {
selfIndex = p.PartyIndex
}
}
// 5. Subscribe to messages
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
if err != nil {
return nil, err
}
// Use message hash from session if not provided
messageHash := input.MessageHash
if len(messageHash) == 0 {
messageHash = sessionInfo.MessageHash
}
// 6. Run TSS Signing protocol
signature, r, s, err := uc.runSigningProtocol(
ctx,
input.SessionID,
input.PartyID,
selfIndex,
sessionInfo.Participants,
sessionInfo.ThresholdT,
originalThresholdN,
shareData,
messageHash,
msgChan,
partyIndexMap,
)
if err != nil {
return nil, err
}
// 7. Update key share last used (only for persistent parties)
if keyShareForUpdate != nil {
keyShareForUpdate.MarkUsed()
if err := uc.keyShareRepo.Update(ctx, keyShareForUpdate); err != nil {
logger.Warn("failed to update key share last used", zap.Error(err))
}
}
// 8. Report completion to coordinator
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil {
logger.Error("failed to report signing completion", zap.Error(err))
}
return &ParticipateSigningOutput{
Success: true,
Signature: signature,
R: r,
S: s,
}, nil
}
// runSigningProtocol runs the TSS signing protocol using tss-lib
func (uc *ParticipateSigningUseCase) runSigningProtocol(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
selfIndex int,
participants []ParticipantInfo,
t int,
n int, // Original total parties from keygen
shareData []byte,
messageHash []byte,
msgChan <-chan *MPCMessage,
partyIndexMap map[string]int,
) ([]byte, *big.Int, *big.Int, error) {
logger.Info("Running signing protocol",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Int("self_index", selfIndex),
zap.Int("t", t),
zap.Int("n", n),
zap.Int("current_signers", len(participants)),
zap.Int("message_hash_len", len(messageHash)))
// Create message handler adapter
msgHandler := &signingMessageHandler{
sessionID: sessionID,
partyID: partyID,
messageRouter: uc.messageRouter,
msgChan: make(chan *tss.ReceivedMessage, 100),
partyIndexMap: partyIndexMap,
}
// Start message conversion goroutine
go msgHandler.convertMessages(ctx, msgChan)
// Create signing config
// IMPORTANT: TotalParties must be the original n from keygen, not current signers
// For 2-of-3: t=2, n=3, but only 2 parties participate in signing
config := tss.SigningConfig{
Threshold: t,
TotalParties: n, // Original total from keygen
TotalSigners: len(participants),
Timeout: 5 * time.Minute,
}
// Create party list
allParties := make([]tss.SigningParty, len(participants))
for i, p := range participants {
allParties[i] = tss.SigningParty{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
}
}
selfParty := tss.SigningParty{
PartyID: partyID,
PartyIndex: selfIndex,
}
// Create signing session
session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler)
if err != nil {
return nil, nil, nil, err
}
// Run signing
result, err := session.Start(ctx)
if err != nil {
return nil, nil, nil, err
}
logger.Info("Signing completed successfully",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID))
return result.Signature, result.R, result.S, nil
}
// signingMessageHandler adapts MPCMessage channel to tss.MessageHandler
type signingMessageHandler struct {
sessionID uuid.UUID
partyID string
messageRouter MessageRouterClient
msgChan chan *tss.ReceivedMessage
partyIndexMap map[string]int
}
func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
}
func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
return h.msgChan
}
func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
logger.Debug("convertMessages started, waiting for messages",
zap.String("session_id", h.sessionID.String()),
zap.String("party_id", h.partyID))
for {
select {
case <-ctx.Done():
logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String()))
close(h.msgChan)
return
case msg, ok := <-inChan:
if !ok {
logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String()))
close(h.msgChan)
return
}
logger.Debug("Received MPC message for conversion",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Bool("is_broadcast", msg.IsBroadcast),
zap.Int("payload_size", len(msg.Payload)))
fromIndex, exists := h.partyIndexMap[msg.FromParty]
if !exists {
logger.Warn("Message from unknown party - dropping",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Any("known_parties", h.partyIndexMap))
continue
}
tssMsg := &tss.ReceivedMessage{
FromPartyIndex: fromIndex,
IsBroadcast: msg.IsBroadcast,
MsgBytes: msg.Payload,
}
logger.Debug("Converted message, sending to TSS",
zap.String("session_id", h.sessionID.String()),
zap.String("from_party", msg.FromParty),
zap.Int("from_index", fromIndex))
select {
case h.msgChan <- tssMsg:
logger.Debug("Message sent to TSS successfully",
zap.String("session_id", h.sessionID.String()))
case <-ctx.Done():
return
}
}
}
}