diff --git a/cifar10-fast-simple/model.py b/cifar10-fast-simple/model.py index fe33495..602849c 100644 --- a/cifar10-fast-simple/model.py +++ b/cifar10-fast-simple/model.py @@ -132,10 +132,11 @@ class ResNetBagOfTricks(nn.Module): x = self.conv6(x) x = self.conv7(x) x = x + self.conv9(self.conv8(x)) + feature_map = x x = self.pool10(x) x = x.reshape(x.size(0), x.size(1)) x = self.linear11(x) x = self.scale_out * x - return x + return x, feature_map Model = ResNetBagOfTricks diff --git a/cifar10-fast-simple/studentmodel.py b/cifar10-fast-simple/studentmodel.py index 071ffff..076f2cb 100644 --- a/cifar10-fast-simple/studentmodel.py +++ b/cifar10-fast-simple/studentmodel.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn