12345678910111213141516171819 |
- import torch
- from torch import nn
- class AdaptiveConcatPool2d(nn.Module):
- def __init__(self, sz=None):
- super().__init__()
- sz = sz or (1,1)
- self.ap = nn.AdaptiveAvgPool2d(sz)
- self.mp = nn.AdaptiveMaxPool2d(sz)
- def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
- class Lambda(nn.Module):
- def __init__(self, f): super().__init__(); self.f=f
- def forward(self, x): return self.f(x)
- class Flatten(nn.Module):
- def __init__(self): super().__init__()
- def forward(self, x): return x.view(x.size(0), -1)
|