From 3467c258825cf51f1db8173b16f5b297a13e55b3 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sat, 23 Nov 2024 23:32:39 -0700 Subject: [PATCH] Lira: remove deadcode --- cifar10-fast-simple/train.py | 210 +---------------------------------- 1 file changed, 3 insertions(+), 207 deletions(-) diff --git a/cifar10-fast-simple/train.py b/cifar10-fast-simple/train.py index abde7a0..2e89920 100644 --- a/cifar10-fast-simple/train.py +++ b/cifar10-fast-simple/train.py @@ -57,7 +57,7 @@ def run_shadow_model(shadow_path, device, dtype, data, labels, batch_size): print(f"Evaluation Accuracy: {eval_acc:.4f}") -def train_shadow(shadow_path, train_data, train_targets, valid_data, valid_targets, batch_size): +def train_shadow(shadow_path, train_data, train_targets, batch_size): # Configurable parameters epochs = 10 momentum = 0.9 @@ -81,8 +81,6 @@ def train_shadow(shadow_path, train_data, train_targets, valid_data, valid_targe torch.backends.cudnn.benchmark = True - # train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype) - weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4]) train_model = model.Model(weights, c_in=3, c_out=10, scale_out=0.125) train_model.to(dtype) @@ -105,7 +103,6 @@ def train_shadow(shadow_path, train_data, train_targets, valid_data, valid_targe if w.requires_grad and len(w.shape) <= 1 ] - # Train and validate batch_count = 0 # Randomly sample half the data per model @@ -168,180 +165,6 @@ def train_shadow(shadow_path, train_data, train_targets, valid_data, valid_targe torch.save(train_model.state_dict(), shadow_path) -def train(seed=0): - # Configurable parameters - epochs = 10 - batch_size = 512 - momentum = 0.9 - weight_decay = 0.256 - weight_decay_bias = 0.004 - ema_update_freq = 5 - ema_rho = 0.99**ema_update_freq - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.float16 if device.type != "cpu" else torch.float32 - - # First, the learning rate rises from 0 to 0.002 for the first 194 batches. - # Next, the learning rate shrinks down to 0.0002 over the next 582 batches. - lr_schedule = torch.cat( - [ - torch.linspace(0e0, 2e-3, 194), - torch.linspace(2e-3, 2e-4, 582), - ] - ) - - lr_schedule_bias = 64.0 * lr_schedule - - # Print information about hardware on first run - if seed == 0: - if device.type == "cuda": - print("Device :", torch.cuda.get_device_name(device.index)) - - print("Dtype :", dtype) - print() - - # Start measuring time - start_time = time.perf_counter() - - # Set random seed to increase chance of reproducability - torch.manual_seed(seed) - - # Setting cudnn.benchmark to True hampers reproducability, but is faster - torch.backends.cudnn.benchmark = True - - # Load dataset - train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype) - - # Compute special weights for first layer - weights = model.patch_whitening(train_data[:10000, :, 4:-4, 4:-4]) - - # Construct the neural network - train_model = model.Model(weights, c_in=3, c_out=10, scale_out=0.125) - - # Convert model weights to half precision - train_model.to(dtype) - - # Convert BatchNorm back to single precision for better accuracy - for module in train_model.modules(): - if isinstance(module, nn.BatchNorm2d): - module.float() - - # Upload model to GPU - train_model.to(device) - - # Collect weights and biases and create nesterov velocity values - weights = [ - (w, torch.zeros_like(w)) - for w in train_model.parameters() - if w.requires_grad and len(w.shape) > 1 - ] - biases = [ - (w, torch.zeros_like(w)) - for w in train_model.parameters() - if w.requires_grad and len(w.shape) <= 1 - ] - - # Copy the model for validation - valid_model = copy.deepcopy(train_model) - - # Train and validate - train_time = 0.0 - batch_count = 0 - - # Randomly sample half the data per model - nb_rows = train_data.shape[0] - indices = torch.randperm(nb_rows)[: nb_rows // 2] - indices_in = indices[: nb_rows // 2] - train_data = train_data[indices_in] - train_targets = train_targets[indices_in] - - for epoch in range(1, epochs + 1): - # Flush CUDA pipeline for more accurate time measurement - if torch.cuda.is_available(): - torch.cuda.synchronize() - - start_time = time.perf_counter() - - # Randomly shuffle training data - indices = torch.randperm(len(train_data), device=device) - data = train_data[indices] - targets = train_targets[indices] - - # Crop random 32x32 patches from 40x40 training data - data = [ - random_crop(data[i : i + batch_size], crop_size=(32, 32)) - for i in range(0, len(data), batch_size) - ] - data = torch.cat(data) - - # Randomly flip half the training data - data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1]) - - for i in range(0, len(data), batch_size): - # discard partial batches - if i + batch_size > len(data): - break - - # Slice batch from data - inputs = data[i : i + batch_size] - target = targets[i : i + batch_size] - batch_count += 1 - - # Compute new gradients - train_model.zero_grad() - train_model.train(True) - - logits = train_model(inputs) - - loss = model.label_smoothing_loss(logits, target, alpha=0.2) - - loss.sum().backward() - - lr_index = min(batch_count, len(lr_schedule) - 1) - lr = lr_schedule[lr_index] - lr_bias = lr_schedule_bias[lr_index] - - # Update weights and biases of training model - update_nesterov(weights, lr, weight_decay, momentum) - update_nesterov(biases, lr_bias, weight_decay_bias, momentum) - - # Update validation model with exponential moving averages - if (i // batch_size % ema_update_freq) == 0: - update_ema(train_model, valid_model, ema_rho) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - - # Add training time - train_time += time.perf_counter() - start_time - - valid_correct = [] - for i in range(0, len(valid_data), batch_size): - valid_model.train(False) - - # Test time agumentation: Test model on regular and flipped data - regular_inputs = valid_data[i : i + batch_size] - flipped_inputs = torch.flip(regular_inputs, [-1]) - - logits1 = valid_model(regular_inputs).detach() - logits2 = valid_model(flipped_inputs).detach() - - # Final logits are average of augmented logits - logits = torch.mean(torch.stack([logits1, logits2], dim=0), dim=0) - - # Compute correct predictions - correct = logits.max(dim=1)[1] == valid_targets[i : i + batch_size] - - valid_correct.append(correct.detach().type(torch.float64)) - - # Accuracy is average number of correct predictions - valid_acc = torch.mean(torch.cat(valid_correct)).item() - - print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}") - - torch.save(train_model.state_dict(), "shadow.pt") - return valid_acc - - def preprocess_data(data, device, dtype): # Convert to torch float16 tensor data = torch.tensor(data, device=device).to(dtype) @@ -373,17 +196,6 @@ def load_cifar10(device, dtype, data_dir="~/data"): return train_data, train_targets, valid_data, valid_targets -def update_ema(train_model, valid_model, rho): - # The trained model is not used for validation directly. Instead, the - # validation model weights are updated with exponential moving averages. - train_weights = train_model.state_dict().values() - valid_weights = valid_model.state_dict().values() - for train_weight, valid_weight in zip(train_weights, valid_weights): - if valid_weight.dtype in [torch.float16, torch.float32]: - valid_weight *= rho - valid_weight += (1 - rho) * train_weight - - def update_nesterov(weights, lr, weight_decay, momentum): for weight, velocity in weights: if weight.requires_grad: @@ -427,32 +239,16 @@ def print_info(): def main(): print_info() - accuracies = [] - for run in range(1): - valid_acc = train(seed=run) - accuracies.append(valid_acc) - - # Print accumulated results - within_threshold = sum(acc >= 0.94 for acc in accuracies) - acc = 0.94 * 100.0 - print() - print(f"{within_threshold} of {run + 1} runs >= {acc} % accuracy") - mean = sum(accuracies) / len(accuracies) - variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies) - std = variance**0.5 - print(f"Min accuracy: {min(accuracies)}") - print(f"Max accuracy: {max(accuracies)}") - print(f"Mean accuracy: {mean} +- {std}") - print() batch_size = 512 shadow_path = "shadow.pt" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if device.type != "cpu" else torch.float32 train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype) - train_shadow(shadow_path, train_data, train_targets, valid_data, valid_targets, batch_size) + train_shadow(shadow_path, train_data, train_targets, batch_size) run_shadow_model(shadow_path, device, dtype, train_data, train_targets, batch_size) run_shadow_model(shadow_path, device, dtype, valid_data, valid_targets, batch_size) + if __name__ == "__main__": main()