207 lines
7.4 KiB
Python
207 lines
7.4 KiB
Python
import math
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
from presets import get_module
|
|
from torch import Tensor
|
|
from torchvision.transforms import functional as F
|
|
|
|
|
|
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
|
|
transforms_module = get_module(use_v2)
|
|
|
|
mixup_cutmix = []
|
|
if mixup_alpha > 0:
|
|
mixup_cutmix.append(
|
|
transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
|
|
if use_v2
|
|
else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
|
|
)
|
|
if cutmix_alpha > 0:
|
|
mixup_cutmix.append(
|
|
transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
|
|
if use_v2
|
|
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
|
|
)
|
|
if not mixup_cutmix:
|
|
return None
|
|
|
|
return transforms_module.RandomChoice(mixup_cutmix)
|
|
|
|
|
|
class RandomMixUp(torch.nn.Module):
|
|
"""Randomly apply MixUp to the provided batch and targets.
|
|
The class implements the data augmentations as described in the paper
|
|
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
|
|
|
Args:
|
|
num_classes (int): number of classes used for one-hot encoding.
|
|
p (float): probability of the batch being transformed. Default value is 0.5.
|
|
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
|
Default value is 1.0.
|
|
inplace (bool): boolean to make this transform inplace. Default set to False.
|
|
"""
|
|
|
|
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
|
super().__init__()
|
|
|
|
if num_classes < 1:
|
|
raise ValueError(
|
|
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
|
)
|
|
|
|
if alpha <= 0:
|
|
raise ValueError("Alpha param can't be zero.")
|
|
|
|
self.num_classes = num_classes
|
|
self.p = p
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Args:
|
|
batch (Tensor): Float tensor of size (B, C, H, W)
|
|
target (Tensor): Integer tensor of size (B, )
|
|
|
|
Returns:
|
|
Tensor: Randomly transformed batch.
|
|
"""
|
|
if batch.ndim != 4:
|
|
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
|
if target.ndim != 1:
|
|
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
|
if not batch.is_floating_point():
|
|
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
|
if target.dtype != torch.int64:
|
|
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
|
|
|
if not self.inplace:
|
|
batch = batch.clone()
|
|
target = target.clone()
|
|
|
|
if target.ndim == 1:
|
|
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
|
|
|
if torch.rand(1).item() >= self.p:
|
|
return batch, target
|
|
|
|
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
|
batch_rolled = batch.roll(1, 0)
|
|
target_rolled = target.roll(1, 0)
|
|
|
|
# Implemented as on mixup paper, page 3.
|
|
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
|
batch_rolled.mul_(1.0 - lambda_param)
|
|
batch.mul_(lambda_param).add_(batch_rolled)
|
|
|
|
target_rolled.mul_(1.0 - lambda_param)
|
|
target.mul_(lambda_param).add_(target_rolled)
|
|
|
|
return batch, target
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"num_classes={self.num_classes}"
|
|
f", p={self.p}"
|
|
f", alpha={self.alpha}"
|
|
f", inplace={self.inplace}"
|
|
f")"
|
|
)
|
|
return s
|
|
|
|
|
|
class RandomCutMix(torch.nn.Module):
|
|
"""Randomly apply CutMix to the provided batch and targets.
|
|
The class implements the data augmentations as described in the paper
|
|
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
|
<https://arxiv.org/abs/1905.04899>`_.
|
|
|
|
Args:
|
|
num_classes (int): number of classes used for one-hot encoding.
|
|
p (float): probability of the batch being transformed. Default value is 0.5.
|
|
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
|
Default value is 1.0.
|
|
inplace (bool): boolean to make this transform inplace. Default set to False.
|
|
"""
|
|
|
|
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
|
super().__init__()
|
|
if num_classes < 1:
|
|
raise ValueError("Please provide a valid positive value for the num_classes.")
|
|
if alpha <= 0:
|
|
raise ValueError("Alpha param can't be zero.")
|
|
|
|
self.num_classes = num_classes
|
|
self.p = p
|
|
self.alpha = alpha
|
|
self.inplace = inplace
|
|
|
|
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Args:
|
|
batch (Tensor): Float tensor of size (B, C, H, W)
|
|
target (Tensor): Integer tensor of size (B, )
|
|
|
|
Returns:
|
|
Tensor: Randomly transformed batch.
|
|
"""
|
|
if batch.ndim != 4:
|
|
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
|
if target.ndim != 1:
|
|
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
|
if not batch.is_floating_point():
|
|
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
|
if target.dtype != torch.int64:
|
|
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
|
|
|
if not self.inplace:
|
|
batch = batch.clone()
|
|
target = target.clone()
|
|
|
|
if target.ndim == 1:
|
|
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
|
|
|
if torch.rand(1).item() >= self.p:
|
|
return batch, target
|
|
|
|
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
|
batch_rolled = batch.roll(1, 0)
|
|
target_rolled = target.roll(1, 0)
|
|
|
|
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
|
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
|
_, H, W = F.get_dimensions(batch)
|
|
|
|
r_x = torch.randint(W, (1,))
|
|
r_y = torch.randint(H, (1,))
|
|
|
|
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
|
r_w_half = int(r * W)
|
|
r_h_half = int(r * H)
|
|
|
|
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
|
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
|
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
|
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
|
|
|
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
|
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
|
|
|
target_rolled.mul_(1.0 - lambda_param)
|
|
target.mul_(lambda_param).add_(target_rolled)
|
|
|
|
return batch, target
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"num_classes={self.num_classes}"
|
|
f", p={self.p}"
|
|
f", alpha={self.alpha}"
|
|
f", inplace={self.inplace}"
|
|
f")"
|
|
)
|
|
return s
|