2024-11-30 13:35:38 -07:00
import os
2024-12-01 13:58:50 -07:00
import time
2024-11-30 13:35:38 -07:00
import torch
from torch import optim
from torch . optim . lr_scheduler import MultiStepLR
import torch . nn as nn
import numpy as np
import random
from utils import json_file_to_pyobj , get_loaders
from WideResNet import WideResNet
2024-11-30 15:41:11 -07:00
from tqdm import tqdm
2024-11-30 20:28:25 -07:00
import opacus
from opacus . validators import ModuleValidator
2024-11-30 23:54:48 -07:00
from opacus . utils . batch_memory_manager import BatchMemoryManager
2024-11-30 13:35:38 -07:00
def set_seed ( seed = 42 ) :
torch . backends . cudnn . deterministic = True
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed ( seed )
2024-12-01 13:58:50 -07:00
def train_no_cap ( net , epochs , data_loader , device , optimizer , criterion , scheduler , test_loader , log , logfile , checkpointFile ) :
best_test_set_accuracy = 0
for epoch in range ( epochs ) :
net . train ( )
#for i, data in tqdm(enumerate(train_loader, 0), leave=False):
for i , data in enumerate ( data_loader , 0 ) :
inputs , labels = data
inputs = inputs . to ( device )
labels = labels . to ( device )
optimizer . zero_grad ( )
wrn_outputs = net ( inputs )
outputs = wrn_outputs [ 0 ]
loss = criterion ( outputs , labels )
loss . backward ( )
optimizer . step ( )
scheduler . step ( )
if epoch % 10 == 0 or epoch == epochs - 1 :
with torch . no_grad ( ) :
correct = 0
total = 0
net . eval ( )
for data in test_loader :
images , labels = data
images = images . to ( device )
labels = labels . to ( device )
wrn_outputs = net ( images )
outputs = wrn_outputs [ 0 ]
_ , predicted = torch . max ( outputs . data , 1 )
total + = labels . size ( 0 )
correct + = ( predicted == labels ) . sum ( ) . item ( )
epoch_accuracy = correct / total
epoch_accuracy = round ( 100 * epoch_accuracy , 2 )
if log :
print ( ' Accuracy at epoch {} is {} % ' . format ( epoch + 1 , epoch_accuracy ) )
with open ( logfile , ' a ' ) as temp :
temp . write ( ' Accuracy at epoch {} is {} % \n ' . format ( epoch + 1 , epoch_accuracy ) )
if epoch_accuracy > best_test_set_accuracy :
best_test_set_accuracy = epoch_accuracy
torch . save ( net . state_dict ( ) , checkpointFile )
return best_test_set_accuracy
def _train_seed ( net , loaders , device , dataset , log = False , logfile = ' ' , epochs = 200 , norm = 1.0 ) :
2024-11-30 13:35:38 -07:00
train_loader , test_loader = loaders
2024-12-01 13:58:50 -07:00
dp_epsilon = None
2024-11-30 23:54:48 -07:00
dp_delta = 1e-5
2024-12-01 13:58:50 -07:00
checkpointFile = ' wrn- {} - {} e- {} d- {} n-dict.pt ' . format ( int ( time . time ( ) ) , dp_epsilon , dp_delta , norm )
2024-11-30 20:28:25 -07:00
if dp_epsilon is not None :
2024-11-30 23:54:48 -07:00
print ( f " DP epsilon = { dp_epsilon } , delta = { dp_delta } " )
2024-11-30 20:28:25 -07:00
#net = ModuleValidator.fix(net, replace_bn_with_in=True)
net = ModuleValidator . fix ( net )
ModuleValidator . validate ( net , strict = True )
2024-11-30 13:35:38 -07:00
criterion = nn . CrossEntropyLoss ( )
optimizer = optim . SGD ( net . parameters ( ) , lr = 0.1 , momentum = 0.9 , nesterov = True , weight_decay = 5e-4 )
scheduler = MultiStepLR ( optimizer , milestones = [ int ( elem * epochs ) for elem in [ 0.3 , 0.6 , 0.8 ] ] , gamma = 0.2 )
2024-11-30 23:54:48 -07:00
if dp_epsilon is not None :
privacy_engine = opacus . PrivacyEngine ( )
net , optimizer , train_loader = privacy_engine . make_private_with_epsilon (
module = net ,
optimizer = optimizer ,
data_loader = train_loader ,
epochs = epochs ,
target_epsilon = dp_epsilon ,
target_delta = dp_delta ,
max_grad_norm = norm ,
)
print ( f " Using sigma= { optimizer . noise_multiplier } and C= { 1.0 } , norm = { norm } " )
else :
print ( " Training without differential privacy " )
2024-11-30 20:28:25 -07:00
2024-11-30 15:41:11 -07:00
print ( f " Training with { epochs } epochs " )
2024-12-01 13:58:50 -07:00
if dp_epsilon is not None :
with BatchMemoryManager (
data_loader = train_loader ,
max_physical_batch_size = 1000 , # Roughly 12gb vram, uses 9.4
optimizer = optimizer
) as memory_safe_data_loader :
best_test_set_accuracy = train_no_cap ( net , epochs , memory_safe_data_loader , device , optimizer , criterion , scheduler , test_loader , log , logfile , checkpointFile )
else :
best_test_set_accuracy = train_no_cap ( net , epochs , train_loader , device , optimizer , criterion , scheduler , test_loader , log , logfile , checkpointFile )
2024-11-30 13:35:38 -07:00
return best_test_set_accuracy
def train ( args ) :
json_options = json_file_to_pyobj ( args . config )
training_configurations = json_options . training
wrn_depth = training_configurations . wrn_depth
wrn_width = training_configurations . wrn_width
dataset = training_configurations . dataset . lower ( )
2024-11-30 15:41:11 -07:00
#seeds = [int(seed) for seed in training_configurations.seeds]
2024-11-30 20:28:25 -07:00
seeds = [ int . from_bytes ( os . urandom ( 4 ) , byteorder = ' big ' ) ]
2024-11-30 13:35:38 -07:00
log = True if training_configurations . log . lower ( ) == ' true ' else False
if log :
2024-11-30 15:41:11 -07:00
logfile = ' WideResNet- {} - {} - {} - {} - {} .txt ' . format ( wrn_depth , wrn_width , training_configurations . dataset , training_configurations . batch_size , training_configurations . epochs )
2024-11-30 13:35:38 -07:00
with open ( logfile , ' w ' ) as temp :
2024-11-30 15:41:11 -07:00
temp . write ( ' WideResNet- {} - {} on {} {} batch for {} epochs \n ' . format ( wrn_depth , wrn_width , training_configurations . dataset , training_configurations . batch_size , training_configurations . epochs ) )
2024-11-30 13:35:38 -07:00
else :
logfile = ' '
checkpoint = True if training_configurations . checkpoint . lower ( ) == ' true ' else False
2024-11-30 15:41:11 -07:00
loaders = get_loaders ( dataset , training_configurations . batch_size )
2024-11-30 13:35:38 -07:00
2024-11-30 23:54:48 -07:00
if torch . cuda . is_available ( ) and args . cuda :
device = torch . device ( f ' cuda: { args . cuda } ' )
elif torch . cuda . is_available ( ) :
2024-11-30 13:35:38 -07:00
device = torch . device ( ' cuda:0 ' )
else :
device = torch . device ( ' cpu ' )
test_set_accuracies = [ ]
for seed in seeds :
set_seed ( seed )
if log :
with open ( logfile , ' a ' ) as temp :
temp . write ( ' ------------------- SEED {} ------------------- \n ' . format ( seed ) )
strides = [ 1 , 1 , 2 , 2 ]
net = WideResNet ( d = wrn_depth , k = wrn_width , n_classes = 10 , input_features = 3 , output_features = 16 , strides = strides )
net = net . to ( device )
2024-11-30 15:41:11 -07:00
epochs = training_configurations . epochs
2024-12-01 13:58:50 -07:00
best_test_set_accuracy = _train_seed ( net , loaders , device , dataset , log , logfile , epochs , args . norm )
2024-11-30 13:35:38 -07:00
if log :
with open ( logfile , ' a ' ) as temp :
temp . write ( ' Best test set accuracy of seed {} is {} \n ' . format ( seed , best_test_set_accuracy ) )
test_set_accuracies . append ( best_test_set_accuracy )
mean_test_set_accuracy , std_test_set_accuracy = np . mean ( test_set_accuracies ) , np . std ( test_set_accuracies )
if log :
with open ( logfile , ' a ' ) as temp :
temp . write ( ' Mean test set accuracy is {} with standard deviation equal to {} \n ' . format ( mean_test_set_accuracy , std_test_set_accuracy ) )
if __name__ == ' __main__ ' :
import argparse
os . environ [ " CUDA_DEVICE_ORDER " ] = " PCI_BUS_ID "
os . environ [ " CUDA_VISIBLE_DEVICES " ] = " 0, 1, 2, 3 "
parser = argparse . ArgumentParser ( description = ' WideResNet ' )
parser . add_argument ( ' -config ' , ' --config ' , help = ' Training Configurations ' , required = True )
2024-11-30 23:54:48 -07:00
parser . add_argument ( ' --norm ' , type = float , help = ' dpsgd norm clip factor ' , required = True )
parser . add_argument ( ' --cuda ' , type = int , help = ' gpu index ' , required = False )
2024-11-30 13:35:38 -07:00
args = parser . parse_args ( )
train ( args )