layers.py 568 B

12345678910111213141516171819
  1. import torch
  2. from torch import nn
  3. class AdaptiveConcatPool2d(nn.Module):
  4. def __init__(self, sz=None):
  5. super().__init__()
  6. sz = sz or (1,1)
  7. self.ap = nn.AdaptiveAvgPool2d(sz)
  8. self.mp = nn.AdaptiveMaxPool2d(sz)
  9. def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
  10. class Lambda(nn.Module):
  11. def __init__(self, f): super().__init__(); self.f=f
  12. def forward(self, x): return self.f(x)
  13. class Flatten(nn.Module):
  14. def __init__(self): super().__init__()
  15. def forward(self, x): return x.view(x.size(0), -1)