From 91c61df0a8570939f53cb2345ab9c0c084682659 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sat, 23 Nov 2024 22:37:07 -0700 Subject: [PATCH] Lira: save load shadow models --- cifar10-fast-simple/train.py | 93 +++++++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/cifar10-fast-simple/train.py b/cifar10-fast-simple/train.py index 0dba02a..07a3555 100644 --- a/cifar10-fast-simple/train.py +++ b/cifar10-fast-simple/train.py @@ -6,6 +6,62 @@ import torchvision import model +def load_model(model_path, device, dtype, train_data): + weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4]) + train_model = model.Model(weights, c_in=3, c_out=10, scale_out=0.125) + train_model.load_state_dict(torch.load(model_path, weights_only=True)) + + # Convert model weights to half precision + train_model.to(dtype) + + # Convert BatchNorm back to single precision for better accuracy + for module in train_model.modules(): + if isinstance(module, nn.BatchNorm2d): + module.float() + + # Upload model to GPU + train_model.to(device) + + return train_model + + +def eval_model(smodel, device, dtype, data, labels, batch_size): + smodel.eval() + eval_correct = [] + + with torch.no_grad(): + for i in range(0, len(data), batch_size): + regular_inputs = data[i : i + batch_size].to(device, dtype) + flipped_inputs = torch.flip(regular_inputs, [-1]) + + logits1 = smodel(regular_inputs) + logits2 = smodel(flipped_inputs) + + # Final logits are average of augmented logits + logits = torch.mean(torch.stack([logits1, logits2], dim=0), dim=0) + + # Compute correct predictions + correct = logits.max(dim=1)[1] == labels[i : i + batch_size].to(device) + eval_correct.append(correct.detach().type(torch.float64)) + + # Accuracy is average number of correct predictions + eval_acc = torch.mean(torch.cat(eval_correct)).item() + + return eval_acc + + +def run_shadow_model(): + batch_size = 512 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if device.type != "cpu" else torch.float32 + train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype) + + smodel = load_model("shadow.pt", device, dtype, train_data) + eval_acc = eval_model(smodel, device, dtype, train_data, train_targets, batch_size) + + print(f"Evaluation Accuracy: {eval_acc:.4f}") + + def train(seed=0): # Configurable parameters epochs = 10 @@ -14,16 +70,18 @@ def train(seed=0): weight_decay = 0.256 weight_decay_bias = 0.004 ema_update_freq = 5 - ema_rho = 0.99 ** ema_update_freq + ema_rho = 0.99**ema_update_freq device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if device.type != "cpu" else torch.float32 # First, the learning rate rises from 0 to 0.002 for the first 194 batches. # Next, the learning rate shrinks down to 0.0002 over the next 582 batches. - lr_schedule = torch.cat([ - torch.linspace(0e+0, 2e-3, 194), - torch.linspace(2e-3, 2e-4, 582), - ]) + lr_schedule = torch.cat( + [ + torch.linspace(0e0, 2e-3, 194), + torch.linspace(2e-3, 2e-4, 582), + ] + ) lr_schedule_bias = 64.0 * lr_schedule @@ -79,12 +137,17 @@ def train(seed=0): # Copy the model for validation valid_model = copy.deepcopy(train_model) - print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds") - # Train and validate - print("\nepoch batch train time [sec] validation accuracy") train_time = 0.0 batch_count = 0 + + # Randomly sample half the data per model + nb_rows = train_data.shape[0] + indices = torch.randperm(nb_rows)[: nb_rows // 2] + indices_in = indices[: nb_rows // 2] + train_data = train_data[indices_in] + train_targets = train_targets[indices_in] + for epoch in range(1, epochs + 1): # Flush CUDA pipeline for more accurate time measurement if torch.cuda.is_available(): @@ -169,8 +232,10 @@ def train(seed=0): print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}") + torch.save(train_model.state_dict(), "shadow.pt") return valid_acc + def preprocess_data(data, device, dtype): # Convert to torch float16 tensor data = torch.tensor(data, device=device).to(dtype) @@ -235,12 +300,14 @@ def random_crop(data, crop_size): def sha256(path): import hashlib + with open(path, "rb") as f: return hashlib.sha256(f.read()).hexdigest() def getrelpath(abspath): import os + return os.path.relpath(abspath, os.getcwd()) @@ -255,24 +322,24 @@ def main(): print_info() accuracies = [] - threshold = 0.94 - for run in range(100): + for run in range(1): valid_acc = train(seed=run) accuracies.append(valid_acc) # Print accumulated results - within_threshold = sum(acc >= threshold for acc in accuracies) - acc = threshold * 100.0 + within_threshold = sum(acc >= 0.94 for acc in accuracies) + acc = 0.94 * 100.0 print() print(f"{within_threshold} of {run + 1} runs >= {acc} % accuracy") mean = sum(accuracies) / len(accuracies) - variance = sum((acc - mean)**2 for acc in accuracies) / len(accuracies) + variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies) std = variance**0.5 print(f"Min accuracy: {min(accuracies)}") print(f"Max accuracy: {max(accuracies)}") print(f"Mean accuracy: {mean} +- {std}") print() + run_shadow_model() if __name__ == "__main__": main()