295 lines
7.8 KiB
Go
295 lines
7.8 KiB
Go
package use_cases
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rwadurian/mpc-system/pkg/crypto"
|
|
"github.com/rwadurian/mpc-system/pkg/logger"
|
|
"github.com/rwadurian/mpc-system/pkg/tss"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/entities"
|
|
"github.com/rwadurian/mpc-system/services/server-party/domain/repositories"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
ErrKeygenFailed = errors.New("keygen failed")
|
|
ErrKeygenTimeout = errors.New("keygen timeout")
|
|
ErrInvalidSession = errors.New("invalid session")
|
|
ErrShareSaveFailed = errors.New("failed to save share")
|
|
)
|
|
|
|
// ParticipateKeygenInput contains input for keygen participation
|
|
type ParticipateKeygenInput struct {
|
|
SessionID uuid.UUID
|
|
PartyID string
|
|
JoinToken string
|
|
}
|
|
|
|
// ParticipateKeygenOutput contains output from keygen participation
|
|
type ParticipateKeygenOutput struct {
|
|
Success bool
|
|
KeyShare *entities.PartyKeyShare
|
|
PublicKey []byte
|
|
}
|
|
|
|
// SessionCoordinatorClient defines the interface for session coordinator communication
|
|
type SessionCoordinatorClient interface {
|
|
JoinSession(ctx context.Context, sessionID uuid.UUID, partyID, joinToken string) (*SessionInfo, error)
|
|
ReportCompletion(ctx context.Context, sessionID uuid.UUID, partyID string, publicKey []byte) error
|
|
}
|
|
|
|
// MessageRouterClient defines the interface for message router communication
|
|
type MessageRouterClient interface {
|
|
RouteMessage(ctx context.Context, sessionID uuid.UUID, fromParty string, toParties []string, roundNumber int, payload []byte) error
|
|
SubscribeMessages(ctx context.Context, sessionID uuid.UUID, partyID string) (<-chan *MPCMessage, error)
|
|
}
|
|
|
|
// SessionInfo contains session information from coordinator
|
|
type SessionInfo struct {
|
|
SessionID uuid.UUID
|
|
SessionType string
|
|
ThresholdN int
|
|
ThresholdT int
|
|
MessageHash []byte
|
|
Participants []ParticipantInfo
|
|
}
|
|
|
|
// ParticipantInfo contains participant information
|
|
type ParticipantInfo struct {
|
|
PartyID string
|
|
PartyIndex int
|
|
}
|
|
|
|
// MPCMessage represents an MPC message from the router
|
|
type MPCMessage struct {
|
|
FromParty string
|
|
IsBroadcast bool
|
|
RoundNumber int
|
|
Payload []byte
|
|
}
|
|
|
|
// ParticipateKeygenUseCase handles keygen participation
|
|
type ParticipateKeygenUseCase struct {
|
|
keyShareRepo repositories.KeyShareRepository
|
|
sessionClient SessionCoordinatorClient
|
|
messageRouter MessageRouterClient
|
|
cryptoService *crypto.CryptoService
|
|
}
|
|
|
|
// NewParticipateKeygenUseCase creates a new participate keygen use case
|
|
func NewParticipateKeygenUseCase(
|
|
keyShareRepo repositories.KeyShareRepository,
|
|
sessionClient SessionCoordinatorClient,
|
|
messageRouter MessageRouterClient,
|
|
cryptoService *crypto.CryptoService,
|
|
) *ParticipateKeygenUseCase {
|
|
return &ParticipateKeygenUseCase{
|
|
keyShareRepo: keyShareRepo,
|
|
sessionClient: sessionClient,
|
|
messageRouter: messageRouter,
|
|
cryptoService: cryptoService,
|
|
}
|
|
}
|
|
|
|
// Execute participates in a keygen session using real TSS protocol
|
|
func (uc *ParticipateKeygenUseCase) Execute(
|
|
ctx context.Context,
|
|
input ParticipateKeygenInput,
|
|
) (*ParticipateKeygenOutput, error) {
|
|
// 1. Join session via coordinator
|
|
sessionInfo, err := uc.sessionClient.JoinSession(ctx, input.SessionID, input.PartyID, input.JoinToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if sessionInfo.SessionType != "keygen" {
|
|
return nil, ErrInvalidSession
|
|
}
|
|
|
|
// 2. Find self in participants and build party index map
|
|
var selfIndex int
|
|
partyIndexMap := make(map[string]int)
|
|
for _, p := range sessionInfo.Participants {
|
|
partyIndexMap[p.PartyID] = p.PartyIndex
|
|
if p.PartyID == input.PartyID {
|
|
selfIndex = p.PartyIndex
|
|
}
|
|
}
|
|
|
|
// 3. Subscribe to messages
|
|
msgChan, err := uc.messageRouter.SubscribeMessages(ctx, input.SessionID, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 4. Run TSS Keygen protocol
|
|
saveData, publicKey, err := uc.runKeygenProtocol(
|
|
ctx,
|
|
input.SessionID,
|
|
input.PartyID,
|
|
selfIndex,
|
|
sessionInfo.Participants,
|
|
sessionInfo.ThresholdN,
|
|
sessionInfo.ThresholdT,
|
|
msgChan,
|
|
partyIndexMap,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 5. Encrypt and save the share
|
|
encryptedShare, err := uc.cryptoService.EncryptShare(saveData, input.PartyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keyShare := entities.NewPartyKeyShare(
|
|
input.PartyID,
|
|
selfIndex,
|
|
input.SessionID,
|
|
sessionInfo.ThresholdN,
|
|
sessionInfo.ThresholdT,
|
|
encryptedShare,
|
|
publicKey,
|
|
)
|
|
|
|
if err := uc.keyShareRepo.Save(ctx, keyShare); err != nil {
|
|
return nil, ErrShareSaveFailed
|
|
}
|
|
|
|
// 6. Report completion to coordinator
|
|
if err := uc.sessionClient.ReportCompletion(ctx, input.SessionID, input.PartyID, publicKey); err != nil {
|
|
logger.Error("failed to report completion", zap.Error(err))
|
|
// Don't fail - share is saved
|
|
}
|
|
|
|
return &ParticipateKeygenOutput{
|
|
Success: true,
|
|
KeyShare: keyShare,
|
|
PublicKey: publicKey,
|
|
}, nil
|
|
}
|
|
|
|
// runKeygenProtocol runs the TSS keygen protocol using tss-lib
|
|
func (uc *ParticipateKeygenUseCase) runKeygenProtocol(
|
|
ctx context.Context,
|
|
sessionID uuid.UUID,
|
|
partyID string,
|
|
selfIndex int,
|
|
participants []ParticipantInfo,
|
|
n, t int,
|
|
msgChan <-chan *MPCMessage,
|
|
partyIndexMap map[string]int,
|
|
) ([]byte, []byte, error) {
|
|
logger.Info("Running keygen protocol",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID),
|
|
zap.Int("self_index", selfIndex),
|
|
zap.Int("n", n),
|
|
zap.Int("t", t))
|
|
|
|
// Create message handler adapter
|
|
msgHandler := &keygenMessageHandler{
|
|
sessionID: sessionID,
|
|
partyID: partyID,
|
|
messageRouter: uc.messageRouter,
|
|
msgChan: make(chan *tss.ReceivedMessage, 100),
|
|
partyIndexMap: partyIndexMap,
|
|
}
|
|
|
|
// Start message conversion goroutine
|
|
go msgHandler.convertMessages(ctx, msgChan)
|
|
|
|
// Create keygen config
|
|
config := tss.KeygenConfig{
|
|
Threshold: t,
|
|
TotalParties: n,
|
|
Timeout: 10 * time.Minute,
|
|
}
|
|
|
|
// Create party list
|
|
allParties := make([]tss.KeygenParty, len(participants))
|
|
for i, p := range participants {
|
|
allParties[i] = tss.KeygenParty{
|
|
PartyID: p.PartyID,
|
|
PartyIndex: p.PartyIndex,
|
|
}
|
|
}
|
|
|
|
selfParty := tss.KeygenParty{
|
|
PartyID: partyID,
|
|
PartyIndex: selfIndex,
|
|
}
|
|
|
|
// Create keygen session
|
|
session, err := tss.NewKeygenSession(config, selfParty, allParties, msgHandler)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Run keygen
|
|
result, err := session.Start(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
logger.Info("Keygen completed successfully",
|
|
zap.String("session_id", sessionID.String()),
|
|
zap.String("party_id", partyID))
|
|
|
|
return result.LocalPartySaveData, result.PublicKeyBytes, nil
|
|
}
|
|
|
|
// keygenMessageHandler adapts MPCMessage channel to tss.MessageHandler
|
|
type keygenMessageHandler struct {
|
|
sessionID uuid.UUID
|
|
partyID string
|
|
messageRouter MessageRouterClient
|
|
msgChan chan *tss.ReceivedMessage
|
|
partyIndexMap map[string]int
|
|
}
|
|
|
|
func (h *keygenMessageHandler) SendMessage(ctx context.Context, isBroadcast bool, toParties []string, msgBytes []byte) error {
|
|
return h.messageRouter.RouteMessage(ctx, h.sessionID, h.partyID, toParties, 0, msgBytes)
|
|
}
|
|
|
|
func (h *keygenMessageHandler) ReceiveMessages() <-chan *tss.ReceivedMessage {
|
|
return h.msgChan
|
|
}
|
|
|
|
func (h *keygenMessageHandler) convertMessages(ctx context.Context, inChan <-chan *MPCMessage) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
close(h.msgChan)
|
|
return
|
|
case msg, ok := <-inChan:
|
|
if !ok {
|
|
close(h.msgChan)
|
|
return
|
|
}
|
|
|
|
fromIndex, exists := h.partyIndexMap[msg.FromParty]
|
|
if !exists {
|
|
continue
|
|
}
|
|
|
|
tssMsg := &tss.ReceivedMessage{
|
|
FromPartyIndex: fromIndex,
|
|
IsBroadcast: msg.IsBroadcast,
|
|
MsgBytes: msg.Payload,
|
|
}
|
|
|
|
select {
|
|
case h.msgChan <- tssMsg:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|