254 lines
7.9 KiB
Python
254 lines
7.9 KiB
Python
import time
|
|
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
import model
|
|
|
|
|
|
def load_model(model_path, device, dtype, train_data):
|
|
weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4])
|
|
train_model = model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
|
|
train_model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
|
|
# Convert model weights to half precision
|
|
train_model.to(dtype)
|
|
|
|
# Convert BatchNorm back to single precision for better accuracy
|
|
for module in train_model.modules():
|
|
if isinstance(module, nn.BatchNorm2d):
|
|
module.float()
|
|
|
|
# Upload model to GPU
|
|
train_model.to(device)
|
|
|
|
return train_model
|
|
|
|
|
|
def eval_model(smodel, device, dtype, data, labels, batch_size):
|
|
smodel.eval()
|
|
eval_correct = []
|
|
|
|
with torch.no_grad():
|
|
for i in range(0, len(data), batch_size):
|
|
regular_inputs = data[i : i + batch_size].to(device, dtype)
|
|
flipped_inputs = torch.flip(regular_inputs, [-1])
|
|
|
|
logits1 = smodel(regular_inputs)
|
|
logits2 = smodel(flipped_inputs)
|
|
|
|
# Final logits are average of augmented logits
|
|
logits = torch.mean(torch.stack([logits1, logits2], dim=0), dim=0)
|
|
|
|
# Compute correct predictions
|
|
correct = logits.max(dim=1)[1] == labels[i : i + batch_size].to(device)
|
|
eval_correct.append(correct.detach().type(torch.float64))
|
|
|
|
# Accuracy is average number of correct predictions
|
|
eval_acc = torch.mean(torch.cat(eval_correct)).item()
|
|
|
|
return eval_acc
|
|
|
|
|
|
def run_shadow_model(shadow_path, device, dtype, data, labels, batch_size):
|
|
smodel = load_model(shadow_path, device, dtype, data)
|
|
eval_acc = eval_model(smodel, device, dtype, data, labels, batch_size)
|
|
|
|
print(f"Evaluation Accuracy: {eval_acc:.4f}")
|
|
|
|
|
|
def train_shadow(shadow_path, train_data, train_targets, batch_size):
|
|
# Configurable parameters
|
|
epochs = 10
|
|
momentum = 0.9
|
|
weight_decay = 0.256
|
|
weight_decay_bias = 0.004
|
|
ema_update_freq = 5
|
|
ema_rho = 0.99**ema_update_freq
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dtype = torch.float16 if device.type != "cpu" else torch.float32
|
|
|
|
# First, the learning rate rises from 0 to 0.002 for the first 194 batches.
|
|
# Next, the learning rate shrinks down to 0.0002 over the next 582 batches.
|
|
lr_schedule = torch.cat(
|
|
[
|
|
torch.linspace(0e0, 2e-3, 194),
|
|
torch.linspace(2e-3, 2e-4, 582),
|
|
]
|
|
)
|
|
|
|
lr_schedule_bias = 64.0 * lr_schedule
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4])
|
|
train_model = model.Model(weights, c_in=3, c_out=10, scale_out=0.125)
|
|
train_model.to(dtype)
|
|
|
|
for module in train_model.modules():
|
|
if isinstance(module, nn.BatchNorm2d):
|
|
module.float()
|
|
|
|
train_model.to(device)
|
|
|
|
# Collect weights and biases and create nesterov velocity values
|
|
weights = [
|
|
(w, torch.zeros_like(w))
|
|
for w in train_model.parameters()
|
|
if w.requires_grad and len(w.shape) > 1
|
|
]
|
|
biases = [
|
|
(w, torch.zeros_like(w))
|
|
for w in train_model.parameters()
|
|
if w.requires_grad and len(w.shape) <= 1
|
|
]
|
|
|
|
batch_count = 0
|
|
|
|
# Randomly sample half the data per model
|
|
nb_rows = train_data.shape[0]
|
|
indices = torch.randperm(nb_rows)[: nb_rows // 2]
|
|
indices_in = indices[: nb_rows // 2]
|
|
train_data = train_data[indices_in]
|
|
train_targets = train_targets[indices_in]
|
|
|
|
for epoch in range(1, epochs + 1):
|
|
# Flush CUDA pipeline for more accurate time measurement
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
# Randomly shuffle training data
|
|
indices = torch.randperm(len(train_data), device=device)
|
|
data = train_data[indices]
|
|
targets = train_targets[indices]
|
|
|
|
# Crop random 32x32 patches from 40x40 training data
|
|
data = [
|
|
random_crop(data[i : i + batch_size], crop_size=(32, 32))
|
|
for i in range(0, len(data), batch_size)
|
|
]
|
|
data = torch.cat(data)
|
|
|
|
# Randomly flip half the training data
|
|
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
|
|
|
|
for i in range(0, len(data), batch_size):
|
|
# discard partial batches
|
|
if i + batch_size > len(data):
|
|
break
|
|
|
|
# Slice batch from data
|
|
inputs = data[i : i + batch_size]
|
|
target = targets[i : i + batch_size]
|
|
batch_count += 1
|
|
|
|
# Compute new gradients
|
|
train_model.zero_grad()
|
|
train_model.train(True)
|
|
|
|
logits = train_model(inputs)
|
|
|
|
loss = model.label_smoothing_loss(logits, target, alpha=0.2)
|
|
|
|
loss.sum().backward()
|
|
|
|
lr_index = min(batch_count, len(lr_schedule) - 1)
|
|
lr = lr_schedule[lr_index]
|
|
lr_bias = lr_schedule_bias[lr_index]
|
|
|
|
# Update weights and biases of training model
|
|
update_nesterov(weights, lr, weight_decay, momentum)
|
|
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
|
|
|
|
torch.save(train_model.state_dict(), shadow_path)
|
|
|
|
|
|
def preprocess_data(data, device, dtype):
|
|
# Convert to torch float16 tensor
|
|
data = torch.tensor(data, device=device).to(dtype)
|
|
|
|
# Normalize
|
|
mean = torch.tensor([125.31, 122.95, 113.87], device=device).to(dtype)
|
|
std = torch.tensor([62.99, 62.09, 66.70], device=device).to(dtype)
|
|
data = (data - mean) / std
|
|
|
|
# Permute data from NHWC to NCHW format
|
|
data = data.permute(0, 3, 1, 2)
|
|
|
|
return data
|
|
|
|
|
|
def load_cifar10(device, dtype, data_dir="~/data"):
|
|
train = torchvision.datasets.CIFAR10(root=data_dir, download=True)
|
|
valid = torchvision.datasets.CIFAR10(root=data_dir, train=False)
|
|
|
|
train_data = preprocess_data(train.data, device, dtype)
|
|
valid_data = preprocess_data(valid.data, device, dtype)
|
|
|
|
train_targets = torch.tensor(train.targets).to(device)
|
|
valid_targets = torch.tensor(valid.targets).to(device)
|
|
|
|
# Pad 32x32 to 40x40
|
|
train_data = nn.ReflectionPad2d(4)(train_data)
|
|
|
|
return train_data, train_targets, valid_data, valid_targets
|
|
|
|
|
|
def update_nesterov(weights, lr, weight_decay, momentum):
|
|
for weight, velocity in weights:
|
|
if weight.requires_grad:
|
|
gradient = weight.grad.data
|
|
weight = weight.data
|
|
|
|
gradient.add_(weight, alpha=weight_decay).mul_(-lr)
|
|
velocity.mul_(momentum).add_(gradient)
|
|
weight.add_(gradient.add_(velocity, alpha=momentum))
|
|
|
|
|
|
def random_crop(data, crop_size):
|
|
crop_h, crop_w = crop_size
|
|
h = data.size(2)
|
|
w = data.size(3)
|
|
x = torch.randint(w - crop_w, size=(1,))[0]
|
|
y = torch.randint(h - crop_h, size=(1,))[0]
|
|
return data[:, :, y : y + crop_h, x : x + crop_w]
|
|
|
|
|
|
def sha256(path):
|
|
import hashlib
|
|
|
|
with open(path, "rb") as f:
|
|
return hashlib.sha256(f.read()).hexdigest()
|
|
|
|
|
|
def getrelpath(abspath):
|
|
import os
|
|
|
|
return os.path.relpath(abspath, os.getcwd())
|
|
|
|
|
|
def print_info():
|
|
# Knowing this information might improve chance of reproducability
|
|
print("File :", getrelpath(__file__), sha256(__file__))
|
|
print("Model :", getrelpath(model.__file__), sha256(model.__file__))
|
|
print("PyTorch:", torch.__version__)
|
|
|
|
|
|
def main():
|
|
print_info()
|
|
|
|
batch_size = 512
|
|
shadow_path = "shadow.pt"
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dtype = torch.float16 if device.type != "cpu" else torch.float32
|
|
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
|
|
|
|
train_shadow(shadow_path, train_data, train_targets, batch_size)
|
|
run_shadow_model(shadow_path, device, dtype, train_data, train_targets, batch_size)
|
|
run_shadow_model(shadow_path, device, dtype, valid_data, valid_targets, batch_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|