From ce3a848eb70ce3b11de2dfbb3d05ab8cbd728dbd Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Fri, 6 Dec 2024 18:56:47 -0700 Subject: [PATCH] O1: add fast training code --- one_run_audit/audit.py | 124 ++++++++++++++++++++++++++++--- one_run_audit/fast_model.py | 141 ++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 11 deletions(-) create mode 100644 one_run_audit/fast_model.py diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 1083156..233f681 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -20,11 +20,14 @@ from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet from equations import get_eps_audit import student_model +import fast_model import warnings warnings.filterwarnings("ignore") DEVICE = None +DTYPE = None +DATADIR = Path("./data") def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): @@ -46,9 +49,8 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) - datadir = Path("./data") - train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) - test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) + train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform) + test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform) # Original dataset x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms @@ -106,11 +108,9 @@ def get_dataloaders2(m=1000, train_batch_size=128, test_batch_size=10): transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) - datadir = Path("./data") - - train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) - trainp_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) - test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) + train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform) + trainp_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform) + test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform) mask = random.sample(range(len(trainp_ds)), m) S = np.random.choice([True, False], size=m) @@ -148,9 +148,8 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10): transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) - datadir = Path("./data") - train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) - test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) + train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform) + test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform) # Original dataset x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) @@ -192,6 +191,39 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10): return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S +def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10): + def preprocess_data(data): + data = torch.tensor(data)#.to(DTYPE) + data = data / 255.0 + data = data.permute(0, 3, 1, 2) + data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data) + data = nn.ReflectionPad2d(4)(data) + data = transforms.RandomCrop(size=(32, 32))(data) + data = transforms.RandomHorizontalFlip()(data) + return data + + train_ds = CIFAR10(root=DATADIR, train=True, download=True) + test_ds = CIFAR10(root=DATADIR, train=False, download=True) + + train_x = preprocess_data(train_ds.data) + test_x = preprocess_data(test_ds.data) + train_y = torch.tensor(train_ds.targets) + test_y = torch.tensor(test_ds.targets) + + train_dl = DataLoader( + TensorDataset(train_x, train_y.long()), + batch_size=train_batch_size, + shuffle=True, + num_workers=4 + ) + test_dl = DataLoader( + TensorDataset(test_x, test_y.long()), + batch_size=train_batch_size, + shuffle=True, + num_workers=4 + ) + return train_dl, test_dl, train_x + def evaluate_on(model, dataloader): correct = 0 total = 0 @@ -404,6 +436,67 @@ def train_small(hp, train_dl, test_dl): return model_init, model +def train_fast(hp): + epochs = hp['epochs'] + momentum = 0.9 + weight_decay = 0.256 + weight_decay_bias = 0.004 + ema_update_freq = 5 + ema_rho = 0.99**ema_update_freq + dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32 + + print("=========================") + print("Training a fast model") + print("=========================") + train_dl, test_dl, train_x = get_dataloaders_raw(hp['target_points']) + + weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4]) + model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125) + + #train_model.to(DTYPE) + #for module in train_model.modules(): + # if isinstance(module, nn.BatchNorm2d): + # module.float() + model.to(DEVICE) + init_model = copy.deepcopy(model) + + # weights = [ + # (w, torch.zeros_like(w)) + # for w in train_model.parameters() + # if w.requires_grad and len(w.shape) > 1 + # ] + # biases = [ + # (w, torch.zeros_like(w)) + # for w in train_model.parameters() + # if w.requires_grad and len(w.shape) <= 1 + # ] + + # lr_schedule = torch.cat( + # [ + # torch.linspace(0e0, 2e-3, 194), + # torch.linspace(2e-3, 2e-4, 582), + # ] + # ) + # lr_schedule_bias = 64.0 * lr_schedule + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + nesterov=True, + weight_decay=5e-4 + ) + scheduler = MultiStepLR( + optimizer, + milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], + gamma=0.2 + ) + + train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler) + + return init_model, model + def train(hp, train_dl, test_dl): model = WideResNet( d=hp["wrn_depth"], @@ -481,6 +574,7 @@ def train(hp, train_dl, test_dl): def main(): global DEVICE + global DTYPE parser = argparse.ArgumentParser(description='WideResNet O1 audit') parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True) @@ -492,14 +586,18 @@ def main(): parser.add_argument('--load', type=Path, help='number of epochs', required=False) parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False) parser.add_argument('--distill', action='store_true', help='train a raw student', required=False) + parser.add_argument('--fast', action='store_true', help='train a the fast model', required=False) args = parser.parse_args() if torch.cuda.is_available() and args.cuda: DEVICE = torch.device(f'cuda:{args.cuda}') + DTYPE = torch.float16 elif torch.cuda.is_available(): DEVICE = torch.device('cuda:0') + DTYPE = torch.float16 else: DEVICE = torch.device('cpu') + DTYPE = torch.float32 hp = { "target_points": args.m, @@ -530,6 +628,10 @@ def main(): train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size']) model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl) test_dl = None + elif args.fast: + train_dl, test_dl, _ = get_dataloaders_raw(hp['target_points']) + model_init, model_trained = train_fast(hp) + exit(1) else: train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size']) if args.studentraw: diff --git a/one_run_audit/fast_model.py b/one_run_audit/fast_model.py new file mode 100644 index 0000000..fe33495 --- /dev/null +++ b/one_run_audit/fast_model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def label_smoothing_loss(inputs, targets, alpha): + log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5) + kl = -log_probs.mean(dim=1) + xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none") + loss = (1 - alpha) * xent + alpha * kl + return loss + + +class GhostBatchNorm(nn.BatchNorm2d): + def __init__(self, num_features, num_splits, **kw): + super().__init__(num_features, **kw) + + running_mean = torch.zeros(num_features * num_splits) + running_var = torch.ones(num_features * num_splits) + + self.weight.requires_grad = False + self.num_splits = num_splits + self.register_buffer("running_mean", running_mean) + self.register_buffer("running_var", running_var) + + def train(self, mode=True): + if (self.training is True) and (mode is False): + # lazily collate stats when we are going to use them + self.running_mean = torch.mean( + self.running_mean.view(self.num_splits, self.num_features), dim=0 + ).repeat(self.num_splits) + self.running_var = torch.mean( + self.running_var.view(self.num_splits, self.num_features), dim=0 + ).repeat(self.num_splits) + return super().train(mode) + + def forward(self, input): + n, c, h, w = input.shape + if self.training or not self.track_running_stats: + assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm" + return F.batch_norm( + input.view(-1, c * self.num_splits, h, w), + self.running_mean, + self.running_var, + self.weight.repeat(self.num_splits), + self.bias.repeat(self.num_splits), + True, + self.momentum, + self.eps, + ).view(n, c, h, w) + else: + return F.batch_norm( + input, + self.running_mean[: self.num_features], + self.running_var[: self.num_features], + self.weight, + self.bias, + False, + self.momentum, + self.eps, + ) + + +def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)): + return nn.Sequential( + nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False), + GhostBatchNorm(c_out, num_splits=16), + nn.CELU(alpha=0.3), + ) + + +def conv_pool_norm_act(c_in, c_out): + return nn.Sequential( + nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False), + nn.MaxPool2d(kernel_size=2, stride=2), + GhostBatchNorm(c_out, num_splits=16), + nn.CELU(alpha=0.3), + ) + + +def patch_whitening(data, patch_size=(3, 3)): + # Compute weights from data such that + # torch.std(F.conv2d(data, weights), dim=(2, 3)) + # is close to 1. + h, w = patch_size + c = data.size(1) + patches = data.unfold(2, h, 1).unfold(3, w, 1) + patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32) + + n, c, h, w = patches.shape + X = patches.reshape(n, c * h * w) + X = X / (X.size(0) - 1) ** 0.5 + covariance = X.t() @ X + + eigenvalues, eigenvectors = torch.linalg.eigh(covariance) + + eigenvalues = eigenvalues.flip(0) + + eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0) + + return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1) + + +class ResNetBagOfTricks(nn.Module): + def __init__(self, first_layer_weights, c_in, c_out, scale_out): + super().__init__() + + c = first_layer_weights.size(0) + + conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False) + conv1.weight.data = first_layer_weights + conv1.weight.requires_grad = False + + self.conv1 = conv1 + self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0) + self.conv3 = conv_pool_norm_act(64, 128) + self.conv4 = conv_bn_relu(128, 128) + self.conv5 = conv_bn_relu(128, 128) + self.conv6 = conv_pool_norm_act(128, 256) + self.conv7 = conv_pool_norm_act(256, 512) + self.conv8 = conv_bn_relu(512, 512) + self.conv9 = conv_bn_relu(512, 512) + self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4) + self.linear11 = nn.Linear(512, c_out, bias=False) + self.scale_out = scale_out + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = x + self.conv5(self.conv4(x)) + x = self.conv6(x) + x = self.conv7(x) + x = x + self.conv9(self.conv8(x)) + x = self.pool10(x) + x = x.reshape(x.size(0), x.size(1)) + x = self.linear11(x) + x = self.scale_out * x + return x + +Model = ResNetBagOfTricks