O1: with student model
This commit is contained in:
parent
a697d4687c
commit
ebfbd88332
2 changed files with 185 additions and 8 deletions
|
@ -19,6 +19,7 @@ from opacus.validators import ModuleValidator
|
||||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||||
from WideResNet import WideResNet
|
from WideResNet import WideResNet
|
||||||
from equations import get_eps_audit
|
from equations import get_eps_audit
|
||||||
|
import student_model
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
@ -185,9 +186,10 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
|
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
|
||||||
td = ConcatDataset([td, td2])
|
td = ConcatDataset([td, td2])
|
||||||
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
|
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
|
||||||
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
|
||||||
|
|
||||||
return train_dl, test_dl, adv_points, adv_labels, S
|
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on(model, dataloader):
|
def evaluate_on(model, dataloader):
|
||||||
|
@ -203,7 +205,11 @@ def evaluate_on(model, dataloader):
|
||||||
labels = labels.to(DEVICE)
|
labels = labels.to(DEVICE)
|
||||||
|
|
||||||
wrn_outputs = model(images)
|
wrn_outputs = model(images)
|
||||||
|
if len(wrn_outputs) == 4:
|
||||||
outputs = wrn_outputs[0]
|
outputs = wrn_outputs[0]
|
||||||
|
else:
|
||||||
|
outputs = wrn_outputs
|
||||||
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
total += labels.size(0)
|
total += labels.size(0)
|
||||||
correct += (predicted == labels).sum().item()
|
correct += (predicted == labels).sum().item()
|
||||||
|
@ -211,6 +217,53 @@ def evaluate_on(model, dataloader):
|
||||||
return correct, total
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
|
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
|
||||||
|
#instantiate istudent
|
||||||
|
student = student_model.Model(num_classes=10).to(device)
|
||||||
|
|
||||||
|
ce_loss = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||||
|
student_init = copy.deepcopy(student)
|
||||||
|
student.to(device)
|
||||||
|
teacher.to(device)
|
||||||
|
teacher.eval() # Teacher set to evaluation mode
|
||||||
|
student.train() # Student to train mode
|
||||||
|
for epoch in range(epochs):
|
||||||
|
running_loss = 0.0
|
||||||
|
for inputs, labels in train_dl:
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_logits, _, _, _ = teacher(inputs)
|
||||||
|
|
||||||
|
# Forward pass with the student model
|
||||||
|
student_logits = student(inputs)
|
||||||
|
#Soften the student logits by applying softmax first and log() second
|
||||||
|
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
|
||||||
|
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
|
||||||
|
|
||||||
|
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
|
||||||
|
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
|
||||||
|
|
||||||
|
# Calculate the true label loss
|
||||||
|
label_loss = ce_loss(student_logits, labels)
|
||||||
|
|
||||||
|
# Weighted sum of the two losses
|
||||||
|
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||||
|
|
||||||
|
return student_init, student
|
||||||
|
|
||||||
|
|
||||||
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
|
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
|
||||||
best_test_set_accuracy = 0
|
best_test_set_accuracy = 0
|
||||||
|
|
||||||
|
@ -224,7 +277,10 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
wrn_outputs = model(inputs)
|
wrn_outputs = model(inputs)
|
||||||
|
if len(wrn_outputs) == 4:
|
||||||
outputs = wrn_outputs[0]
|
outputs = wrn_outputs[0]
|
||||||
|
else:
|
||||||
|
outputs = wrn_outputs
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -287,6 +343,67 @@ def load(hp, model_path, train_dl):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train_small(hp, train_dl, test_dl):
|
||||||
|
model = student_model.Model(num_classes=10).to(DEVICE)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
model = ModuleValidator.fix(model)
|
||||||
|
ModuleValidator.validate(model, strict=True)
|
||||||
|
|
||||||
|
model_init = copy.deepcopy(model)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
scheduler = MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
|
||||||
|
gamma=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs")
|
||||||
|
|
||||||
|
if hp['epsilon'] is not None:
|
||||||
|
privacy_engine = opacus.PrivacyEngine()
|
||||||
|
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
||||||
|
module=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_loader=train_dl,
|
||||||
|
epochs=hp['epochs'],
|
||||||
|
target_epsilon=hp['epsilon'],
|
||||||
|
target_delta=hp['delta'],
|
||||||
|
max_grad_norm=hp['norm'],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
|
||||||
|
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
|
||||||
|
|
||||||
|
with BatchMemoryManager(
|
||||||
|
data_loader=train_loader,
|
||||||
|
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
|
||||||
|
optimizer=optimizer
|
||||||
|
) as memory_safe_data_loader:
|
||||||
|
best_test_set_accuracy = train_no_cap(
|
||||||
|
model,
|
||||||
|
hp,
|
||||||
|
memory_safe_data_loader,
|
||||||
|
test_dl,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Training without differential privacy")
|
||||||
|
best_test_set_accuracy = train_no_cap(
|
||||||
|
model,
|
||||||
|
hp,
|
||||||
|
train_dl,
|
||||||
|
test_dl,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_init, model
|
||||||
|
|
||||||
def train(hp, train_dl, test_dl):
|
def train(hp, train_dl, test_dl):
|
||||||
model = WideResNet(
|
model = WideResNet(
|
||||||
d=hp["wrn_depth"],
|
d=hp["wrn_depth"],
|
||||||
|
@ -373,6 +490,8 @@ def main():
|
||||||
parser.add_argument('--k', type=int, help='number of symmetric guesses', required=True)
|
parser.add_argument('--k', type=int, help='number of symmetric guesses', required=True)
|
||||||
parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
|
parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
|
||||||
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
|
||||||
|
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
|
||||||
|
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if torch.cuda.is_available() and args.cuda:
|
if torch.cuda.is_available() and args.cuda:
|
||||||
|
@ -408,11 +527,35 @@ def main():
|
||||||
))
|
))
|
||||||
|
|
||||||
if args.load:
|
if args.load:
|
||||||
train_dl, test_dl, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||||
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
|
||||||
test_dl = None
|
test_dl = None
|
||||||
else:
|
else:
|
||||||
train_dl, test_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
|
||||||
|
if args.studentraw:
|
||||||
|
print("=========================")
|
||||||
|
print("Training a raw student model")
|
||||||
|
print("=========================")
|
||||||
|
model_init, model_trained = train_small(hp, train_dl, test_dl)
|
||||||
|
elif args.distill:
|
||||||
|
print("=========================")
|
||||||
|
print("Training a distilled student model")
|
||||||
|
print("=========================")
|
||||||
|
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
|
||||||
|
model_init, model_trained = train_knowledge_distillation(
|
||||||
|
teacher=teacher_trained,
|
||||||
|
train_dl=train_dl,
|
||||||
|
epochs=100,
|
||||||
|
device=DEVICE,
|
||||||
|
learning_rate=0.001,
|
||||||
|
T=2,
|
||||||
|
soft_target_loss_weight=0.25,
|
||||||
|
ce_loss_weight=0.75,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("=========================")
|
||||||
|
print("Training teacher model")
|
||||||
|
print("=========================")
|
||||||
model_init, model_trained = train(hp, train_dl, test_dl)
|
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||||
|
|
||||||
np.save("data/adv_points", adv_points)
|
np.save("data/adv_points", adv_points)
|
||||||
|
@ -433,8 +576,13 @@ def main():
|
||||||
y_point = y_m[i].unsqueeze(0).to(DEVICE)
|
y_point = y_m[i].unsqueeze(0).to(DEVICE)
|
||||||
is_in = S[i]
|
is_in = S[i]
|
||||||
|
|
||||||
init_loss = criterion(model_init(x_point)[0], y_point)
|
wrn_outputs = model_init(x_point)
|
||||||
trained_loss = criterion(model_trained(x_point)[0], y_point)
|
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
|
||||||
|
init_loss = criterion(outputs, y_point)
|
||||||
|
|
||||||
|
wrn_outputs = model_trained(x_point)
|
||||||
|
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
|
||||||
|
trained_loss = criterion(outputs, y_point)
|
||||||
|
|
||||||
scores.append(((init_loss - trained_loss).item(), is_in))
|
scores.append(((init_loss - trained_loss).item(), is_in))
|
||||||
|
|
||||||
|
|
29
one_run_audit/student_model.py
Normal file
29
one_run_audit/student_model.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
|
||||||
|
class ModifiedLightNNCosine(nn.Module):
|
||||||
|
def __init__(self, num_classes=10):
|
||||||
|
super(ModifiedLightNNCosine, self).__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(1024, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(256, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
flattened_conv_output = torch.flatten(x, 1)
|
||||||
|
x = self.classifier(flattened_conv_output)
|
||||||
|
return x
|
||||||
|
|
||||||
|
Model = ModifiedLightNNCosine
|
Loading…
Reference in a new issue