diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 48012f3..d25de1a 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -18,12 +18,61 @@ from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager from WideResNet import WideResNet from equations import get_eps_audit + +import student_model + import warnings warnings.filterwarnings("ignore") DEVICE = None +STUDENTBOOL = False +def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75): + #instantiate istudent + student = student_model.Model(num_classes=10).to(device) + + ce_loss = nn.CrossEntropyLoss() + optimizer = optim.Adam(student.parameters(), lr=learning_rate) + student_init = copy.deepcopy(student) + student.to(device) + teacher.to(device) + teacher.eval() # Teacher set to evaluation mode + student.train() # Student to train mode + for epoch in range(epochs): + running_loss = 0.0 + for inputs, labels in train_dl: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights + with torch.no_grad(): + teacher_logits, _, _, _ = teacher(inputs) + + # Forward pass with the student model + student_logits = student(inputs) + #Soften the student logits by applying softmax first and log() second + soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) + soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) + + # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network" + soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) + + # Calculate the true label loss + label_loss = ce_loss(student_logits, labels) + + # Weighted sum of the two losses + loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss + + loss.backward() + optimizer.step() + + running_loss += loss.item() + if epoch % 10 == 0: + print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}") + + return student_init, student def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): seed = np.random.randint(0, 1e9) @@ -90,7 +139,10 @@ def evaluate_on(model, dataloader): labels = labels.to(DEVICE) wrn_outputs = model(images) - outputs = wrn_outputs[0] + if STUDENTBOOL: + outputs = wrn_outputs + else: + outputs = wrn_outputs[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() @@ -209,6 +261,8 @@ def main(): parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--m', type=int, help='number of target points', required=True) + parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher") + args = parser.parse_args() if torch.cuda.is_available() and args.cuda: @@ -227,8 +281,8 @@ def main(): "norm": args.norm, "batch_size": 4096, "epochs": 100, - "k+": 300, - "k-": 300, + "k+": 200, + "k-": 200, "p_value": 0.05, } @@ -250,11 +304,35 @@ def main(): print(f"Got x_m: {x_m.shape}") print(f"Got y_m: {y_m.shape}") - model_init, model_trained = train(hp, train_dl, test_dl) - # torch.save(model_init.state_dict(), "data/init_model.pt") # torch.save(model_trained.state_dict(), "data/trained_model.pt") + if args.auditmodel == "student": + global STUDENTBOOL + teacher_init, teacher_trained = train(hp, train_dl, test_dl) + STUDENTBOOL = True + # torch.save(model_init.state_dict(), "data/init_model.pt") + # torch.save(model_trained.state_dict(), "data/trained_model.pt") + + + #train student model + print("Training Student Model") + model_init, model_trained = train_knowledge_distillation( + teacher=teacher_trained, + train_dl=train_dl, + epochs=100, + device=DEVICE, + learning_rate=0.001, + T=2, + soft_target_loss_weight=0.25, + ce_loss_weight=0.75, + ) + stcorrect, sttotal = evaluate_on(model_trained, test_dl) + stacc = stcorrect/sttotal*100 + print(f"Student Accuracy: {stacc}%") + else: + model_init, model_trained = train(hp, train_dl, test_dl) + scores = list() criterion = nn.CrossEntropyLoss() with torch.no_grad(): @@ -266,9 +344,12 @@ def main(): x_point = x_m[i].unsqueeze(0) y_point = y_m[i].unsqueeze(0) is_in = S_m[i] - - init_loss = criterion(model_init(x_point)[0], y_point) - trained_loss = criterion(model_trained(x_point)[0], y_point) + if STUDENTBOOL: + init_loss = criterion(model_init(x_point), y_point) + trained_loss = criterion(model_trained(x_point), y_point) + else: + init_loss = criterion(model_init(x_point)[0], y_point) + trained_loss = criterion(model_trained(x_point)[0], y_point) scores.append(((init_loss - trained_loss).item(), is_in)) @@ -290,7 +371,7 @@ def main(): print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}") print(f"p[ε < {eps_lb}] < {hp['p_value']}") - + correct, total = evaluate_on(model_init, train_dl) print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") correct, total = evaluate_on(model_trained, test_dl) diff --git a/one_run_audit/student_model.py b/one_run_audit/student_model.py new file mode 100644 index 0000000..fa6d723 --- /dev/null +++ b/one_run_audit/student_model.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +# Create a similar student class where we return a tuple. We do not apply pooling after flattening. +class ModifiedLightNNCosine(nn.Module): + def __init__(self, num_classes=10): + super(ModifiedLightNNCosine, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.classifier = nn.Sequential( + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + x = self.features(x) + flattened_conv_output = torch.flatten(x, 1) + x = self.classifier(flattened_conv_output) + return x + +Model = ModifiedLightNNCosine