Pytorch version of lira
This commit is contained in:
parent
3467c25882
commit
e95444fa74
12 changed files with 912 additions and 0 deletions
7
lira-pytorch/.gitignore
vendored
Normal file
7
lira-pytorch/.gitignore
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
__pycache__
|
||||||
|
exp
|
||||||
|
logs
|
||||||
|
slurm
|
||||||
|
gpu.sh
|
||||||
|
*.out
|
||||||
|
|
202
lira-pytorch/LICENSE
Normal file
202
lira-pytorch/LICENSE
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
42
lira-pytorch/README.md
Normal file
42
lira-pytorch/README.md
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
# Likelihood Ration Attack (LiRA) in PyTorch
|
||||||
|
Implementation of the original [LiRA](https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021) using PyTorch. To run the code, first create an environment with the `env.yml` file. Then run the following command to train the models and run the LiRA attack:
|
||||||
|
|
||||||
|
```
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The output will generate and store a log-scale FPR-TPR curve as `./fprtpr.png` with the TPR@0.1%FPR in the output log.
|
||||||
|
|
||||||
|
## Results on CIFAR10
|
||||||
|
|
||||||
|
Using 16 shadow models trained with `ResNet18 and 2 augmented queries`:
|
||||||
|
|
||||||
|
![roc](figures/fprtpr_resnet18.png)
|
||||||
|
```
|
||||||
|
Attack Ours (online)
|
||||||
|
AUC 0.6548, Accuracy 0.6015, TPR@0.1%FPR of 0.0068
|
||||||
|
Attack Ours (online, fixed variance)
|
||||||
|
AUC 0.6700, Accuracy 0.6042, TPR@0.1%FPR of 0.0464
|
||||||
|
Attack Ours (offline)
|
||||||
|
AUC 0.5250, Accuracy 0.5353, TPR@0.1%FPR of 0.0041
|
||||||
|
Attack Ours (offline, fixed variance)
|
||||||
|
AUC 0.5270, Accuracy 0.5380, TPR@0.1%FPR of 0.0192
|
||||||
|
Attack Global threshold
|
||||||
|
AUC 0.5948, Accuracy 0.5869, TPR@0.1%FPR of 0.0006
|
||||||
|
```
|
||||||
|
|
||||||
|
Using 16 shadow models trained with `WideResNet28-10 and 2 augmented queries`:
|
||||||
|
|
||||||
|
![roc](figures/fprtpr_wideresnet.png)
|
||||||
|
```
|
||||||
|
Attack Ours (online)
|
||||||
|
AUC 0.6834, Accuracy 0.6152, TPR@0.1%FPR of 0.0240
|
||||||
|
Attack Ours (online, fixed variance)
|
||||||
|
AUC 0.7017, Accuracy 0.6240, TPR@0.1%FPR of 0.0704
|
||||||
|
Attack Ours (offline)
|
||||||
|
AUC 0.5621, Accuracy 0.5649, TPR@0.1%FPR of 0.0140
|
||||||
|
Attack Ours (offline, fixed variance)
|
||||||
|
AUC 0.5698, Accuracy 0.5628, TPR@0.1%FPR of 0.0370
|
||||||
|
Attack Global threshold
|
||||||
|
AUC 0.6016, Accuracy 0.5977, TPR@0.1%FPR of 0.0013
|
||||||
|
```
|
35
lira-pytorch/env.yml
Normal file
35
lira-pytorch/env.yml
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Minimal environment for starting a project using conda/mamba:
|
||||||
|
# conda env create -n ENVNAME --file ENV.yml
|
||||||
|
|
||||||
|
name: template
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.8.6
|
||||||
|
- pip
|
||||||
|
- pytest
|
||||||
|
- numpy
|
||||||
|
- scipy
|
||||||
|
- scikit-learn
|
||||||
|
- matplotlib
|
||||||
|
- pandas
|
||||||
|
- tqdm
|
||||||
|
- wandb
|
||||||
|
- jupyterlab
|
||||||
|
- jupyter
|
||||||
|
- ipykernel
|
||||||
|
- pytorch
|
||||||
|
- torchvision
|
||||||
|
- torchaudio
|
||||||
|
- pytorch-cuda=12.1
|
||||||
|
- tqdm
|
||||||
|
- pytorch-lightning
|
||||||
|
- lightning-bolts
|
||||||
|
- torchmetrics
|
||||||
|
|
||||||
|
# Install packages with pip
|
||||||
|
# - pip:
|
||||||
|
# - ray[tune]
|
BIN
lira-pytorch/figures/fprtpr_resnet18.png
Normal file
BIN
lira-pytorch/figures/fprtpr_resnet18.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
BIN
lira-pytorch/figures/fprtpr_wideresnet.png
Normal file
BIN
lira-pytorch/figures/fprtpr_wideresnet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
75
lira-pytorch/inference.py
Normal file
75
lira-pytorch/inference.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
# PyTorch implementation of
|
||||||
|
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
|
||||||
|
#
|
||||||
|
# author: Chenxiang Zhang (orientino)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from wide_resnet import WideResNet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--n_queries", default=2, type=int)
|
||||||
|
parser.add_argument("--model", default="resnet18", type=str)
|
||||||
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run():
|
||||||
|
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
datadir = Path().home() / "opt/data/cifar"
|
||||||
|
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
|
||||||
|
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# Infer the logits with multiple queries
|
||||||
|
for path in os.listdir(args.savedir):
|
||||||
|
if args.model == "wresnet28-2":
|
||||||
|
m = WideResNet(28, 2, 0.0, 10)
|
||||||
|
elif args.model == "wresnet28-10":
|
||||||
|
m = WideResNet(28, 10, 0.3, 10)
|
||||||
|
elif args.model == "resnet18":
|
||||||
|
m = models.resnet18(weights=None, num_classes=10)
|
||||||
|
m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
m.maxpool = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
|
||||||
|
m.to(DEVICE)
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
logits_n = []
|
||||||
|
for i in range(args.n_queries):
|
||||||
|
logits = []
|
||||||
|
for x, _ in tqdm(train_dl):
|
||||||
|
x = x.to(DEVICE)
|
||||||
|
outputs = m(x)
|
||||||
|
logits.append(outputs.cpu().numpy())
|
||||||
|
logits_n.append(np.concatenate(logits))
|
||||||
|
logits_n = np.stack(logits_n, axis=1)
|
||||||
|
print(logits_n.shape)
|
||||||
|
|
||||||
|
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
205
lira-pytorch/plot.py
Normal file
205
lira-pytorch/plot.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Modified copy by Chenxiang Zhang (orientino) of the original:
|
||||||
|
# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import scipy.stats
|
||||||
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
|
||||||
|
matplotlib.rcParams["pdf.fonttype"] = 42
|
||||||
|
matplotlib.rcParams["ps.fonttype"] = 42
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def sweep(score, x):
|
||||||
|
"""
|
||||||
|
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
|
||||||
|
"""
|
||||||
|
fpr, tpr, _ = roc_curve(x, -score)
|
||||||
|
acc = np.max(1 - (fpr + (1 - tpr)) / 2)
|
||||||
|
return fpr, tpr, auc(fpr, tpr), acc
|
||||||
|
|
||||||
|
|
||||||
|
def load_data():
|
||||||
|
"""
|
||||||
|
Load our saved scores and then put them into a big matrix.
|
||||||
|
"""
|
||||||
|
global scores, keep
|
||||||
|
scores = []
|
||||||
|
keep = []
|
||||||
|
|
||||||
|
for path in os.listdir(args.savedir):
|
||||||
|
scores.append(np.load(os.path.join(args.savedir, path, "scores.npy")))
|
||||||
|
keep.append(np.load(os.path.join(args.savedir, path, "keep.npy")))
|
||||||
|
scores = np.array(scores)
|
||||||
|
keep = np.array(keep)
|
||||||
|
|
||||||
|
return scores, keep
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False):
|
||||||
|
"""
|
||||||
|
Fit a two predictive models using keep and scores in order to predict
|
||||||
|
if the examples in check_scores were training data or not, using the
|
||||||
|
ground truth answer from check_keep.
|
||||||
|
"""
|
||||||
|
dat_in = []
|
||||||
|
dat_out = []
|
||||||
|
|
||||||
|
for j in range(scores.shape[1]):
|
||||||
|
dat_in.append(scores[keep[:, j], j, :])
|
||||||
|
dat_out.append(scores[~keep[:, j], j, :])
|
||||||
|
|
||||||
|
in_size = min(min(map(len, dat_in)), in_size)
|
||||||
|
out_size = min(min(map(len, dat_out)), out_size)
|
||||||
|
|
||||||
|
dat_in = np.array([x[:in_size] for x in dat_in])
|
||||||
|
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||||
|
|
||||||
|
mean_in = np.median(dat_in, 1)
|
||||||
|
mean_out = np.median(dat_out, 1)
|
||||||
|
|
||||||
|
if fix_variance:
|
||||||
|
std_in = np.std(dat_in)
|
||||||
|
std_out = np.std(dat_in)
|
||||||
|
else:
|
||||||
|
std_in = np.std(dat_in, 1)
|
||||||
|
std_out = np.std(dat_out, 1)
|
||||||
|
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in + 1e-30)
|
||||||
|
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30)
|
||||||
|
score = pr_in - pr_out
|
||||||
|
|
||||||
|
prediction.extend(score.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False):
|
||||||
|
"""
|
||||||
|
Fit a single predictive model using keep and scores in order to predict
|
||||||
|
if the examples in check_scores were training data or not, using the
|
||||||
|
ground truth answer from check_keep.
|
||||||
|
"""
|
||||||
|
dat_in = []
|
||||||
|
dat_out = []
|
||||||
|
|
||||||
|
for j in range(scores.shape[1]):
|
||||||
|
dat_in.append(scores[keep[:, j], j, :])
|
||||||
|
dat_out.append(scores[~keep[:, j], j, :])
|
||||||
|
|
||||||
|
out_size = min(min(map(len, dat_out)), out_size)
|
||||||
|
|
||||||
|
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||||
|
|
||||||
|
mean_out = np.median(dat_out, 1)
|
||||||
|
|
||||||
|
if fix_variance:
|
||||||
|
std_out = np.std(dat_out)
|
||||||
|
else:
|
||||||
|
std_out = np.std(dat_out, 1)
|
||||||
|
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
score = scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30)
|
||||||
|
|
||||||
|
prediction.extend(score.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_global(keep, scores, check_keep, check_scores):
|
||||||
|
"""
|
||||||
|
Use a simple global threshold sweep to predict if the examples in
|
||||||
|
check_scores were training data or not, using the ground truth answer from
|
||||||
|
check_keep.
|
||||||
|
"""
|
||||||
|
prediction = []
|
||||||
|
answers = []
|
||||||
|
for ans, sc in zip(check_keep, check_scores):
|
||||||
|
prediction.extend(-sc.mean(1))
|
||||||
|
answers.extend(ans)
|
||||||
|
|
||||||
|
return prediction, answers
|
||||||
|
|
||||||
|
|
||||||
|
def do_plot(fn, keep, scores, ntest, legend="", metric="auc", sweep_fn=sweep, **plot_kwargs):
|
||||||
|
"""
|
||||||
|
Generate the ROC curves by using ntest models as test models and the rest to train.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prediction, answers = fn(keep[:-ntest], scores[:-ntest], keep[-ntest:], scores[-ntest:])
|
||||||
|
|
||||||
|
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
|
||||||
|
|
||||||
|
low = tpr[np.where(fpr < 0.001)[0][-1]]
|
||||||
|
|
||||||
|
print("Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f" % (legend, auc, acc, low))
|
||||||
|
|
||||||
|
metric_text = ""
|
||||||
|
if metric == "auc":
|
||||||
|
metric_text = "auc=%.3f" % auc
|
||||||
|
elif metric == "acc":
|
||||||
|
metric_text = "acc=%.3f" % acc
|
||||||
|
|
||||||
|
plt.plot(fpr, tpr, label=legend + metric_text, **plot_kwargs)
|
||||||
|
return (acc, auc)
|
||||||
|
|
||||||
|
|
||||||
|
def fig_fpr_tpr():
|
||||||
|
plt.figure(figsize=(4, 3))
|
||||||
|
|
||||||
|
do_plot(generate_ours, keep, scores, 1, "Ours (online)\n", metric="auc")
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours, fix_variance=True), keep, scores, 1, "Ours (online, fixed variance)\n", metric="auc")
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours_offline), keep, scores, 1, "Ours (offline)\n", metric="auc")
|
||||||
|
|
||||||
|
do_plot(functools.partial(generate_ours_offline, fix_variance=True), keep, scores, 1, "Ours (offline, fixed variance)\n", metric="auc")
|
||||||
|
|
||||||
|
do_plot(generate_global, keep, scores, 1, "Global threshold\n", metric="auc")
|
||||||
|
|
||||||
|
plt.semilogx()
|
||||||
|
plt.semilogy()
|
||||||
|
plt.xlim(1e-5, 1)
|
||||||
|
plt.ylim(1e-5, 1)
|
||||||
|
plt.xlabel("False Positive Rate")
|
||||||
|
plt.ylabel("True Positive Rate")
|
||||||
|
plt.plot([0, 1], [0, 1], ls="--", color="gray")
|
||||||
|
plt.subplots_adjust(bottom=0.18, left=0.18, top=0.96, right=0.96)
|
||||||
|
plt.legend(fontsize=8)
|
||||||
|
plt.savefig("fprtpr.png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_data()
|
||||||
|
fig_fpr_tpr()
|
21
lira-pytorch/run.sh
Executable file
21
lira-pytorch/run.sh
Executable file
|
@ -0,0 +1,21 @@
|
||||||
|
python3 train.py --epochs 100 --shadow_id 0 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 1 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 2 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 3 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 4 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 5 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 6 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 7 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 8 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 9 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 10 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 11 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 12 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 13 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 14 --debug
|
||||||
|
python3 train.py --epochs 100 --shadow_id 15 --debug
|
||||||
|
|
||||||
|
python3 inference.py --savedir exp/cifar10
|
||||||
|
python3 score.py --savedir exp/cifar10
|
||||||
|
python3 plot.py --savedir exp/cifar10
|
||||||
|
|
70
lira-pytorch/score.py
Normal file
70
lira-pytorch/score.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Modified copy by Chenxiang Zhang (orientino) of the original:
|
||||||
|
# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_one(path):
|
||||||
|
"""
|
||||||
|
This loads a logits and converts it to a scored prediction.
|
||||||
|
"""
|
||||||
|
opredictions = np.load(os.path.join(path, "logits.npy")) # [n_examples, n_augs, n_classes]
|
||||||
|
|
||||||
|
# Be exceptionally careful.
|
||||||
|
# Numerically stable everything, as described in the paper.
|
||||||
|
predictions = opredictions - np.max(opredictions, axis=-1, keepdims=True)
|
||||||
|
predictions = np.array(np.exp(predictions), dtype=np.float64)
|
||||||
|
predictions = predictions / np.sum(predictions, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
labels = get_labels() # TODO generalize this
|
||||||
|
|
||||||
|
COUNT = predictions.shape[0]
|
||||||
|
y_true = predictions[np.arange(COUNT), :, labels[:COUNT]]
|
||||||
|
|
||||||
|
print("mean acc", np.mean(predictions[:, 0, :].argmax(1) == labels[:COUNT]))
|
||||||
|
|
||||||
|
predictions[np.arange(COUNT), :, labels[:COUNT]] = 0
|
||||||
|
y_wrong = np.sum(predictions, axis=-1)
|
||||||
|
|
||||||
|
logit = np.log(y_true + 1e-45) - np.log(y_wrong + 1e-45)
|
||||||
|
np.save(os.path.join(path, "scores.npy"), logit)
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels():
|
||||||
|
datadir = Path().home() / "opt/data/cifar"
|
||||||
|
train_ds = CIFAR10(root=datadir, train=True, download=True)
|
||||||
|
return np.array(train_ds.targets)
|
||||||
|
|
||||||
|
|
||||||
|
def load_stats():
|
||||||
|
with mp.Pool(8) as p:
|
||||||
|
p.map(load_one, [os.path.join(args.savedir, x) for x in os.listdir(args.savedir)])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_stats()
|
180
lira-pytorch/train.py
Normal file
180
lira-pytorch/train.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# PyTorch implementation of
|
||||||
|
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/train.py
|
||||||
|
#
|
||||||
|
# author: Chenxiang Zhang (orientino)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
from tqdm import tqdm
|
||||||
|
from opacus.validators import ModuleValidator
|
||||||
|
from opacus import PrivacyEngine
|
||||||
|
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||||
|
|
||||||
|
from wide_resnet import WideResNet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--lr", default=0.1, type=float)
|
||||||
|
parser.add_argument("--epochs", default=1, type=int)
|
||||||
|
parser.add_argument("--n_shadows", default=16, type=int)
|
||||||
|
parser.add_argument("--shadow_id", default=1, type=int)
|
||||||
|
parser.add_argument("--model", default="resnet18", type=str)
|
||||||
|
parser.add_argument("--pkeep", default=0.5, type=float)
|
||||||
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||||
|
parser.add_argument("--debug", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
||||||
|
EPOCHS = args.epochs
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
seed = np.random.randint(0, 1000000000)
|
||||||
|
seed ^= int(time.time())
|
||||||
|
pl.seed_everything(seed)
|
||||||
|
|
||||||
|
args.debug = True
|
||||||
|
wandb.init(project="lira", mode="disabled" if args.debug else "online")
|
||||||
|
wandb.config.update(args)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
test_transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
datadir = Path().home() / "opt/data/cifar"
|
||||||
|
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||||
|
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||||
|
|
||||||
|
# Compute the IN / OUT subset:
|
||||||
|
# If we run each experiment independently then even after a lot of trials
|
||||||
|
# there will still probably be some examples that were always included
|
||||||
|
# or always excluded. So instead, with experiment IDs, we guarantee that
|
||||||
|
# after `args.n_shadows` are done, each example is seen exactly half
|
||||||
|
# of the time in train, and half of the time not in train.
|
||||||
|
|
||||||
|
size = len(train_ds)
|
||||||
|
np.random.seed(seed)
|
||||||
|
if args.n_shadows is not None:
|
||||||
|
np.random.seed(0)
|
||||||
|
keep = np.random.uniform(0, 1, size=(args.n_shadows, size))
|
||||||
|
order = keep.argsort(0)
|
||||||
|
keep = order < int(args.pkeep * args.n_shadows)
|
||||||
|
keep = np.array(keep[args.shadow_id], dtype=bool)
|
||||||
|
keep = keep.nonzero()[0]
|
||||||
|
else:
|
||||||
|
keep = np.random.choice(size, size=int(args.pkeep * size), replace=False)
|
||||||
|
keep.sort()
|
||||||
|
keep_bool = np.full((size), False)
|
||||||
|
keep_bool[keep] = True
|
||||||
|
|
||||||
|
train_ds = torch.utils.data.Subset(train_ds, keep)
|
||||||
|
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
|
||||||
|
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
if args.model == "wresnet28-2":
|
||||||
|
m = WideResNet(28, 2, 0.0, 10)
|
||||||
|
print("one")
|
||||||
|
elif args.model == "wresnet28-10":
|
||||||
|
m = WideResNet(28, 10, 0.3, 10)
|
||||||
|
print("two")
|
||||||
|
elif args.model == "resnet18":
|
||||||
|
print("three")
|
||||||
|
m = models.resnet18(weights=None, num_classes=10)
|
||||||
|
m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
m.maxpool = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
m = m.to(DEVICE)
|
||||||
|
|
||||||
|
m = ModuleValidator.fix(m)
|
||||||
|
ModuleValidator.validate(m, strict=True)
|
||||||
|
|
||||||
|
optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||||||
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)
|
||||||
|
|
||||||
|
privacy_engine = PrivacyEngine(accountant='rdp', secure_mod=True)
|
||||||
|
m, optim, train_dl = privacy_engine.make_private_with_epsilon(
|
||||||
|
module=m,
|
||||||
|
optimizer=optim,
|
||||||
|
data_loader=train_dl,
|
||||||
|
epochs=args.epochs,
|
||||||
|
target_epsilon=1,
|
||||||
|
target_delta=1e-4,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Device: {DEVICE}")
|
||||||
|
|
||||||
|
# Train
|
||||||
|
# max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
|
||||||
|
with BatchMemoryManager(
|
||||||
|
data_loader=train_dl,
|
||||||
|
max_physical_batch_size=1000,
|
||||||
|
optimizer=optim
|
||||||
|
) as memory_safe_data_loader:
|
||||||
|
|
||||||
|
for i in tqdm(range(args.epochs)):
|
||||||
|
m.train()
|
||||||
|
loss_total = 0
|
||||||
|
pbar = tqdm(memory_safe_data_loader, leave=False)
|
||||||
|
#pbar = tqdm(train_dl, leave=False)
|
||||||
|
for itr, (x, y) in enumerate(pbar):
|
||||||
|
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(m(x), y)
|
||||||
|
loss_total += loss
|
||||||
|
|
||||||
|
pbar.set_postfix_str(f"loss: {loss:.2f}")
|
||||||
|
optim.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
sched.step()
|
||||||
|
|
||||||
|
wandb.log({"loss": loss_total / len(train_dl)})
|
||||||
|
|
||||||
|
print(f"[test] acc_test: {get_acc(m, test_dl):.4f}")
|
||||||
|
wandb.log({"acc_test": get_acc(m, test_dl)})
|
||||||
|
|
||||||
|
savedir = os.path.join(args.savedir, str(args.shadow_id))
|
||||||
|
os.makedirs(savedir, exist_ok=True)
|
||||||
|
np.save(savedir + "/keep.npy", keep_bool)
|
||||||
|
torch.save(m.state_dict(), savedir + "/model.pt")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_acc(model, dl):
|
||||||
|
acc = []
|
||||||
|
for x, y in dl:
|
||||||
|
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||||
|
acc.append(torch.argmax(model(x), dim=1) == y)
|
||||||
|
acc = torch.cat(acc)
|
||||||
|
acc = torch.sum(acc) / len(acc)
|
||||||
|
|
||||||
|
return acc.item()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
75
lira-pytorch/wide_resnet.py
Normal file
75
lira-pytorch/wide_resnet.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
|
||||||
|
class wide_basic(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, dropout_rate, stride=1):
|
||||||
|
super(wide_basic, self).__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.dropout = nn.Dropout(p=dropout_rate)
|
||||||
|
|
||||||
|
self.shortcut = nn.Sequential()
|
||||||
|
if stride != 1 or in_planes != planes:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1(F.relu(self.bn1(x)))
|
||||||
|
out = self.dropout(out)
|
||||||
|
out = self.conv2(F.relu(self.bn2(out)))
|
||||||
|
out += self.shortcut(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WideResNet(nn.Module):
|
||||||
|
def __init__(self, depth, widen_factor, dropout_rate, n_classes):
|
||||||
|
super(WideResNet, self).__init__()
|
||||||
|
self.in_planes = 16
|
||||||
|
|
||||||
|
assert (depth - 4) % 6 == 0, "Wide-ResNet depth should be 6n+4"
|
||||||
|
n = (depth - 4) // 6
|
||||||
|
k = widen_factor
|
||||||
|
stages = [16, 16 * k, 32 * k, 64 * k]
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, stages[0], kernel_size=3, stride=1, padding=1)
|
||||||
|
self.layer1 = self._wide_layer(wide_basic, stages[1], n, dropout_rate, stride=1)
|
||||||
|
self.layer2 = self._wide_layer(wide_basic, stages[2], n, dropout_rate, stride=2)
|
||||||
|
self.layer3 = self._wide_layer(wide_basic, stages[3], n, dropout_rate, stride=2)
|
||||||
|
self.bn1 = nn.BatchNorm2d(stages[3], momentum=0.9)
|
||||||
|
self.linear = nn.Linear(stages[3], n_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _wide_layer(self, block, planes, n_blocks, dropout_rate, stride):
|
||||||
|
strides = [stride] + [1] * (int(n_blocks) - 1)
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, dropout_rate, stride))
|
||||||
|
self.in_planes = planes
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.layer1(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
out = self.layer3(out)
|
||||||
|
out = F.relu(self.bn1(out))
|
||||||
|
out = F.avg_pool2d(out, 8)
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
out = self.linear(out)
|
||||||
|
|
||||||
|
return out
|
Loading…
Reference in a new issue