O1: with student model

This commit is contained in:
Akemi Izuko 2024-12-05 01:04:35 -07:00
parent a697d4687c
commit ebfbd88332
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
2 changed files with 185 additions and 8 deletions

View file

@ -19,6 +19,7 @@ from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet
from equations import get_eps_audit
import student_model
import warnings
warnings.filterwarnings("ignore")
@ -185,9 +186,10 @@ def get_dataloaders3(m=1000, train_batch_size=128, test_batch_size=10):
td2 = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
td = ConcatDataset([td, td2])
train_dl = DataLoader(td, batch_size=train_batch_size, shuffle=True, num_workers=4)
pure_train_dl = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=test_batch_size, shuffle=True, num_workers=4)
return train_dl, test_dl, adv_points, adv_labels, S
return train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S
def evaluate_on(model, dataloader):
@ -203,7 +205,11 @@ def evaluate_on(model, dataloader):
labels = labels.to(DEVICE)
wrn_outputs = model(images)
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
@ -211,6 +217,53 @@ def evaluate_on(model, dataloader):
return correct, total
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent
student = student_model.Model(num_classes=10).to(device)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
student_init = copy.deepcopy(student)
student.to(device)
teacher.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_dl:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits, _, _, _ = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
return student_init, student
def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
best_test_set_accuracy = 0
@ -224,7 +277,10 @@ def train_no_cap(model, hp, train_dl, test_dl, optimizer, criterion, scheduler):
optimizer.zero_grad()
wrn_outputs = model(inputs)
if len(wrn_outputs) == 4:
outputs = wrn_outputs[0]
else:
outputs = wrn_outputs
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
@ -287,6 +343,67 @@ def load(hp, model_path, train_dl):
def train_small(hp, train_dl, test_dl):
model = student_model.Model(num_classes=10).to(DEVICE)
model = model.to(DEVICE)
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=True)
model_init = copy.deepcopy(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(
optimizer,
milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
gamma=0.2
)
print(f"Training raw (no distill) STUDENT with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dl,
epochs=hp['epochs'],
target_epsilon=hp['epsilon'],
target_delta=hp['delta'],
max_grad_norm=hp['norm'],
)
print(f"DP epsilon = {hp['epsilon']}, delta = {hp['delta']}")
print(f"Using sigma={optimizer.noise_multiplier} and C = norm = {hp['norm']}")
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=2000, # 1000 ~= 9.4GB vram
optimizer=optimizer
) as memory_safe_data_loader:
best_test_set_accuracy = train_no_cap(
model,
hp,
memory_safe_data_loader,
test_dl,
optimizer,
criterion,
scheduler,
)
else:
print("Training without differential privacy")
best_test_set_accuracy = train_no_cap(
model,
hp,
train_dl,
test_dl,
optimizer,
criterion,
scheduler,
)
return model_init, model
def train(hp, train_dl, test_dl):
model = WideResNet(
d=hp["wrn_depth"],
@ -373,6 +490,8 @@ def main():
parser.add_argument('--k', type=int, help='number of symmetric guesses', required=True)
parser.add_argument('--epochs', type=int, help='number of epochs', required=True)
parser.add_argument('--load', type=Path, help='number of epochs', required=False)
parser.add_argument('--studentraw', action='store_true', help='train a raw student', required=False)
parser.add_argument('--distill', action='store_true', help='train a raw student', required=False)
args = parser.parse_args()
if torch.cuda.is_available() and args.cuda:
@ -408,11 +527,35 @@ def main():
))
if args.load:
train_dl, test_dl, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
train_dl, test_dl, ____, _, __, ___ = get_dataloaders3(hp['target_points'], hp['batch_size'])
model_init, model_trained, adv_points, adv_labels, S = load(hp, args.load, train_dl)
test_dl = None
else:
train_dl, test_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
train_dl, test_dl, pure_train_dl, adv_points, adv_labels, S = get_dataloaders3(hp['target_points'], hp['batch_size'])
if args.studentraw:
print("=========================")
print("Training a raw student model")
print("=========================")
model_init, model_trained = train_small(hp, train_dl, test_dl)
elif args.distill:
print("=========================")
print("Training a distilled student model")
print("=========================")
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=train_dl,
epochs=100,
device=DEVICE,
learning_rate=0.001,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75,
)
else:
print("=========================")
print("Training teacher model")
print("=========================")
model_init, model_trained = train(hp, train_dl, test_dl)
np.save("data/adv_points", adv_points)
@ -433,8 +576,13 @@ def main():
y_point = y_m[i].unsqueeze(0).to(DEVICE)
is_in = S[i]
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)
wrn_outputs = model_init(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
init_loss = criterion(outputs, y_point)
wrn_outputs = model_trained(x_point)
outputs = wrn_outputs[0] if len(wrn_outputs) == 4 else wrn_outputs
trained_loss = criterion(outputs, y_point)
scores.append(((init_loss - trained_loss).item(), is_in))

View file

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x
Model = ModifiedLightNNCosine