O1: add fast training code
This commit is contained in:
parent
86d16e53d7
commit
ce3a848eb7
2 changed files with 254 additions and 11 deletions
|
@ -20,11 +20,14 @@ from opacus.utils.batch_memory_manager import BatchMemoryManager
|
|||
from WideResNet import WideResNet
|
||||
from equations import get_eps_audit
|
||||
import student_model
|
||||
import fast_model
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
DEVICE = None
|
||||
DTYPE = None
|
||||
DATADIR = Path("./data")
|
||||
|
||||
|
||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||
|
@ -46,9 +49,8 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
datadir = Path("./data")
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||
|
||||
# Original dataset
|
||||
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
|
||||
|
@ -106,11 +108,9 @@ def get_dataloaders2(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
datadir = Path("./data")
|
||||
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
trainp_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||
trainp_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||
|
||||
mask = random.sample(range(len(trainp_ds)), m)
|
||||
S = np.random.choice([True, False], size=m)
|
||||
|
@ -148,9 +148,8 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
datadir = Path("./data")
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
|
||||
|
||||
# Original dataset
|
||||
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
|
||||
|
@ -192,6 +191,39 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
|||
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
|
||||
|
||||
|
||||
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
|
||||
def preprocess_data(data):
|
||||
data = torch.tensor(data)#.to(DTYPE)
|
||||
data = data / 255.0
|
||||
data = data.permute(0, 3, 1, 2)
|
||||
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
|
||||
data = nn.ReflectionPad2d(4)(data)
|
||||
data = transforms.RandomCrop(size=(32, 32))(data)
|
||||
data = transforms.RandomHorizontalFlip()(data)
|
||||
return data
|
||||
|
||||
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
|
||||
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
|
||||
|
||||
train_x = preprocess_data(train_ds.data)
|
||||
test_x = preprocess_data(test_ds.data)
|
||||
train_y = torch.tensor(train_ds.targets)
|
||||
test_y = torch.tensor(test_ds.targets)
|
||||
|
||||
train_dl = DataLoader(
|
||||
TensorDataset(train_x, train_y.long()),
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4
|
||||
)
|
||||
test_dl = DataLoader(
|
||||
TensorDataset(test_x, test_y.long()),
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4
|
||||
)
|
||||
return train_dl, test_dl, train_x
|
||||
|
||||
def evaluate_on(model, dataloader):
|
||||
correct = 0
|
||||
total = 0
|
||||
|
@ -404,6 +436,67 @@ def train_small(hp, train_dl, test_dl):
|
|||
|
||||
return model_init, model
|
||||
|
||||
def train_fast(hp):
|
||||
epochs = hp['epochs']
|
||||
momentum = 0.9
|
||||
weight_decay = 0.256
|
||||
weight_decay_bias = 0.004
|
||||
ema_update_freq = 5
|
||||
ema_rho = 0.99**ema_update_freq
|
||||
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
|
||||
|
||||
print("=========================")
|
||||
print("Training a fast model")
|
||||
print("=========================")
|
||||
train_dl, test_dl, train_x = get_dataloaders_raw(hp['target_points'])
|
||||
|
||||
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
|
||||
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
|
||||
|
||||
#train_model.to(DTYPE)
|
||||
#for module in train_model.modules():
|
||||
# if isinstance(module, nn.BatchNorm2d):
|
||||
# module.float()
|
||||
model.to(DEVICE)
|
||||
init_model = copy.deepcopy(model)
|
||||
|
||||
# weights = [
|
||||
# (w, torch.zeros_like(w))
|
||||
# for w in train_model.parameters()
|
||||
# if w.requires_grad and len(w.shape) > 1
|
||||
# ]
|
||||
# biases = [
|
||||
# (w, torch.zeros_like(w))
|
||||
# for w in train_model.parameters()
|
||||
# if w.requires_grad and len(w.shape) <= 1
|
||||
# ]
|
||||
|
||||
# lr_schedule = torch.cat(
|
||||
# [
|
||||
# torch.linspace(0e0, 2e-3, 194),
|
||||
# torch.linspace(2e-3, 2e-4, 582),
|
||||
# ]
|
||||
# )
|
||||
# lr_schedule_bias = 64.0 * lr_schedule
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=0.1,
|
||||
momentum=0.9,
|
||||
nesterov=True,
|
||||
weight_decay=5e-4
|
||||
)
|
||||
scheduler = MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
||||
gamma=0.2
|
||||
)
|
||||
|
||||
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
|
||||
|
||||
return init_model, model
|
||||
|
||||
def train(hp, train_dl, test_dl):
|
||||
model = WideResNet(
|
||||
d=hp["wrn_depth"],
|
||||
|
@ -481,6 +574,7 @@ def train(hp, train_dl, test_dl):
|
|||
|
||||
def main():
|
||||
global DEVICE
|
||||
global DTYPE
|
||||
|
||||
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
|
||||
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
|
||||
|
@ -492,14 +586,18 @@ def main():
|
|||
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
||||
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
|
||||
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
|
||||
parser.add_argument('--fast', action='store_true', help='train a the fast model', required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available() and args.cuda:
|
||||
DEVICE = torch.device(f'cuda:{args.cuda}')
|
||||
DTYPE = torch.float16
|
||||
elif torch.cuda.is_available():
|
||||
DEVICE = torch.device('cuda:0')
|
||||
DTYPE = torch.float16
|
||||
else:
|
||||
DEVICE = torch.device('cpu')
|
||||
DTYPE = torch.float32
|
||||
|
||||
hp = {
|
||||
"target_points": args.m,
|
||||
|
@ -530,6 +628,10 @@ def main():
|
|||
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
||||
test_dl = None
|
||||
elif args.fast:
|
||||
train_dl, test_dl, _ = get_dataloaders_raw(hp['target_points'])
|
||||
model_init, model_trained = train_fast(hp)
|
||||
exit(1)
|
||||
else:
|
||||
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||
if args.studentraw:
|
||||
|
|
141
one_run_audit/fast_model.py
Normal file
141
one_run_audit/fast_model.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def label_smoothing_loss(inputs, targets, alpha):
|
||||
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
|
||||
kl = -log_probs.mean(dim=1)
|
||||
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
|
||||
loss = (1 - alpha) * xent + alpha * kl
|
||||
return loss
|
||||
|
||||
|
||||
class GhostBatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features, num_splits, **kw):
|
||||
super().__init__(num_features, **kw)
|
||||
|
||||
running_mean = torch.zeros(num_features * num_splits)
|
||||
running_var = torch.ones(num_features * num_splits)
|
||||
|
||||
self.weight.requires_grad = False
|
||||
self.num_splits = num_splits
|
||||
self.register_buffer("running_mean", running_mean)
|
||||
self.register_buffer("running_var", running_var)
|
||||
|
||||
def train(self, mode=True):
|
||||
if (self.training is True) and (mode is False):
|
||||
# lazily collate stats when we are going to use them
|
||||
self.running_mean = torch.mean(
|
||||
self.running_mean.view(self.num_splits, self.num_features), dim=0
|
||||
).repeat(self.num_splits)
|
||||
self.running_var = torch.mean(
|
||||
self.running_var.view(self.num_splits, self.num_features), dim=0
|
||||
).repeat(self.num_splits)
|
||||
return super().train(mode)
|
||||
|
||||
def forward(self, input):
|
||||
n, c, h, w = input.shape
|
||||
if self.training or not self.track_running_stats:
|
||||
assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
|
||||
return F.batch_norm(
|
||||
input.view(-1, c * self.num_splits, h, w),
|
||||
self.running_mean,
|
||||
self.running_var,
|
||||
self.weight.repeat(self.num_splits),
|
||||
self.bias.repeat(self.num_splits),
|
||||
True,
|
||||
self.momentum,
|
||||
self.eps,
|
||||
).view(n, c, h, w)
|
||||
else:
|
||||
return F.batch_norm(
|
||||
input,
|
||||
self.running_mean[: self.num_features],
|
||||
self.running_var[: self.num_features],
|
||||
self.weight,
|
||||
self.bias,
|
||||
False,
|
||||
self.momentum,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
|
||||
def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
|
||||
GhostBatchNorm(c_out, num_splits=16),
|
||||
nn.CELU(alpha=0.3),
|
||||
)
|
||||
|
||||
|
||||
def conv_pool_norm_act(c_in, c_out):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
GhostBatchNorm(c_out, num_splits=16),
|
||||
nn.CELU(alpha=0.3),
|
||||
)
|
||||
|
||||
|
||||
def patch_whitening(data, patch_size=(3, 3)):
|
||||
# Compute weights from data such that
|
||||
# torch.std(F.conv2d(data, weights), dim=(2, 3))
|
||||
# is close to 1.
|
||||
h, w = patch_size
|
||||
c = data.size(1)
|
||||
patches = data.unfold(2, h, 1).unfold(3, w, 1)
|
||||
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
|
||||
|
||||
n, c, h, w = patches.shape
|
||||
X = patches.reshape(n, c * h * w)
|
||||
X = X / (X.size(0) - 1) ** 0.5
|
||||
covariance = X.t() @ X
|
||||
|
||||
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
|
||||
|
||||
eigenvalues = eigenvalues.flip(0)
|
||||
|
||||
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
|
||||
|
||||
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
class ResNetBagOfTricks(nn.Module):
|
||||
def __init__(self, first_layer_weights, c_in, c_out, scale_out):
|
||||
super().__init__()
|
||||
|
||||
c = first_layer_weights.size(0)
|
||||
|
||||
conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
|
||||
conv1.weight.data = first_layer_weights
|
||||
conv1.weight.requires_grad = False
|
||||
|
||||
self.conv1 = conv1
|
||||
self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
|
||||
self.conv3 = conv_pool_norm_act(64, 128)
|
||||
self.conv4 = conv_bn_relu(128, 128)
|
||||
self.conv5 = conv_bn_relu(128, 128)
|
||||
self.conv6 = conv_pool_norm_act(128, 256)
|
||||
self.conv7 = conv_pool_norm_act(256, 512)
|
||||
self.conv8 = conv_bn_relu(512, 512)
|
||||
self.conv9 = conv_bn_relu(512, 512)
|
||||
self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
|
||||
self.linear11 = nn.Linear(512, c_out, bias=False)
|
||||
self.scale_out = scale_out
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = x + self.conv5(self.conv4(x))
|
||||
x = self.conv6(x)
|
||||
x = self.conv7(x)
|
||||
x = x + self.conv9(self.conv8(x))
|
||||
x = self.pool10(x)
|
||||
x = x.reshape(x.size(0), x.size(1))
|
||||
x = self.linear11(x)
|
||||
x = self.scale_out * x
|
||||
return x
|
||||
|
||||
Model = ResNetBagOfTricks
|
Loading…
Reference in a new issue