diff --git a/hardnet.py b/hardnet.py index cdb7a13..9f93b5d 100644 --- a/hardnet.py +++ b/hardnet.py @@ -3,14 +3,6 @@ import torch.nn as nn import torch.nn.functional as F -class Flatten(nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - return x.view(x.data.size(0),-1) - - - class CombConvLayer(nn.Sequential): def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.1, bias=False): super().__init__() @@ -196,7 +188,7 @@ def __init__(self, depth_wise=False, arch=85, pretrained=True, weight_path=''): self.base.append ( nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), - Flatten(), + nn.Flatten(), nn.Dropout(drop_rate), nn.Linear(ch, 1000) ))