sglang_v0.5.2/vision_0.22.1/references/depth/stereo/utils/norm.py

14 lines
278 B
Python

import torch
def freeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
def unfreeze_batch_norm(model):
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.train()