forked from 626_privacy/tensorflow_privacy
Compare commits
No commits in common. "akemi/lira_2021_fix" and "master" have entirely different histories.
akemi/lira
...
master
5 changed files with 31 additions and 562 deletions
500
distil.py
500
distil.py
|
@ -1,500 +0,0 @@
|
|||
#based off of https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html#prerequisites
|
||||
import torchvision
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
|
||||
|
||||
transforms_cifar = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
# Loading the CIFAR-10 dataset:
|
||||
train_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
|
||||
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
|
||||
|
||||
from torch.utils.data import Subset
|
||||
num_images_to_keep = 2000
|
||||
train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000)))
|
||||
test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000)))
|
||||
#Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
|
||||
|
||||
|
||||
# Deeper neural network class to be used as teacher:
|
||||
class DeepNN(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(DeepNN, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(2048, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(512, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
# Lightweight neural network class to be used as student:
|
||||
class LightNN(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(LightNN, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def train(model, train_loader, epochs, learning_rate, device):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
model.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_loader:
|
||||
# inputs: A collection of batch_size images
|
||||
# labels: A vector of dimensionality batch_size with integers denoting class of each image
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
|
||||
# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
|
||||
# labels: The actual labels of the images. Vector of dimensionality batch_size
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
|
||||
|
||||
def test(model, test_loader, device):
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
outputs = model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||
return accuracy
|
||||
|
||||
torch.manual_seed(42)
|
||||
nn_deep = DeepNN(num_classes=10).to(device)
|
||||
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
|
||||
test_accuracy_deep = test(nn_deep, test_loader, device)
|
||||
|
||||
# Instantiate the lightweight network:
|
||||
torch.manual_seed(42)
|
||||
nn_light = LightNN(num_classes=10).to(device)
|
||||
|
||||
|
||||
torch.manual_seed(42)
|
||||
new_nn_light = LightNN(num_classes=10).to(device)
|
||||
|
||||
# Print the norm of the first layer of the initial lightweight model
|
||||
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
|
||||
# Print the norm of the first layer of the new lightweight model
|
||||
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
|
||||
|
||||
|
||||
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
|
||||
print(f"DeepNN parameters: {total_params_deep}")
|
||||
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
|
||||
print(f"LightNN parameters: {total_params_light}")
|
||||
|
||||
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
|
||||
test_accuracy_light_ce = test(nn_light, test_loader, device)
|
||||
|
||||
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
|
||||
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")
|
||||
|
||||
|
||||
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
|
||||
teacher.eval() # Teacher set to evaluation mode
|
||||
student.train() # Student to train mode
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher(inputs)
|
||||
|
||||
# Forward pass with the student model
|
||||
student_logits = student(inputs)
|
||||
|
||||
#Soften the student logits by applying softmax first and log() second
|
||||
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
|
||||
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
|
||||
|
||||
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
|
||||
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
|
||||
|
||||
# Calculate the true label loss
|
||||
label_loss = ce_loss(student_logits, labels)
|
||||
|
||||
# Weighted sum of the two losses
|
||||
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
|
||||
|
||||
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)
|
||||
|
||||
# Compare the student test accuracy with and without the teacher, after distillation
|
||||
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
|
||||
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
|
||||
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
|
||||
|
||||
|
||||
class ModifiedDeepNNCosine(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedDeepNNCosine, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(2048, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(512, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
flattened_conv_output = torch.flatten(x, 1)
|
||||
x = self.classifier(flattened_conv_output)
|
||||
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
|
||||
return x, flattened_conv_output_after_pooling
|
||||
|
||||
# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
|
||||
class ModifiedLightNNCosine(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedLightNNCosine, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
flattened_conv_output = torch.flatten(x, 1)
|
||||
x = self.classifier(flattened_conv_output)
|
||||
return x, flattened_conv_output
|
||||
|
||||
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
|
||||
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
|
||||
modified_nn_deep.load_state_dict(nn_deep.state_dict())
|
||||
|
||||
# Once again ensure the norm of the first layer is the same for both networks
|
||||
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
|
||||
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())
|
||||
|
||||
# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
|
||||
torch.manual_seed(42)
|
||||
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
|
||||
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())
|
||||
|
||||
# Create a sample input tensor
|
||||
sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32
|
||||
|
||||
# Pass the input through the student
|
||||
logits, hidden_representation = modified_nn_light(sample_input)
|
||||
|
||||
# Print the shapes of the tensors
|
||||
print("Student logits shape:", logits.shape) # batch_size x total_classes
|
||||
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
|
||||
|
||||
# Pass the input through the teacher
|
||||
logits, hidden_representation = modified_nn_deep(sample_input)
|
||||
|
||||
# Print the shapes of the tensors
|
||||
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
|
||||
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
|
||||
|
||||
def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
cosine_loss = nn.CosineEmbeddingLoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
|
||||
teacher.to(device)
|
||||
student.to(device)
|
||||
teacher.eval() # Teacher set to evaluation mode
|
||||
student.train() # Student to train mode
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass with the teacher model and keep only the hidden representation
|
||||
with torch.no_grad():
|
||||
_, teacher_hidden_representation = teacher(inputs)
|
||||
|
||||
# Forward pass with the student model
|
||||
student_logits, student_hidden_representation = student(inputs)
|
||||
|
||||
# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
|
||||
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
|
||||
|
||||
# Calculate the true label loss
|
||||
label_loss = ce_loss(student_logits, labels)
|
||||
|
||||
# Weighted sum of the two losses
|
||||
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
|
||||
|
||||
def test_multiple_outputs(model, test_loader, device):
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
outputs, _ = model(inputs) # Disregard the second tensor of the tuple
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
print(f"Test Accuracy: {accuracy:.2f}%")
|
||||
return accuracy
|
||||
|
||||
# Train and test the lightweight network with cross entropy loss
|
||||
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
|
||||
|
||||
|
||||
# Pass the sample input only from the convolutional feature extractor
|
||||
convolutional_fe_output_student = nn_light.features(sample_input)
|
||||
convolutional_fe_output_teacher = nn_deep.features(sample_input)
|
||||
|
||||
# Print their shapes
|
||||
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
|
||||
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)
|
||||
|
||||
class ModifiedDeepNNRegressor(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedDeepNNRegressor, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(2048, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(512, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
conv_feature_map = x
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x, conv_feature_map
|
||||
|
||||
class ModifiedLightNNRegressor(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(ModifiedLightNNRegressor, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
# Include an extra regressor (in our case linear)
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
regressor_output = self.regressor(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x, regressor_output
|
||||
|
||||
|
||||
def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
mse_loss = nn.MSELoss()
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
|
||||
teacher.to(device)
|
||||
student.to(device)
|
||||
teacher.eval() # Teacher set to evaluation mode
|
||||
student.train() # Student to train mode
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Again ignore teacher logits
|
||||
with torch.no_grad():
|
||||
_, teacher_feature_map = teacher(inputs)
|
||||
|
||||
# Forward pass with the student model
|
||||
student_logits, regressor_feature_map = student(inputs)
|
||||
|
||||
# Calculate the loss
|
||||
hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)
|
||||
|
||||
# Calculate the true label loss
|
||||
label_loss = ce_loss(student_logits, labels)
|
||||
|
||||
# Weighted sum of the two losses
|
||||
loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
|
||||
|
||||
|
||||
# Initialize a ModifiedLightNNRegressor
|
||||
torch.manual_seed(42)
|
||||
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)
|
||||
|
||||
# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
|
||||
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
|
||||
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())
|
||||
|
||||
# Train and test once again
|
||||
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
|
||||
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
|
||||
|
||||
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
|
||||
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
|
||||
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
|
||||
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%")
|
||||
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")
|
||||
|
||||
#
|
||||
# For more information, see:
|
||||
#
|
||||
# - [Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a
|
||||
# neural network. In: Neural Information Processing System Deep
|
||||
# Learning Workshop (2015)](https://arxiv.org/abs/1503.02531)
|
||||
# - [Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C.,
|
||||
# Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of
|
||||
# the International Conference on Learning
|
||||
# Representations (2015)](https://arxiv.org/abs/1412.6550)
|
||||
#
|
|
@ -8,19 +8,14 @@ by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and F
|
|||
|
||||
### INSTALLING
|
||||
|
||||
You will need to install fairly standard dependencies and python 3.11 minimum.
|
||||
You will need to install fairly standard dependencies
|
||||
|
||||
```
|
||||
pip install scipy scikit-learn numpy matplotlib tensorflow tensorflow_datasets
|
||||
`pip install scipy, sklearn, numpy, matplotlib`
|
||||
|
||||
# This needs to be separate
|
||||
pip install objax
|
||||
|
||||
RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
|
||||
JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
|
||||
pip uninstall -y jaxlib
|
||||
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION
|
||||
```
|
||||
and also some machine learning framework to train models. We train our models
|
||||
with JAX + ObJAX so you will need to follow build instructions for that
|
||||
https://github.com/google/objax
|
||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||
|
||||
### RUNNING THE CODE
|
||||
|
||||
|
|
|
@ -11,41 +11,20 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
mkdir -p logs
|
||||
|
||||
SECONDS=0
|
||||
|
||||
echo '======== 1 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 2>&1 | tee logs/log_0
|
||||
echo '======== 2 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 2>&1 | tee logs/log_1
|
||||
echo '======== 3 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 2>&1 | tee logs/log_2
|
||||
echo '======== 4 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 2>&1 | tee logs/log_3
|
||||
echo '======== 5 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 2>&1 | tee logs/log_4
|
||||
echo '======== 6 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 2>&1 | tee logs/log_5
|
||||
echo '======== 7 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 2>&1 | tee logs/log_6
|
||||
echo '======== 8 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 2>&1 | tee logs/log_7
|
||||
echo '======== 9 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 2>&1 | tee logs/log_8
|
||||
echo '======== 10 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 2>&1 | tee logs/log_9
|
||||
echo '======== 11 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 2>&1 | tee logs/log_10
|
||||
echo '======== 12 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 2>&1 | tee logs/log_11
|
||||
echo '======== 13 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 2>&1 | tee logs/log_12
|
||||
echo '======== 14 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 2>&1 | tee logs/log_13
|
||||
echo '======== 15 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 2>&1 | tee logs/log_14
|
||||
echo '======== 16 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 2>&1 | tee logs/log_15
|
||||
|
||||
echo "COMPLETE: Took ${SECONDS} seconds"
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
||||
|
|
|
@ -11,25 +11,22 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
mkdir -p logs
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
||||
|
|
|
@ -66,9 +66,7 @@ class TrainLoop(objax.Module):
|
|||
for k, v in kv.items():
|
||||
if jn.isnan(v):
|
||||
raise ValueError('NaN, try reducing learning rate', k)
|
||||
if summary is not None and v.ndim == 1:
|
||||
summary.scalar(k, float(v[0]))
|
||||
elif summary is not None:
|
||||
if summary is not None:
|
||||
summary.scalar(k, float(v))
|
||||
|
||||
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
|
||||
|
|
Loading…
Reference in a new issue