91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
# configurable
|
|
bsz = 64
|
|
imgsz = 64
|
|
nz = 100
|
|
ngf = 64
|
|
ndf = 64
|
|
nc = 3
|
|
|
|
|
|
# custom weights initialization called on netG and netD
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find("Conv") != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
elif classname.find("BatchNorm") != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
|
|
class _netG(nn.Module):
|
|
def __init__(self, ngpu):
|
|
super().__init__()
|
|
self.ngpu = ngpu
|
|
self.main = nn.Sequential(
|
|
# input is Z, going into a convolution
|
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
|
nn.BatchNorm2d(ngf * 8),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*8) x 4 x 4
|
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 4),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*4) x 8 x 8
|
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 2),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*2) x 16 x 16
|
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf),
|
|
nn.ReLU(True),
|
|
# state size. (ngf) x 32 x 32
|
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
|
nn.Tanh(),
|
|
# state size. (nc) x 64 x 64
|
|
)
|
|
|
|
def forward(self, input):
|
|
if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor):
|
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
|
else:
|
|
output = self.main(input)
|
|
return output
|
|
|
|
|
|
class _netD(nn.Module):
|
|
def __init__(self, ngpu):
|
|
super().__init__()
|
|
self.ngpu = ngpu
|
|
self.main = nn.Sequential(
|
|
# input is (nc) x 64 x 64
|
|
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf) x 32 x 32
|
|
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 2),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*2) x 16 x 16
|
|
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 4),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*4) x 8 x 8
|
|
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 8),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*8) x 4 x 4
|
|
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, input):
|
|
if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor):
|
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
|
else:
|
|
output = self.main(input)
|
|
|
|
return output.view(-1, 1)
|