package use_cases import ( "context" "github.com/rwadurian/mpc-system/services/account/application/ports" "github.com/rwadurian/mpc-system/services/account/domain/entities" "github.com/rwadurian/mpc-system/services/account/domain/repositories" "github.com/rwadurian/mpc-system/services/account/domain/services" ) // InitiateRecoveryUseCase handles initiating account recovery type InitiateRecoveryUseCase struct { accountRepo repositories.AccountRepository recoveryRepo repositories.RecoverySessionRepository domainService *services.AccountDomainService eventPublisher ports.EventPublisher } // NewInitiateRecoveryUseCase creates a new InitiateRecoveryUseCase func NewInitiateRecoveryUseCase( accountRepo repositories.AccountRepository, recoveryRepo repositories.RecoverySessionRepository, domainService *services.AccountDomainService, eventPublisher ports.EventPublisher, ) *InitiateRecoveryUseCase { return &InitiateRecoveryUseCase{ accountRepo: accountRepo, recoveryRepo: recoveryRepo, domainService: domainService, eventPublisher: eventPublisher, } } // Execute initiates account recovery func (uc *InitiateRecoveryUseCase) Execute(ctx context.Context, input ports.InitiateRecoveryInput) (*ports.InitiateRecoveryOutput, error) { // Check if there's already an active recovery session existingRecovery, err := uc.recoveryRepo.GetActiveByAccountID(ctx, input.AccountID) if err == nil && existingRecovery != nil { return nil, &entities.AccountError{ Code: "RECOVERY_ALREADY_IN_PROGRESS", Message: "there is already an active recovery session for this account", } } // Initiate recovery using domain service recoverySession, err := uc.domainService.InitiateRecovery(ctx, input.AccountID, input.RecoveryType, input.OldShareType) if err != nil { return nil, err } // Publish event if uc.eventPublisher != nil { _ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{ Type: ports.EventTypeRecoveryStarted, AccountID: input.AccountID.String(), Data: map[string]interface{}{ "recoverySessionId": recoverySession.ID.String(), "recoveryType": input.RecoveryType.String(), }, }) } return &ports.InitiateRecoveryOutput{ RecoverySession: recoverySession, }, nil } // CompleteRecoveryUseCase handles completing account recovery type CompleteRecoveryUseCase struct { accountRepo repositories.AccountRepository shareRepo repositories.AccountShareRepository recoveryRepo repositories.RecoverySessionRepository domainService *services.AccountDomainService eventPublisher ports.EventPublisher } // NewCompleteRecoveryUseCase creates a new CompleteRecoveryUseCase func NewCompleteRecoveryUseCase( accountRepo repositories.AccountRepository, shareRepo repositories.AccountShareRepository, recoveryRepo repositories.RecoverySessionRepository, domainService *services.AccountDomainService, eventPublisher ports.EventPublisher, ) *CompleteRecoveryUseCase { return &CompleteRecoveryUseCase{ accountRepo: accountRepo, shareRepo: shareRepo, recoveryRepo: recoveryRepo, domainService: domainService, eventPublisher: eventPublisher, } } // Execute completes account recovery func (uc *CompleteRecoveryUseCase) Execute(ctx context.Context, input ports.CompleteRecoveryInput) (*ports.CompleteRecoveryOutput, error) { // Convert shares input newShares := make([]services.ShareInfo, len(input.NewShares)) for i, s := range input.NewShares { newShares[i] = services.ShareInfo{ ShareType: s.ShareType, PartyID: s.PartyID, PartyIndex: s.PartyIndex, DeviceType: s.DeviceType, DeviceID: s.DeviceID, } } // Complete recovery using domain service err := uc.domainService.CompleteRecovery( ctx, input.RecoverySessionID, input.NewPublicKey, input.NewKeygenSessionID, newShares, ) if err != nil { return nil, err } // Get recovery session to get account ID recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID) if err != nil { return nil, err } // Get updated account account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID) if err != nil { return nil, err } // Publish event if uc.eventPublisher != nil { _ = uc.eventPublisher.Publish(ctx, ports.AccountEvent{ Type: ports.EventTypeRecoveryComplete, AccountID: account.ID.String(), Data: map[string]interface{}{ "recoverySessionId": input.RecoverySessionID, "newKeygenSessionId": input.NewKeygenSessionID.String(), }, }) } return &ports.CompleteRecoveryOutput{ Account: account, }, nil } // GetRecoveryStatusInput represents input for getting recovery status type GetRecoveryStatusInput struct { RecoverySessionID string } // GetRecoveryStatusOutput represents output from getting recovery status type GetRecoveryStatusOutput struct { RecoverySession *entities.RecoverySession } // GetRecoveryStatusUseCase handles getting recovery session status type GetRecoveryStatusUseCase struct { recoveryRepo repositories.RecoverySessionRepository } // NewGetRecoveryStatusUseCase creates a new GetRecoveryStatusUseCase func NewGetRecoveryStatusUseCase(recoveryRepo repositories.RecoverySessionRepository) *GetRecoveryStatusUseCase { return &GetRecoveryStatusUseCase{ recoveryRepo: recoveryRepo, } } // Execute gets recovery session status func (uc *GetRecoveryStatusUseCase) Execute(ctx context.Context, input GetRecoveryStatusInput) (*GetRecoveryStatusOutput, error) { recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID) if err != nil { return nil, err } return &GetRecoveryStatusOutput{ RecoverySession: recovery, }, nil } // CancelRecoveryInput represents input for canceling recovery type CancelRecoveryInput struct { RecoverySessionID string } // CancelRecoveryUseCase handles canceling recovery type CancelRecoveryUseCase struct { accountRepo repositories.AccountRepository recoveryRepo repositories.RecoverySessionRepository } // NewCancelRecoveryUseCase creates a new CancelRecoveryUseCase func NewCancelRecoveryUseCase( accountRepo repositories.AccountRepository, recoveryRepo repositories.RecoverySessionRepository, ) *CancelRecoveryUseCase { return &CancelRecoveryUseCase{ accountRepo: accountRepo, recoveryRepo: recoveryRepo, } } // Execute cancels a recovery session func (uc *CancelRecoveryUseCase) Execute(ctx context.Context, input CancelRecoveryInput) error { // Get recovery session recovery, err := uc.recoveryRepo.GetByID(ctx, input.RecoverySessionID) if err != nil { return err } // Check if recovery can be canceled if recovery.IsCompleted() { return &entities.AccountError{ Code: "RECOVERY_CANNOT_CANCEL", Message: "cannot cancel completed recovery", } } // Mark recovery as failed if err := recovery.Fail(); err != nil { return err } // Update recovery session if err := uc.recoveryRepo.Update(ctx, recovery); err != nil { return err } // Reactivate account account, err := uc.accountRepo.GetByID(ctx, recovery.AccountID) if err != nil { return err } account.Activate() if err := uc.accountRepo.Update(ctx, account); err != nil { return err } return nil }