|
@@ -20,7 +20,7 @@ class GeneratorModule(ABC, nn.Module):
|
|
for l in c[n:]: set_trainable(l, True)
|
|
for l in c[n:]: set_trainable(l, True)
|
|
|
|
|
|
def get_device(self):
|
|
def get_device(self):
|
|
- next(self.parameters()).device
|
|
|
|
|
|
+ return next(self.parameters()).device
|
|
|
|
|
|
|
|
|
|
class Unet34(GeneratorModule):
|
|
class Unet34(GeneratorModule):
|
|
@@ -51,7 +51,7 @@ class Unet34(GeneratorModule):
|
|
self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
|
|
self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
|
|
|
|
|
|
#Gets around irritating inconsistent halving coming from resnet
|
|
#Gets around irritating inconsistent halving coming from resnet
|
|
- def _pad(self, x, target):
|
|
|
|
|
|
+ def _pad(self, x: torch.Tensor, target: torch.Tensor)-> torch.Tensor:
|
|
h = x.shape[2]
|
|
h = x.shape[2]
|
|
w = x.shape[3]
|
|
w = x.shape[3]
|
|
|
|
|
|
@@ -59,9 +59,9 @@ class Unet34(GeneratorModule):
|
|
target_w = target.shape[3]*2
|
|
target_w = target.shape[3]*2
|
|
|
|
|
|
if h<target_h or w<target_w:
|
|
if h<target_h or w<target_w:
|
|
- target = Variable(torch.zeros(x.shape[0], x.shape[1], target_h, target_w))
|
|
|
|
- target[:,:,:h,:w]=x
|
|
|
|
- return to_gpu(target)
|
|
|
|
|
|
+ padh = target_h-h if target_h > h else 0
|
|
|
|
+ padw = target_w-w if target_w > w else 0
|
|
|
|
+ return F.pad(x, (0,padw,0,padh), "constant",0)
|
|
|
|
|
|
return x
|
|
return x
|
|
|
|
|