diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 2f1c308..e3ceee5 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -19,6 +19,7 @@ from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet from equations import get_eps_audit +import student_model import warnings warnings.filterwarnings("ignore") @@ -185,9 +186,10 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10): td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long()) td = ConcatDataset([td, td2]) train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4) + pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4) - return train_dl, test_dl, adv_points, adv_labels, S + return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S def evaluate_on(model, dataloader): @@ -203,7 +205,11 @@ def evaluate_on(model, dataloader): labels = labels.to(DEVICE) wrn_outputs = model(images) - outputs = wrn_outputs[0] + if len(wrn_outputs) == 4: + outputs = wrn_outputs[0] + else: + outputs = wrn_outputs + _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() @@ -211,6 +217,53 @@ def evaluate_on(model, dataloader): return correct, total +def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75): + #instantiate istudent + student = student_model.Model(num_classes=10).to(device) + + ce_loss = nn.CrossEntropyLoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + student_init = copy.deepcopy(student) + student.to(device) + teacher.to(device) + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_dl: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights + with torch.no_grad(): + teacher_logits, _, _, _ = teacher(inputs) + + # Forward pass with the student model + student_logits = student(inputs) + #Soften the student logits by applying softmax first and log() second + soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) + soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) + + # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" + soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + if epoch % 10 == 0: + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") + + return student_init, student + + def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler): best_test_set_accuracy = 0 @@ -224,7 +277,10 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler): optimizer.zero_grad() wrn_outputs = model(inputs) - outputs = wrn_outputs[0] + if len(wrn_outputs) == 4: + outputs = wrn_outputs[0] + else: + outputs = wrn_outputs loss = criterion(outputs, labels) loss.backward() optimizer.step() @@ -287,6 +343,67 @@ def load(hp, model_path, train_dl): +def train_small(hp, train_dl, test_dl): + model = student_model.Model(num_classes=10).to(DEVICE) + model = model.to(DEVICE) + model = ModuleValidator.fix(model) + ModuleValidator.validate(model, strict=True) + + model_init = copy.deepcopy(model) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + scheduler = MultiStepLR( + optimizer, + milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]], + gamma=0.2 + ) + + print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs") + + if hp['epsilon'] is not None: + privacy_engine = opacus.PrivacyEngine() + model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( + module=model, + optimizer=optimizer, + data_loader=train_dl, + epochs=hp['epochs'], + target_epsilon=hp['epsilon'], + target_delta=hp['delta'], + max_grad_norm=hp['norm'], + ) + + print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}") + print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}") + + with BatchMemoryManager( + data_loader=train_loader, + max_physical_batch_size=2000, # 1000 ~= 9.4GB vram + optimizer=optimizer + ) as memory_safe_data_loader: + best_test_set_accuracy = train_no_cap( + model, + hp, + memory_safe_data_loader, + test_dl, + optimizer, + criterion, + scheduler, + ) + else: + print("Training without differential privacy") + best_test_set_accuracy = train_no_cap( + model, + hp, + train_dl, + test_dl, + optimizer, + criterion, + scheduler, + ) + + return model_init, model + def train(hp, train_dl, test_dl): model = WideResNet( d=hp["wrn_depth"], @@ -373,6 +490,8 @@ def main(): parser.add_argument('--k', type=int, help='number of symmetric guesses', required=True) parser.add_argument('--epochs', type=int, help='number of epochs', required=True) parser.add_argument('--load', type=Path, help='number of epochs', required=False) + parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False) + parser.add_argument('--distill', action='store_true', help='train a raw student', required=False) args = parser.parse_args() if torch.cuda.is_available() and args.cuda: @@ -408,12 +527,36 @@ def main(): )) if args.load: - train_dl, test_dl, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size']) + train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size']) model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl) test_dl = None else: - train_dl, test_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size']) - model_init, model_trained = train(hp, train_dl, test_dl) + train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size']) + if args.studentraw: + print("=========================") + print("Training a raw student model") + print("=========================") + model_init, model_trained = train_small(hp, train_dl, test_dl) + elif args.distill: + print("=========================") + print("Training a distilled student model") + print("=========================") + teacher_init, teacher_trained = train(hp, train_dl, test_dl) + model_init, model_trained = train_knowledge_distillation( + teacher=teacher_trained, + train_dl=train_dl, + epochs=100, + device=DEVICE, + learning_rate=0.001, + T=2, + soft_target_loss_weight=0.25, + ce_loss_weight=0.75, + ) + else: + print("=========================") + print("Training teacher model") + print("=========================") + model_init, model_trained = train(hp, train_dl, test_dl) np.save("data/adv_points", adv_points) np.save("data/adv_labels", adv_labels) @@ -433,8 +576,13 @@ def main(): y_point = y_m[i].unsqueeze(0).to(DEVICE) is_in = S[i] - init_loss = criterion(model_init(x_point)[0], y_point) - trained_loss = criterion(model_trained(x_point)[0], y_point) + wrn_outputs = model_init(x_point) + outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs + init_loss = criterion(outputs, y_point) + + wrn_outputs = model_trained(x_point) + outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs + trained_loss = criterion(outputs, y_point) scores.append(((init_loss - trained_loss).item(), is_in)) diff --git a/one_run_audit/student_model.py b/one_run_audit/student_model.py new file mode 100644 index 0000000..fa6d723 --- /dev/null +++ b/one_run_audit/student_model.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +# Create a similar student class where we return a tuple. We do not apply pooling after flattening. +class ModifiedLightNNCosine(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedLightNNCosine, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + flattened_conv_output = torch.flatten(x, 1) + x = self.classifier(flattened_conv_output) + return x + +Model = ModifiedLightNNCosine