student stuff

This commit is contained in:
ARVP 2024-12-04 23:32:30 -07:00
parent c6f91352d9
commit 1200907c31
2 changed files with 119 additions and 9 deletions

View file

@ -18,12 +18,61 @@ from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet from WideResNet import WideResNet
from equations import get_eps_audit from equations import get_eps_audit
import student_model
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
DEVICE = None DEVICE = None
STUDENTBOOL = False
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
#instantiate istudent
student = student_model.Model(num_classes=10).to(device)
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
student_init = copy.deepcopy(student)
student.to(device)
teacher.to(device)
teacher.eval() # Teacher set to evaluation mode
student.train() # Student to train mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_dl:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits, _, _, _ = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
return student_init, student
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9) seed = np.random.randint(0, 1e9)
@ -90,7 +139,10 @@ def evaluate_on(model, dataloader):
labels = labels.to(DEVICE) labels = labels.to(DEVICE)
wrn_outputs = model(images) wrn_outputs = model(images)
outputs = wrn_outputs[0] if STUDENTBOOL:
outputs = wrn_outputs
else:
outputs = wrn_outputs[0]
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
total += labels.size(0) total += labels.size(0)
correct += (predicted == labels).sum().item() correct += (predicted == labels).sum().item()
@ -209,6 +261,8 @@ def main():
parser.add_argument('--cuda', type=int, help='gpu index', required=False) parser.add_argument('--cuda', type=int, help='gpu index', required=False)
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None) parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
parser.add_argument('--m', type=int, help='number of target points', required=True) parser.add_argument('--m', type=int, help='number of target points', required=True)
parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher")
args = parser.parse_args() args = parser.parse_args()
if torch.cuda.is_available() and args.cuda: if torch.cuda.is_available() and args.cuda:
@ -227,8 +281,8 @@ def main():
"norm": args.norm, "norm": args.norm,
"batch_size": 4096, "batch_size": 4096,
"epochs": 100, "epochs": 100,
"k+": 300, "k+": 200,
"k-": 300, "k-": 200,
"p_value": 0.05, "p_value": 0.05,
} }
@ -250,11 +304,35 @@ def main():
print(f"Got x_m: {x_m.shape}") print(f"Got x_m: {x_m.shape}")
print(f"Got y_m: {y_m.shape}") print(f"Got y_m: {y_m.shape}")
model_init, model_trained = train(hp, train_dl, test_dl)
# torch.save(model_init.state_dict(), "data/init_model.pt") # torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt") # torch.save(model_trained.state_dict(), "data/trained_model.pt")
if args.auditmodel == "student":
global STUDENTBOOL
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
STUDENTBOOL = True
# torch.save(model_init.state_dict(), "data/init_model.pt")
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
#train student model
print("Training Student Model")
model_init, model_trained = train_knowledge_distillation(
teacher=teacher_trained,
train_dl=train_dl,
epochs=100,
device=DEVICE,
learning_rate=0.001,
T=2,
soft_target_loss_weight=0.25,
ce_loss_weight=0.75,
)
stcorrect, sttotal = evaluate_on(model_trained, test_dl)
stacc = stcorrect/sttotal*100
print(f"Student Accuracy: {stacc}%")
else:
model_init, model_trained = train(hp, train_dl, test_dl)
scores = list() scores = list()
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
with torch.no_grad(): with torch.no_grad():
@ -266,9 +344,12 @@ def main():
x_point = x_m[i].unsqueeze(0) x_point = x_m[i].unsqueeze(0)
y_point = y_m[i].unsqueeze(0) y_point = y_m[i].unsqueeze(0)
is_in = S_m[i] is_in = S_m[i]
if STUDENTBOOL:
init_loss = criterion(model_init(x_point)[0], y_point) init_loss = criterion(model_init(x_point), y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point) trained_loss = criterion(model_trained(x_point), y_point)
else:
init_loss = criterion(model_init(x_point)[0], y_point)
trained_loss = criterion(model_trained(x_point)[0], y_point)
scores.append(((init_loss - trained_loss).item(), is_in)) scores.append(((init_loss - trained_loss).item(), is_in))
@ -290,7 +371,7 @@ def main():
print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}")
print(f"p[ε < {eps_lb}] < {hp['p_value']}") print(f"p[ε < {eps_lb}] < {hp['p_value']}")
correct, total = evaluate_on(model_init, train_dl) correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")
correct, total = evaluate_on(model_trained, test_dl) correct, total = evaluate_on(model_trained, test_dl)

View file

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
def __init__(self, num_classes=10):
super(ModifiedLightNNCosine, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x
Model = ModifiedLightNNCosine