rwadurian/backend/mpc-system/services/server-party/application/use_cases/participate_signing.go

271 lines
7.1 KiB
Go

package use_cases
import (
"context"
"errors"
"math/big"
"time"
"github.com/google/uuid"
"github.com/rwadurian/mpc-system/pkg/crypto"
"github.com/rwadurian/mpc-system/pkg/logger"
"github.com/rwadurian/mpc-system/pkg/tss"
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
"go.uber.org/zap"
)
var (
ErrSigningFailed = errors.New("signing failed")
ErrSigningTimeout = errors.New("signing timeout")
ErrKeyShareNotFound = errors.New("key share not found")
ErrInvalidSignSession = errors.New("invalid sign session")
)
// ParticipateSigningInput contains input for signing participation
type ParticipateSigningInput struct {
SessionID uuid.UUID
PartyID string
JoinToken string
MessageHash []byte
}
// ParticipateSigningOutput contains output from signing participation
type ParticipateSigningOutput struct {
Success bool
Signature []byte
R *big.Int
S *big.Int
}
// ParticipateSigningUseCase handles signing participation
type ParticipateSigningUseCase struct {
keyShareRepo repositories.KeyShareRepository
sessionClient SessionCoordinatorClient
messageRouter MessageRouterClient
cryptoService *crypto.CryptoService
}
// NewParticipateSigningUseCase creates a new participate signing use case
func NewParticipateSigningUseCase(
keyShareRepo repositories.KeyShareRepository,
sessionClient SessionCoordinatorClient,
messageRouter MessageRouterClient,
cryptoService *crypto.CryptoService,
) *ParticipateSigningUseCase {
return &ParticipateSigningUseCase{
keyShareRepo: keyShareRepo,
sessionClient: sessionClient,
messageRouter: messageRouter,
cryptoService: cryptoService,
}
}
// Execute participates in a signing session using real TSS protocol
func (uc *ParticipateSigningUseCase) Execute(
ctx context.Context,
input ParticipateSigningInput,
) (*ParticipateSigningOutput, error) {
// 1. Join session via coordinator
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
if err != nil {
return nil, err
}
if sessionInfo.SessionType != "sign" {
return nil, ErrInvalidSignSession
}
// 2. Load key share for this party
keyShares, err := uc.keyShareRepo.ListByParty(ctx, input.PartyID)
if err != nil || len(keyShares) == 0 {
return nil, ErrKeyShareNotFound
}
// Use the most recent key share (in production, would match by public key or session reference)
keyShare := keyShares[len(keyShares)-1]
// 3. Decrypt share data
shareData, err := uc.cryptoService.DecryptShare(keyShare.ShareData, input.PartyID)
if err != nil {
return nil, err
}
// 4. Find self in participants and build party index map
var selfIndex int
partyIndexMap := make(map[string]int)
for _, p := range sessionInfo.Participants {
partyIndexMap[p.PartyID] = p.PartyIndex
if p.PartyID == input.PartyID {
selfIndex = p.PartyIndex
}
}
// 5. Subscribe to messages
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
if err != nil {
return nil, err
}
// Use message hash from session if not provided
messageHash := input.MessageHash
if len(messageHash) == 0 {
messageHash = sessionInfo.MessageHash
}
// 6. Run TSS Signing protocol
signature, r, s, err := uc.runSigningProtocol(
ctx,
input.SessionID,
input.PartyID,
selfIndex,
sessionInfo.Participants,
sessionInfo.ThresholdT,
shareData,
messageHash,
msgChan,
partyIndexMap,
)
if err != nil {
return nil, err
}
// 7. Update key share last used
keyShare.MarkUsed()
if err := uc.keyShareRepo.Update(ctx, keyShare); err != nil {
logger.Warn("failed to update key share last used", zap.Error(err))
}
// 8. Report completion to coordinator
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, signature); err != nil {
logger.Error("failed to report signing completion", zap.Error(err))
}
return &ParticipateSigningOutput{
Success: true,
Signature: signature,
R: r,
S: s,
}, nil
}
// runSigningProtocol runs the TSS signing protocol using tss-lib
func (uc *ParticipateSigningUseCase) runSigningProtocol(
ctx context.Context,
sessionID uuid.UUID,
partyID string,
selfIndex int,
participants []ParticipantInfo,
t int,
shareData []byte,
messageHash []byte,
msgChan <-chan *MPCMessage,
partyIndexMap map[string]int,
) ([]byte, *big.Int, *big.Int, error) {
logger.Info("Running signing protocol",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID),
zap.Int("self_index", selfIndex),
zap.Int("t", t),
zap.Int("message_hash_len", len(messageHash)))
// Create message handler adapter
msgHandler := &signingMessageHandler{
sessionID: sessionID,
partyID: partyID,
messageRouter: uc.messageRouter,
msgChan: make(chan *tss.ReceivedMessage, 100),
partyIndexMap: partyIndexMap,
}
// Start message conversion goroutine
go msgHandler.convertMessages(ctx, msgChan)
// Create signing config
config := tss.SigningConfig{
Threshold: t,
TotalSigners: len(participants),
Timeout: 5 * time.Minute,
}
// Create party list
allParties := make([]tss.SigningParty, len(participants))
for i, p := range participants {
allParties[i] = tss.SigningParty{
PartyID: p.PartyID,
PartyIndex: p.PartyIndex,
}
}
selfParty := tss.SigningParty{
PartyID: partyID,
PartyIndex: selfIndex,
}
// Create signing session
session, err := tss.NewSigningSession(config, selfParty, allParties, messageHash, shareData, msgHandler)
if err != nil {
return nil, nil, nil, err
}
// Run signing
result, err := session.Start(ctx)
if err != nil {
return nil, nil, nil, err
}
logger.Info("Signing completed successfully",
zap.String("session_id", sessionID.String()),
zap.String("party_id", partyID))
return result.Signature, result.R, result.S, nil
}
// signingMessageHandler adapts MPCMessage channel to tss.MessageHandler
type signingMessageHandler struct {
sessionID uuid.UUID
partyID string
messageRouter MessageRouterClient
msgChan chan *tss.ReceivedMessage
partyIndexMap map[string]int
}
func (h *signingMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
}
func (h *signingMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
return h.msgChan
}
func (h *signingMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
for {
select {
case <-ctx.Done():
close(h.msgChan)
return
case msg, ok := <-inChan:
if !ok {
close(h.msgChan)
return
}
fromIndex, exists := h.partyIndexMap[msg.FromParty]
if !exists {
continue
}
tssMsg := &tss.ReceivedMessage{
FromPartyIndex: fromIndex,
IsBroadcast: msg.IsBroadcast,
MsgBytes: msg.Payload,
}
select {
case h.msgChan <- tssMsg:
case <-ctx.Done():
return
}
}
}
}