student stuff
This commit is contained in:
parent
c6f91352d9
commit
1200907c31
2 changed files with 119 additions and 9 deletions
|
@ -18,12 +18,61 @@ from opacus.validators import ModuleValidator
|
|||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
from WideResNet import WideResNet
|
||||
from equations import get_eps_audit
|
||||
|
||||
import student_model
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
DEVICE = None
|
||||
STUDENTBOOL = False
|
||||
|
||||
def train_knowledge_distillation(teacher, train_dl, epochs, device, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75):
|
||||
#instantiate istudent
|
||||
student = student_model.Model(num_classes=10).to(device)
|
||||
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
student_init = copy.deepcopy(student)
|
||||
student.to(device)
|
||||
teacher.to(device)
|
||||
teacher.eval() # Teacher set to evaluation mode
|
||||
student.train() # Student to train mode
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_dl:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
|
||||
with torch.no_grad():
|
||||
teacher_logits, _, _, _ = teacher(inputs)
|
||||
|
||||
# Forward pass with the student model
|
||||
student_logits = student(inputs)
|
||||
#Soften the student logits by applying softmax first and log() second
|
||||
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
|
||||
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
|
||||
|
||||
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
|
||||
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
|
||||
|
||||
# Calculate the true label loss
|
||||
label_loss = ce_loss(student_logits, labels)
|
||||
|
||||
# Weighted sum of the two losses
|
||||
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
if epoch % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")
|
||||
|
||||
return student_init, student
|
||||
|
||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||
seed = np.random.randint(0, 1e9)
|
||||
|
@ -90,7 +139,10 @@ def evaluate_on(model, dataloader):
|
|||
labels = labels.to(DEVICE)
|
||||
|
||||
wrn_outputs = model(images)
|
||||
outputs = wrn_outputs[0]
|
||||
if STUDENTBOOL:
|
||||
outputs = wrn_outputs
|
||||
else:
|
||||
outputs = wrn_outputs[0]
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
@ -209,6 +261,8 @@ def main():
|
|||
parser.add_argument('--cuda', type=int, help='gpu index', required=False)
|
||||
parser.add_argument('--epsilon', type=float, help='dp epsilon', required=False, default=None)
|
||||
parser.add_argument('--m', type=int, help='number of target points', required=True)
|
||||
parser.add_argument('--auditmodel', type=str, help='type of model to audit', default="teacher")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available() and args.cuda:
|
||||
|
@ -227,8 +281,8 @@ def main():
|
|||
"norm": args.norm,
|
||||
"batch_size": 4096,
|
||||
"epochs": 100,
|
||||
"k+": 300,
|
||||
"k-": 300,
|
||||
"k+": 200,
|
||||
"k-": 200,
|
||||
"p_value": 0.05,
|
||||
}
|
||||
|
||||
|
@ -250,11 +304,35 @@ def main():
|
|||
print(f"Got x_m: {x_m.shape}")
|
||||
print(f"Got y_m: {y_m.shape}")
|
||||
|
||||
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||
|
||||
# torch.save(model_init.state_dict(), "data/init_model.pt")
|
||||
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
|
||||
|
||||
if args.auditmodel == "student":
|
||||
global STUDENTBOOL
|
||||
teacher_init, teacher_trained = train(hp, train_dl, test_dl)
|
||||
STUDENTBOOL = True
|
||||
# torch.save(model_init.state_dict(), "data/init_model.pt")
|
||||
# torch.save(model_trained.state_dict(), "data/trained_model.pt")
|
||||
|
||||
|
||||
#train student model
|
||||
print("Training Student Model")
|
||||
model_init, model_trained = train_knowledge_distillation(
|
||||
teacher=teacher_trained,
|
||||
train_dl=train_dl,
|
||||
epochs=100,
|
||||
device=DEVICE,
|
||||
learning_rate=0.001,
|
||||
T=2,
|
||||
soft_target_loss_weight=0.25,
|
||||
ce_loss_weight=0.75,
|
||||
)
|
||||
stcorrect, sttotal = evaluate_on(model_trained, test_dl)
|
||||
stacc = stcorrect/sttotal*100
|
||||
print(f"Student Accuracy: {stacc}%")
|
||||
else:
|
||||
model_init, model_trained = train(hp, train_dl, test_dl)
|
||||
|
||||
scores = list()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
with torch.no_grad():
|
||||
|
@ -266,9 +344,12 @@ def main():
|
|||
x_point = x_m[i].unsqueeze(0)
|
||||
y_point = y_m[i].unsqueeze(0)
|
||||
is_in = S_m[i]
|
||||
|
||||
init_loss = criterion(model_init(x_point)[0], y_point)
|
||||
trained_loss = criterion(model_trained(x_point)[0], y_point)
|
||||
if STUDENTBOOL:
|
||||
init_loss = criterion(model_init(x_point), y_point)
|
||||
trained_loss = criterion(model_trained(x_point), y_point)
|
||||
else:
|
||||
init_loss = criterion(model_init(x_point)[0], y_point)
|
||||
trained_loss = criterion(model_trained(x_point)[0], y_point)
|
||||
|
||||
scores.append(((init_loss - trained_loss).item(), is_in))
|
||||
|
||||
|
|
29
one_run_audit/student_model.py
Normal file
29
one_run_audit/student_model.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
|
||||
class ModifiedLightNNCosine(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedLightNNCosine, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
flattened_conv_output = torch.flatten(x, 1)
|
||||
x = self.classifier(flattened_conv_output)
|
||||
return x
|
||||
|
||||
Model = ModifiedLightNNCosine
|
Loading…
Reference in a new issue