O1: add fast training code

This commit is contained in:
Akemi Izuko 2024-12-06 18:56:47 -07:00
parent 86d16e53d7
commit ce3a848eb7
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
2 changed files with 254 additions and 11 deletions

View file

@ -20,11 +20,14 @@ from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet
from equations import get_eps_audit
import student_model
import fast_model
import warnings
warnings.filterwarnings("ignore")
DEVICE = None
DTYPE = None
DATADIR = Path("./data")
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
@ -46,9 +49,8 @@ def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datadir = Path("./data")
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
# Original dataset
x = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds))) # Applies transforms
@ -106,11 +108,9 @@ def get_dataloaders2(m=1000, train_batch_size=128, test_batch_size=10):
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datadir = Path("./data")
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
trainp_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
trainp_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
mask = random.sample(range(len(trainp_ds)), m)
S = np.random.choice([True, False], size=m)
@ -148,9 +148,8 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
datadir = Path("./data")
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
train_ds = CIFAR10(root=DATADIR, train=True, download=True, transform=train_transform)
test_ds = CIFAR10(root=DATADIR, train=False, download=True, transform=test_transform)
# Original dataset
x_train = np.stack(train_ds[i][0].numpy() for i in range(len(train_ds)))
@ -192,6 +191,39 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
def get_dataloaders_raw(m=1000, train_batch_size=512, test_batch_size=10):
def preprocess_data(data):
data = torch.tensor(data)#.to(DTYPE)
data = data / 255.0
data = data.permute(0, 3, 1, 2)
data = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(data)
data = nn.ReflectionPad2d(4)(data)
data = transforms.RandomCrop(size=(32, 32))(data)
data = transforms.RandomHorizontalFlip()(data)
return data
train_ds = CIFAR10(root=DATADIR, train=True, download=True)
test_ds = CIFAR10(root=DATADIR, train=False, download=True)
train_x = preprocess_data(train_ds.data)
test_x = preprocess_data(test_ds.data)
train_y = torch.tensor(train_ds.targets)
test_y = torch.tensor(test_ds.targets)
train_dl = DataLoader(
TensorDataset(train_x, train_y.long()),
batch_size=train_batch_size,
shuffle=True,
num_workers=4
)
test_dl = DataLoader(
TensorDataset(test_x, test_y.long()),
batch_size=train_batch_size,
shuffle=True,
num_workers=4
)
return train_dl, test_dl, train_x
def evaluate_on(model, dataloader):
correct = 0
total = 0
@ -404,6 +436,67 @@ def train_small(hp, train_dl, test_dl):
return model_init, model
def train_fast(hp):
epochs = hp['epochs']
momentum = 0.9
weight_decay = 0.256
weight_decay_bias = 0.004
ema_update_freq = 5
ema_rho = 0.99**ema_update_freq
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
print("=========================")
print("Training a fast model")
print("=========================")
train_dl, test_dl, train_x = get_dataloaders_raw(hp['target_points'])
weights = fast_model.patch_whitening(train_x[:10000, :, 4:-4, 4:-4])
model = fast_model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
#train_model.to(DTYPE)
#for module in train_model.modules():
# if isinstance(module, nn.BatchNorm2d):
# module.float()
model.to(DEVICE)
init_model = copy.deepcopy(model)
# weights = [
# (w, torch.zeros_like(w))
# for w in train_model.parameters()
# if w.requires_grad and len(w.shape) > 1
# ]
# biases = [
# (w, torch.zeros_like(w))
# for w in train_model.parameters()
# if w.requires_grad and len(w.shape) <= 1
# ]
# lr_schedule = torch.cat(
# [
# torch.linspace(0e0, 2e-3, 194),
# torch.linspace(2e-3, 2e-4, 582),
# ]
# )
# lr_schedule_bias = 64.0 * lr_schedule
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
nesterov=True,
weight_decay=5e-4
)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler)
return init_model, model
def train(hp, train_dl, test_dl):
model = WideResNet(
d=hp["wrn_depth"],
@ -481,6 +574,7 @@ def train(hp, train_dl, test_dl):
def main():
global DEVICE
global DTYPE
parser = argparse.ArgumentParser(description='WideResNet O1 audit')
parser.add_argument('--norm', type=float, help='dpsgd norm clip factor', required=True)
@ -492,14 +586,18 @@ def main():
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
parser.add_argument('--fast', action='store_true', help='train a the fast model', required=False)
args = parser.parse_args()
if torch.cuda.is_available() and args.cuda:
DEVICE = torch.device(f'cuda:{args.cuda}')
DTYPE = torch.float16
elif torch.cuda.is_available():
DEVICE = torch.device('cuda:0')
DTYPE = torch.float16
else:
DEVICE = torch.device('cpu')
DTYPE = torch.float32
hp = {
"target_points": args.m,
@ -530,6 +628,10 @@ def main():
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
test_dl = None
elif args.fast:
train_dl, test_dl, _ = get_dataloaders_raw(hp['target_points'])
model_init, model_trained = train_fast(hp)
exit(1)
else:
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
if args.studentraw:

141
one_run_audit/fast_model.py Normal file
View file

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def label_smoothing_loss(inputs, targets, alpha):
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
kl = -log_probs.mean(dim=1)
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
loss = (1 - alpha) * xent + alpha * kl
return loss
class GhostBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
running_mean = torch.zeros(num_features * num_splits)
running_var = torch.ones(num_features * num_splits)
self.weight.requires_grad = False
self.num_splits = num_splits
self.register_buffer("running_mean", running_mean)
self.register_buffer("running_var", running_var)
def train(self, mode=True):
if (self.training is True) and (mode is False):
# lazily collate stats when we are going to use them
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
return super().train(mode)
def forward(self, input):
n, c, h, w = input.shape
if self.training or not self.track_running_stats:
assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
return F.batch_norm(
input.view(-1, c * self.num_splits, h, w),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
).view(n, c, h, w)
else:
return F.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.weight,
self.bias,
False,
self.momentum,
self.eps,
)
def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def conv_pool_norm_act(c_in, c_out):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.MaxPool2d(kernel_size=2, stride=2),
GhostBatchNorm(c_out, num_splits=16),
nn.CELU(alpha=0.3),
)
def patch_whitening(data, patch_size=(3, 3)):
# Compute weights from data such that
# torch.std(F.conv2d(data, weights), dim=(2, 3))
# is close to 1.
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
class ResNetBagOfTricks(nn.Module):
def __init__(self, first_layer_weights, c_in, c_out, scale_out):
super().__init__()
c = first_layer_weights.size(0)
conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv1.weight.data = first_layer_weights
conv1.weight.requires_grad = False
self.conv1 = conv1
self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
self.conv3 = conv_pool_norm_act(64, 128)
self.conv4 = conv_bn_relu(128, 128)
self.conv5 = conv_bn_relu(128, 128)
self.conv6 = conv_pool_norm_act(128, 256)
self.conv7 = conv_pool_norm_act(256, 512)
self.conv8 = conv_bn_relu(512, 512)
self.conv9 = conv_bn_relu(512, 512)
self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
self.linear11 = nn.Linear(512, c_out, bias=False)
self.scale_out = scale_out
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + self.conv5(self.conv4(x))
x = self.conv6(x)
x = self.conv7(x)
x = x + self.conv9(self.conv8(x))
x = self.pool10(x)
x = x.reshape(x.size(0), x.size(1))
x = self.linear11(x)
x = self.scale_out * x
return x
Model = ResNetBagOfTricks