mia_on_model_distillation/wresnet-pytorch
2024-12-01 15:33:03 -07:00
..
src changed to same data loaders as train.py and added saving student model 2024-12-01 15:33:03 -07:00
README.md Torchlira: add wresnet-16 2024-11-30 13:35:38 -07:00

Wide Residual Networks in PyTorch

Implementation of Wide Residual Networks (WRNs) in PyTorch.

How to train WRNs

At the moment the CIFAR10 and SVHN datasets are fully supported, with specific augmentations for CIFAR10 drawn from related literature and mean/std normalization for SVHN, and multistep learning rate scheduling for both cases. Training is executed through JSON configuration files, which you can modify or extend to support other configurations of WRNs and/or extend datasets etc.

Example Runs

Train a WideResNet-16-1 on CIFAR10:

python train.py --config configs/WRN-16-1-scratch-CIFAR10.json

Train a WideResNet-40-2 on SVHN:

python train.py --config configs/WRN-40-2-scratch-SVHN.json

Results

This work has been tested with 4 variants of WRNs. When setting the seed generator equal to 0, you should expect a test-set accuracy performance close to the following values:

Model CIFAR10 SVHN
WRN-16-1 90.97% 95.52%
WRN-16-2 94.21% 96.17%
WRN-40-1 93.52% 96.07%
WRN-40-2 95.14% 96.14%

Notes

The motivation for originally implementing WRNs in PyTorch was this NeurIPS reproducibility project, where WRNs were used as the main framework for few-shot and zero-shot knowledge transfer.