353 lines
11 KiB
Go
353 lines
11 KiB
Go
package use_cases
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"math/big"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"github.com/rwadurian/mpc-system/pkg/tss"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
ErrSigningFailed = errors.New("signing failed")
|
|
ErrSigningTimeout = errors.New("signing timeout")
|
|
ErrKeyShareNotFound = errors.New("key share not found")
|
|
ErrInvalidSignSession = errors.New("invalid sign session")
|
|
)
|
|
|
|
// ParticipateSigningInput contains input for signing participation
|
|
type ParticipateSigningInput struct {
|
|
SessionID uuid.UUID
|
|
PartyID string
|
|
JoinToken string
|
|
MessageHash []byte
|
|
// For delegate parties: encrypted share provided by user (not loaded from DB)
|
|
UserShareData []byte
|
|
}
|
|
|
|
// ParticipateSigningOutput contains output from signing participation
|
|
type ParticipateSigningOutput struct {
|
|
Success bool
|
|
Signature []byte
|
|
R *big.Int
|
|
S *big.Int
|
|
}
|
|
|
|
// ParticipateSigningUseCase handles signing participation
|
|
type ParticipateSigningUseCase struct {
|
|
keyShareRepo repositories.KeyShareRepository
|
|
sessionClient SessionCoordinatorClient
|
|
messageRouter MessageRouterClient
|
|
cryptoService *crypto.CryptoService
|
|
}
|
|
|
|
// NewParticipateSigningUseCase creates a new participate signing use case
|
|
func NewParticipateSigningUseCase(
|
|
keyShareRepo repositories.KeyShareRepository,
|
|
sessionClient SessionCoordinatorClient,
|
|
messageRouter MessageRouterClient,
|
|
cryptoService *crypto.CryptoService,
|
|
) *ParticipateSigningUseCase {
|
|
return &ParticipateSigningUseCase{
|
|
keyShareRepo: keyShareRepo,
|
|
sessionClient: sessionClient,
|
|
messageRouter: messageRouter,
|
|
cryptoService: cryptoService,
|
|
}
|
|
}
|
|
|
|
// Execute participates in a signing session using real TSS protocol
|
|
func (uc *ParticipateSigningUseCase) Execute(
|
|
ctx context.Context,
|
|
input ParticipateSigningInput,
|
|
) (*ParticipateSigningOutput, error) {
|
|
// 1. Join session via coordinator
|
|
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if sessionInfo.SessionType != "sign" {
|
|
return nil, ErrInvalidSignSession
|
|
}
|
|
|
|
// 2. Get share data - either from user input (delegate) or from database (persistent)
|
|
var shareData []byte
|
|
var keyShareForUpdate *entities.PartyKeyShare
|
|
var originalThresholdN int // Original total parties from keygen
|
|
|
|
if len(input.UserShareData) > 0 {
|
|
// Delegate party: use share provided by user
|
|
shareData, err = uc.cryptoService.DecryptShare(input.UserShareData, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// For delegate party, get threshold info from session
|
|
originalThresholdN = sessionInfo.ThresholdN
|
|
logger.Info("Using user-provided share (delegate party)",
|
|
zap.String("party_id", input.PartyID),
|
|
zap.String("session_id", input.SessionID.String()))
|
|
} else {
|
|
// Persistent party: load from database
|
|
// If KeygenSessionID is provided, use it to load the specific share
|
|
// Otherwise, use the most recent share (fallback for backward compatibility)
|
|
if sessionInfo.KeygenSessionID != uuid.Nil {
|
|
// Load the specific share for this keygen session
|
|
keyShareForUpdate, err = uc.keyShareRepo.FindBySessionAndParty(ctx, sessionInfo.KeygenSessionID, input.PartyID)
|
|
if err != nil {
|
|
logger.Error("Failed to find keyshare for keygen session",
|
|
zap.String("party_id", input.PartyID),
|
|
zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()),
|
|
zap.Error(err))
|
|
return nil, ErrKeyShareNotFound
|
|
}
|
|
logger.Info("Using specific keyshare by keygen_session_id",
|
|
zap.String("party_id", input.PartyID),
|
|
zap.String("keygen_session_id", sessionInfo.KeygenSessionID.String()))
|
|
} else {
|
|
// Fallback: use the most recent key share
|
|
// TODO: This should be removed once all signing sessions provide keygen_session_id
|
|
keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID)
|
|
if err != nil || len(keyShares) == 0 {
|
|
return nil, ErrKeyShareNotFound
|
|
}
|
|
keyShareForUpdate = keyShares[len(keyShares)-1]
|
|
logger.Warn("Using most recent keyshare (keygen_session_id not provided)",
|
|
zap.String("party_id", input.PartyID),
|
|
zap.String("fallback_session_id", keyShareForUpdate.SessionID.String()))
|
|
}
|
|
|
|
// Get original threshold_n from keygen
|
|
originalThresholdN = keyShareForUpdate.ThresholdN
|
|
|
|
// Decrypt share data
|
|
shareData, err = uc.cryptoService.DecryptShare(keyShareForUpdate.ShareData, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
logger.Info("Using database share (persistent party)",
|
|
zap.String("party_id", input.PartyID),
|
|
zap.String("session_id", input.SessionID.String()),
|
|
zap.String("keygen_session_id", keyShareForUpdate.SessionID.String()),
|
|
zap.Int("original_threshold_n", originalThresholdN),
|
|
zap.Int("threshold_t", keyShareForUpdate.ThresholdT))
|
|
}
|
|
|
|
// 4. Find self in participants and build party index map
|
|
var selfIndex int
|
|
partyIndexMap := make(map[string]int)
|
|
for _, p := range sessionInfo.Participants {
|
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
|
if p.PartyID == input.PartyID {
|
|
selfIndex = p.PartyIndex
|
|
}
|
|
}
|
|
|
|
// 5. Subscribe to messages
|
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Use message hash from session if not provided
|
|
messageHash := input.MessageHash
|
|
if len(messageHash) == 0 {
|
|
messageHash = sessionInfo.MessageHash
|
|
}
|
|
|
|
// 6. Run TSS Signing protocol
|
|
signature, r, s, err := uc.runSigningProtocol(
|
|
ctx,
|
|
input.SessionID,
|
|
input.PartyID,
|
|
selfIndex,
|
|
sessionInfo.Participants,
|
|
sessionInfo.ThresholdT,
|
|
originalThresholdN,
|
|
shareData,
|
|
messageHash,
|
|
msgChan,
|
|
partyIndexMap,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 7. Update key share last used (only for persistent parties)
|
|
if keyShareForUpdate != nil {
|
|
keyShareForUpdate.MarkUsed()
|
|
if err := uc.keyShareRepo.Update(ctx, keyShareForUpdate); err != nil {
|
|
logger.Warn("failed to update key share last used", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// 8. Report completion to coordinator
|
|
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil {
|
|
logger.Error("failed to report signing completion", zap.Error(err))
|
|
}
|
|
|
|
return &ParticipateSigningOutput{
|
|
Success: true,
|
|
Signature: signature,
|
|
R: r,
|
|
S: s,
|
|
}, nil
|
|
}
|
|
|
|
// runSigningProtocol runs the TSS signing protocol using tss-lib
|
|
func (uc *ParticipateSigningUseCase) runSigningProtocol(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
selfIndex int,
|
|
participants []ParticipantInfo,
|
|
t int,
|
|
n int, // Original total parties from keygen
|
|
shareData []byte,
|
|
messageHash []byte,
|
|
msgChan <-chan *MPCMessage,
|
|
partyIndexMap map[string]int,
|
|
) ([]byte, *big.Int, *big.Int, error) {
|
|
logger.Info("Running signing protocol",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.Int("self_index", selfIndex),
|
|
zap.Int("t", t),
|
|
zap.Int("n", n),
|
|
zap.Int("current_signers", len(participants)),
|
|
zap.Int("message_hash_len", len(messageHash)))
|
|
|
|
// Create message handler adapter
|
|
msgHandler := &signingMessageHandler{
|
|
sessionID: sessionID,
|
|
partyID: partyID,
|
|
messageRouter: uc.messageRouter,
|
|
msgChan: make(chan *tss.ReceivedMessage, 100),
|
|
partyIndexMap: partyIndexMap,
|
|
}
|
|
|
|
// Start message conversion goroutine
|
|
go msgHandler.convertMessages(ctx, msgChan)
|
|
|
|
// Create signing config
|
|
// IMPORTANT: TotalParties must be the original n from keygen, not current signers
|
|
// For 2-of-3: t=2, n=3, but only 2 parties participate in signing
|
|
config := tss.SigningConfig{
|
|
Threshold: t,
|
|
TotalParties: n, // Original total from keygen
|
|
TotalSigners: len(participants),
|
|
Timeout: 5 * time.Minute,
|
|
}
|
|
|
|
// Create party list
|
|
allParties := make([]tss.SigningParty, len(participants))
|
|
for i, p := range participants {
|
|
allParties[i] = tss.SigningParty{
|
|
PartyID: p.PartyID,
|
|
PartyIndex: p.PartyIndex,
|
|
}
|
|
}
|
|
|
|
selfParty := tss.SigningParty{
|
|
PartyID: partyID,
|
|
PartyIndex: selfIndex,
|
|
}
|
|
|
|
// Create signing session
|
|
session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
// Run signing
|
|
result, err := session.Start(ctx)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
logger.Info("Signing completed successfully",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID))
|
|
|
|
return result.Signature, result.R, result.S, nil
|
|
}
|
|
|
|
// signingMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
|
type signingMessageHandler struct {
|
|
sessionID uuid.UUID
|
|
partyID string
|
|
messageRouter MessageRouterClient
|
|
msgChan chan *tss.ReceivedMessage
|
|
partyIndexMap map[string]int
|
|
}
|
|
|
|
func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
|
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
|
}
|
|
|
|
func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
|
return h.msgChan
|
|
}
|
|
|
|
func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
|
logger.Debug("convertMessages started, waiting for messages",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("party_id", h.partyID))
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Debug("convertMessages context cancelled", zap.String("session_id", h.sessionID.String()))
|
|
close(h.msgChan)
|
|
return
|
|
case msg, ok := <-inChan:
|
|
if !ok {
|
|
logger.Debug("convertMessages inChan closed", zap.String("session_id", h.sessionID.String()))
|
|
close(h.msgChan)
|
|
return
|
|
}
|
|
|
|
logger.Debug("Received MPC message for conversion",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Bool("is_broadcast", msg.IsBroadcast),
|
|
zap.Int("payload_size", len(msg.Payload)))
|
|
|
|
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
|
if !exists {
|
|
logger.Warn("Message from unknown party - dropping",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Any("known_parties", h.partyIndexMap))
|
|
continue
|
|
}
|
|
|
|
tssMsg := &tss.ReceivedMessage{
|
|
FromPartyIndex: fromIndex,
|
|
IsBroadcast: msg.IsBroadcast,
|
|
MsgBytes: msg.Payload,
|
|
}
|
|
|
|
logger.Debug("Converted message, sending to TSS",
|
|
zap.String("session_id", h.sessionID.String()),
|
|
zap.String("from_party", msg.FromParty),
|
|
zap.Int("from_index", fromIndex))
|
|
|
|
select {
|
|
case h.msgChan <- tssMsg:
|
|
logger.Debug("Message sent to TSS successfully",
|
|
zap.String("session_id", h.sessionID.String()))
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|