32 lines
886 B
Python
32 lines
886 B
Python
import torch
|
|
from torch.optim import Adam
|
|
from torchvision.models import resnet50
|
|
from tqdm import tqdm
|
|
|
|
from opensora.utils.lr_scheduler import LinearWarmupLR
|
|
|
|
|
|
def test_lr_scheduler():
|
|
warmup_steps = 200
|
|
model = resnet50().cuda()
|
|
optimizer = Adam(model.parameters(), lr=0.01)
|
|
scheduler = LinearWarmupLR(optimizer, warmup_steps=warmup_steps)
|
|
current_lr = scheduler.get_lr()[0]
|
|
data = torch.rand(1, 3, 224, 224).cuda()
|
|
|
|
for i in tqdm(range(warmup_steps * 2)):
|
|
out = model(data)
|
|
out.mean().backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
if i >= warmup_steps:
|
|
assert scheduler.get_lr()[0] == 0.01
|
|
else:
|
|
assert scheduler.get_lr()[0] > current_lr, f"{scheduler.get_lr()[0]} <= {current_lr}"
|
|
current_lr = scheduler.get_lr()[0]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_lr_scheduler()
|