O1: add fast training code
This commit is contained in:
parent
86d16e53d7
commit
ce3a848eb7
2 changed files with 254 additions and 11 deletions
|
@ -20,11 +20,14 @@ from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||||
from WideResNet import WideResNet
|
from WideResNet import WideResNet
|
||||||
from equations import get_eps_audit
|
from equations import get_eps_audit
|
||||||
import student_model
|
import student_model
|
||||||
|
import fast_model
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
DEVICE = None
|
DEVICE = None
|
||||||
|
DTYPE = None
|
||||||
|
DATADIR = Path("./data")
|
||||||
|
|
||||||
|
|
||||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
|
@ -46,9 +49,8 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
])
|
])
|
||||||
datadir = Path("./data")
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
|
||||||
|
|
||||||
# Original dataset
|
# Original dataset
|
||||||
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
||||||
|
@ -106,11 +108,9 @@ def get_dataloaders2(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
])
|
])
|
||||||
datadir = Path("./data")
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||||
|
trainp_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||||
trainp_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
|
||||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
|
||||||
|
|
||||||
mask = random.sample(range(len(trainp_ds)), m)
|
mask = random.sample(range(len(trainp_ds)), m)
|
||||||
S = np.random.choice([True, False], size=m)
|
S = np.random.choice([True, False], size=m)
|
||||||
|
@ -148,9 +148,8 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
])
|
])
|
||||||
datadir = Path("./data")
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
|
||||||
|
|
||||||
# Original dataset
|
# Original dataset
|
||||||
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
|
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
|
||||||
|
@ -192,6 +191,39 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
|
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
|
||||||
|
def preprocess_data(data):
|
||||||
|
data = torch.tensor(data)#.to(DTYPE)
|
||||||
|
data = data / 255.0
|
||||||
|
data = data.permute(0, 3, 1, 2)
|
||||||
|
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
|
||||||
|
data = nn.ReflectionPad2d(4)(data)
|
||||||
|
data = transforms.RandomCrop(size=(32, 32))(data)
|
||||||
|
data = transforms.RandomHorizontalFlip()(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
||||||
|
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
||||||
|
|
||||||
|
train_x = preprocess_data(train_ds.data)
|
||||||
|
test_x = preprocess_data(test_ds.data)
|
||||||
|
train_y = torch.tensor(train_ds.targets)
|
||||||
|
test_y = torch.tensor(test_ds.targets)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
TensorDataset(train_x, train_y.long()),
|
||||||
|
batch_size=train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
test_dl = DataLoader(
|
||||||
|
TensorDataset(test_x, test_y.long()),
|
||||||
|
batch_size=train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
return train_dl, test_dl, train_x
|
||||||
|
|
||||||
def evaluate_on(model, dataloader):
|
def evaluate_on(model, dataloader):
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
@ -404,6 +436,67 @@ def train_small(hp, train_dl, test_dl):
|
||||||
|
|
||||||
return model_init, model
|
return model_init, model
|
||||||
|
|
||||||
|
def train_fast(hp):
|
||||||
|
epochs = hp['epochs']
|
||||||
|
momentum = 0.9
|
||||||
|
weight_decay = 0.256
|
||||||
|
weight_decay_bias = 0.004
|
||||||
|
ema_update_freq = 5
|
||||||
|
ema_rho = 0.99**ema_update_freq
|
||||||
|
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
|
||||||
|
|
||||||
|
print("=========================")
|
||||||
|
print("Training a fast model")
|
||||||
|
print("=========================")
|
||||||
|
train_dl, test_dl, train_x = get_dataloaders_raw(hp['target_points'])
|
||||||
|
|
||||||
|
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
|
||||||
|
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
|
||||||
|
|
||||||
|
#train_model.to(DTYPE)
|
||||||
|
#for module in train_model.modules():
|
||||||
|
# if isinstance(module, nn.BatchNorm2d):
|
||||||
|
# module.float()
|
||||||
|
model.to(DEVICE)
|
||||||
|
init_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
# weights = [
|
||||||
|
# (w, torch.zeros_like(w))
|
||||||
|
# for w in train_model.parameters()
|
||||||
|
# if w.requires_grad and len(w.shape) > 1
|
||||||
|
# ]
|
||||||
|
# biases = [
|
||||||
|
# (w, torch.zeros_like(w))
|
||||||
|
# for w in train_model.parameters()
|
||||||
|
# if w.requires_grad and len(w.shape) <= 1
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# lr_schedule = torch.cat(
|
||||||
|
# [
|
||||||
|
# torch.linspace(0e0, 2e-3, 194),
|
||||||
|
# torch.linspace(2e-3, 2e-4, 582),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# lr_schedule_bias = 64.0 * lr_schedule
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=0.1,
|
||||||
|
momentum=0.9,
|
||||||
|
nesterov=True,
|
||||||
|
weight_decay=5e-4
|
||||||
|
)
|
||||||
|
scheduler = MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
||||||
|
gamma=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
|
||||||
|
|
||||||
|
return init_model, model
|
||||||
|
|
||||||
def train(hp, train_dl, test_dl):
|
def train(hp, train_dl, test_dl):
|
||||||
model = WideResNet(
|
model = WideResNet(
|
||||||
d=hp["wrn_depth"],
|
d=hp["wrn_depth"],
|
||||||
|
@ -481,6 +574,7 @@ def train(hp, train_dl, test_dl):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global DEVICE
|
global DEVICE
|
||||||
|
global DTYPE
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
||||||
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||||
|
@ -492,14 +586,18 @@ def main():
|
||||||
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
||||||
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
|
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
|
||||||
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
|
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
|
||||||
|
parser.add_argument('--fast', action='store_true', help='train a the fast model', required=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if torch.cuda.is_available() and args.cuda:
|
if torch.cuda.is_available() and args.cuda:
|
||||||
DEVICE = torch.device(f'cuda:{args.cuda}')
|
DEVICE = torch.device(f'cuda:{args.cuda}')
|
||||||
|
DTYPE = torch.float16
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
DEVICE = torch.device('cuda:0')
|
DEVICE = torch.device('cuda:0')
|
||||||
|
DTYPE = torch.float16
|
||||||
else:
|
else:
|
||||||
DEVICE = torch.device('cpu')
|
DEVICE = torch.device('cpu')
|
||||||
|
DTYPE = torch.float32
|
||||||
|
|
||||||
hp = {
|
hp = {
|
||||||
"target_points": args.m,
|
"target_points": args.m,
|
||||||
|
@ -530,6 +628,10 @@ def main():
|
||||||
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||||
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
||||||
test_dl = None
|
test_dl = None
|
||||||
|
elif args.fast:
|
||||||
|
train_dl, test_dl, _ = get_dataloaders_raw(hp['target_points'])
|
||||||
|
model_init, model_trained = train_fast(hp)
|
||||||
|
exit(1)
|
||||||
else:
|
else:
|
||||||
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||||
if args.studentraw:
|
if args.studentraw:
|
||||||
|
|
141
one_run_audit/fast_model.py
Normal file
141
one_run_audit/fast_model.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def label_smoothing_loss(inputs, targets, alpha):
|
||||||
|
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
|
||||||
|
kl = -log_probs.mean(dim=1)
|
||||||
|
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
|
||||||
|
loss = (1 - alpha) * xent + alpha * kl
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class GhostBatchNorm(nn.BatchNorm2d):
|
||||||
|
def __init__(self, num_features, num_splits, **kw):
|
||||||
|
super().__init__(num_features, **kw)
|
||||||
|
|
||||||
|
running_mean = torch.zeros(num_features * num_splits)
|
||||||
|
running_var = torch.ones(num_features * num_splits)
|
||||||
|
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
self.num_splits = num_splits
|
||||||
|
self.register_buffer("running_mean", running_mean)
|
||||||
|
self.register_buffer("running_var", running_var)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
if (self.training is True) and (mode is False):
|
||||||
|
# lazily collate stats when we are going to use them
|
||||||
|
self.running_mean = torch.mean(
|
||||||
|
self.running_mean.view(self.num_splits, self.num_features), dim=0
|
||||||
|
).repeat(self.num_splits)
|
||||||
|
self.running_var = torch.mean(
|
||||||
|
self.running_var.view(self.num_splits, self.num_features), dim=0
|
||||||
|
).repeat(self.num_splits)
|
||||||
|
return super().train(mode)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
n, c, h, w = input.shape
|
||||||
|
if self.training or not self.track_running_stats:
|
||||||
|
assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
|
||||||
|
return F.batch_norm(
|
||||||
|
input.view(-1, c * self.num_splits, h, w),
|
||||||
|
self.running_mean,
|
||||||
|
self.running_var,
|
||||||
|
self.weight.repeat(self.num_splits),
|
||||||
|
self.bias.repeat(self.num_splits),
|
||||||
|
True,
|
||||||
|
self.momentum,
|
||||||
|
self.eps,
|
||||||
|
).view(n, c, h, w)
|
||||||
|
else:
|
||||||
|
return F.batch_norm(
|
||||||
|
input,
|
||||||
|
self.running_mean[: self.num_features],
|
||||||
|
self.running_var[: self.num_features],
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
False,
|
||||||
|
self.momentum,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
|
||||||
|
GhostBatchNorm(c_out, num_splits=16),
|
||||||
|
nn.CELU(alpha=0.3),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_pool_norm_act(c_in, c_out):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
GhostBatchNorm(c_out, num_splits=16),
|
||||||
|
nn.CELU(alpha=0.3),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_whitening(data, patch_size=(3, 3)):
|
||||||
|
# Compute weights from data such that
|
||||||
|
# torch.std(F.conv2d(data, weights), dim=(2, 3))
|
||||||
|
# is close to 1.
|
||||||
|
h, w = patch_size
|
||||||
|
c = data.size(1)
|
||||||
|
patches = data.unfold(2, h, 1).unfold(3, w, 1)
|
||||||
|
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
|
||||||
|
|
||||||
|
n, c, h, w = patches.shape
|
||||||
|
X = patches.reshape(n, c * h * w)
|
||||||
|
X = X / (X.size(0) - 1) ** 0.5
|
||||||
|
covariance = X.t() @ X
|
||||||
|
|
||||||
|
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
|
||||||
|
|
||||||
|
eigenvalues = eigenvalues.flip(0)
|
||||||
|
|
||||||
|
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
|
||||||
|
|
||||||
|
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetBagOfTricks(nn.Module):
|
||||||
|
def __init__(self, first_layer_weights, c_in, c_out, scale_out):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c = first_layer_weights.size(0)
|
||||||
|
|
||||||
|
conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
|
||||||
|
conv1.weight.data = first_layer_weights
|
||||||
|
conv1.weight.requires_grad = False
|
||||||
|
|
||||||
|
self.conv1 = conv1
|
||||||
|
self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
|
||||||
|
self.conv3 = conv_pool_norm_act(64, 128)
|
||||||
|
self.conv4 = conv_bn_relu(128, 128)
|
||||||
|
self.conv5 = conv_bn_relu(128, 128)
|
||||||
|
self.conv6 = conv_pool_norm_act(128, 256)
|
||||||
|
self.conv7 = conv_pool_norm_act(256, 512)
|
||||||
|
self.conv8 = conv_bn_relu(512, 512)
|
||||||
|
self.conv9 = conv_bn_relu(512, 512)
|
||||||
|
self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
|
||||||
|
self.linear11 = nn.Linear(512, c_out, bias=False)
|
||||||
|
self.scale_out = scale_out
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = x + self.conv5(self.conv4(x))
|
||||||
|
x = self.conv6(x)
|
||||||
|
x = self.conv7(x)
|
||||||
|
x = x + self.conv9(self.conv8(x))
|
||||||
|
x = self.pool10(x)
|
||||||
|
x = x.reshape(x.size(0), x.size(1))
|
||||||
|
x = self.linear11(x)
|
||||||
|
x = self.scale_out * x
|
||||||
|
return x
|
||||||
|
|
||||||
|
Model = ResNetBagOfTricks
|
Loading…
Reference in a new issue