Pytorch version of lira
This commit is contained in:
parent
3467c25882
commit
e95444fa74
12 changed files with 912 additions and 0 deletions
7
lira-pytorch/.gitignore
vendored
Normal file
7
lira-pytorch/.gitignore
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
__pycache__
|
||||
exp
|
||||
logs
|
||||
slurm
|
||||
gpu.sh
|
||||
*.out
|
||||
|
202
lira-pytorch/LICENSE
Normal file
202
lira-pytorch/LICENSE
Normal file
|
@ -0,0 +1,202 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
42
lira-pytorch/README.md
Normal file
42
lira-pytorch/README.md
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Likelihood Ration Attack (LiRA) in PyTorch
|
||||
Implementation of the original [LiRA](https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021) using PyTorch. To run the code, first create an environment with the `env.yml` file. Then run the following command to train the models and run the LiRA attack:
|
||||
|
||||
```
|
||||
./run.sh
|
||||
```
|
||||
|
||||
The output will generate and store a log-scale FPR-TPR curve as `./fprtpr.png` with the TPR@0.1%FPR in the output log.
|
||||
|
||||
## Results on CIFAR10
|
||||
|
||||
Using 16 shadow models trained with `ResNet18 and 2 augmented queries`:
|
||||
|
||||
![roc](figures/fprtpr_resnet18.png)
|
||||
```
|
||||
Attack Ours (online)
|
||||
AUC 0.6548, Accuracy 0.6015, TPR@0.1%FPR of 0.0068
|
||||
Attack Ours (online, fixed variance)
|
||||
AUC 0.6700, Accuracy 0.6042, TPR@0.1%FPR of 0.0464
|
||||
Attack Ours (offline)
|
||||
AUC 0.5250, Accuracy 0.5353, TPR@0.1%FPR of 0.0041
|
||||
Attack Ours (offline, fixed variance)
|
||||
AUC 0.5270, Accuracy 0.5380, TPR@0.1%FPR of 0.0192
|
||||
Attack Global threshold
|
||||
AUC 0.5948, Accuracy 0.5869, TPR@0.1%FPR of 0.0006
|
||||
```
|
||||
|
||||
Using 16 shadow models trained with `WideResNet28-10 and 2 augmented queries`:
|
||||
|
||||
![roc](figures/fprtpr_wideresnet.png)
|
||||
```
|
||||
Attack Ours (online)
|
||||
AUC 0.6834, Accuracy 0.6152, TPR@0.1%FPR of 0.0240
|
||||
Attack Ours (online, fixed variance)
|
||||
AUC 0.7017, Accuracy 0.6240, TPR@0.1%FPR of 0.0704
|
||||
Attack Ours (offline)
|
||||
AUC 0.5621, Accuracy 0.5649, TPR@0.1%FPR of 0.0140
|
||||
Attack Ours (offline, fixed variance)
|
||||
AUC 0.5698, Accuracy 0.5628, TPR@0.1%FPR of 0.0370
|
||||
Attack Global threshold
|
||||
AUC 0.6016, Accuracy 0.5977, TPR@0.1%FPR of 0.0013
|
||||
```
|
35
lira-pytorch/env.yml
Normal file
35
lira-pytorch/env.yml
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Minimal environment for starting a project using conda/mamba:
|
||||
# conda env create -n ENVNAME --file ENV.yml
|
||||
|
||||
name: template
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.8.6
|
||||
- pip
|
||||
- pytest
|
||||
- numpy
|
||||
- scipy
|
||||
- scikit-learn
|
||||
- matplotlib
|
||||
- pandas
|
||||
- tqdm
|
||||
- wandb
|
||||
- jupyterlab
|
||||
- jupyter
|
||||
- ipykernel
|
||||
- pytorch
|
||||
- torchvision
|
||||
- torchaudio
|
||||
- pytorch-cuda=12.1
|
||||
- tqdm
|
||||
- pytorch-lightning
|
||||
- lightning-bolts
|
||||
- torchmetrics
|
||||
|
||||
# Install packages with pip
|
||||
# - pip:
|
||||
# - ray[tune]
|
BIN
lira-pytorch/figures/fprtpr_resnet18.png
Normal file
BIN
lira-pytorch/figures/fprtpr_resnet18.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
BIN
lira-pytorch/figures/fprtpr_wideresnet.png
Normal file
BIN
lira-pytorch/figures/fprtpr_wideresnet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
75
lira-pytorch/inference.py
Normal file
75
lira-pytorch/inference.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# PyTorch implementation of
|
||||
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
|
||||
#
|
||||
# author: Chenxiang Zhang (orientino)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import models, transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
|
||||
from wide_resnet import WideResNet
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_queries", default=2, type=int)
|
||||
parser.add_argument("--model", default="resnet18", type=str)
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run():
|
||||
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
||||
|
||||
# Dataset
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform)
|
||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||
|
||||
# Infer the logits with multiple queries
|
||||
for path in os.listdir(args.savedir):
|
||||
if args.model == "wresnet28-2":
|
||||
m = WideResNet(28, 2, 0.0, 10)
|
||||
elif args.model == "wresnet28-10":
|
||||
m = WideResNet(28, 10, 0.3, 10)
|
||||
elif args.model == "resnet18":
|
||||
m = models.resnet18(weights=None, num_classes=10)
|
||||
m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
m.maxpool = nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
|
||||
m.to(DEVICE)
|
||||
m.eval()
|
||||
|
||||
logits_n = []
|
||||
for i in range(args.n_queries):
|
||||
logits = []
|
||||
for x, _ in tqdm(train_dl):
|
||||
x = x.to(DEVICE)
|
||||
outputs = m(x)
|
||||
logits.append(outputs.cpu().numpy())
|
||||
logits_n.append(np.concatenate(logits))
|
||||
logits_n = np.stack(logits_n, axis=1)
|
||||
print(logits_n.shape)
|
||||
|
||||
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
205
lira-pytorch/plot.py
Normal file
205
lira-pytorch/plot.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified copy by Chenxiang Zhang (orientino) of the original:
|
||||
# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
|
||||
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
|
||||
matplotlib.rcParams["pdf.fonttype"] = 42
|
||||
matplotlib.rcParams["ps.fonttype"] = 42
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def sweep(score, x):
|
||||
"""
|
||||
Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
|
||||
"""
|
||||
fpr, tpr, _ = roc_curve(x, -score)
|
||||
acc = np.max(1 - (fpr + (1 - tpr)) / 2)
|
||||
return fpr, tpr, auc(fpr, tpr), acc
|
||||
|
||||
|
||||
def load_data():
|
||||
"""
|
||||
Load our saved scores and then put them into a big matrix.
|
||||
"""
|
||||
global scores, keep
|
||||
scores = []
|
||||
keep = []
|
||||
|
||||
for path in os.listdir(args.savedir):
|
||||
scores.append(np.load(os.path.join(args.savedir, path, "scores.npy")))
|
||||
keep.append(np.load(os.path.join(args.savedir, path, "keep.npy")))
|
||||
scores = np.array(scores)
|
||||
keep = np.array(keep)
|
||||
|
||||
return scores, keep
|
||||
|
||||
|
||||
def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False):
|
||||
"""
|
||||
Fit a two predictive models using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:, j], j, :])
|
||||
dat_out.append(scores[~keep[:, j], j, :])
|
||||
|
||||
in_size = min(min(map(len, dat_in)), in_size)
|
||||
out_size = min(min(map(len, dat_out)), out_size)
|
||||
|
||||
dat_in = np.array([x[:in_size] for x in dat_in])
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_in = np.median(dat_in, 1)
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_in = np.std(dat_in)
|
||||
std_out = np.std(dat_in)
|
||||
else:
|
||||
std_in = np.std(dat_in, 1)
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in + 1e-30)
|
||||
pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30)
|
||||
score = pr_in - pr_out
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
|
||||
def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False):
|
||||
"""
|
||||
Fit a single predictive model using keep and scores in order to predict
|
||||
if the examples in check_scores were training data or not, using the
|
||||
ground truth answer from check_keep.
|
||||
"""
|
||||
dat_in = []
|
||||
dat_out = []
|
||||
|
||||
for j in range(scores.shape[1]):
|
||||
dat_in.append(scores[keep[:, j], j, :])
|
||||
dat_out.append(scores[~keep[:, j], j, :])
|
||||
|
||||
out_size = min(min(map(len, dat_out)), out_size)
|
||||
|
||||
dat_out = np.array([x[:out_size] for x in dat_out])
|
||||
|
||||
mean_out = np.median(dat_out, 1)
|
||||
|
||||
if fix_variance:
|
||||
std_out = np.std(dat_out)
|
||||
else:
|
||||
std_out = np.std(dat_out, 1)
|
||||
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
score = scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30)
|
||||
|
||||
prediction.extend(score.mean(1))
|
||||
answers.extend(ans)
|
||||
return prediction, answers
|
||||
|
||||
|
||||
def generate_global(keep, scores, check_keep, check_scores):
|
||||
"""
|
||||
Use a simple global threshold sweep to predict if the examples in
|
||||
check_scores were training data or not, using the ground truth answer from
|
||||
check_keep.
|
||||
"""
|
||||
prediction = []
|
||||
answers = []
|
||||
for ans, sc in zip(check_keep, check_scores):
|
||||
prediction.extend(-sc.mean(1))
|
||||
answers.extend(ans)
|
||||
|
||||
return prediction, answers
|
||||
|
||||
|
||||
def do_plot(fn, keep, scores, ntest, legend="", metric="auc", sweep_fn=sweep, **plot_kwargs):
|
||||
"""
|
||||
Generate the ROC curves by using ntest models as test models and the rest to train.
|
||||
"""
|
||||
|
||||
prediction, answers = fn(keep[:-ntest], scores[:-ntest], keep[-ntest:], scores[-ntest:])
|
||||
|
||||
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
|
||||
|
||||
low = tpr[np.where(fpr < 0.001)[0][-1]]
|
||||
|
||||
print("Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f" % (legend, auc, acc, low))
|
||||
|
||||
metric_text = ""
|
||||
if metric == "auc":
|
||||
metric_text = "auc=%.3f" % auc
|
||||
elif metric == "acc":
|
||||
metric_text = "acc=%.3f" % acc
|
||||
|
||||
plt.plot(fpr, tpr, label=legend + metric_text, **plot_kwargs)
|
||||
return (acc, auc)
|
||||
|
||||
|
||||
def fig_fpr_tpr():
|
||||
plt.figure(figsize=(4, 3))
|
||||
|
||||
do_plot(generate_ours, keep, scores, 1, "Ours (online)\n", metric="auc")
|
||||
|
||||
do_plot(functools.partial(generate_ours, fix_variance=True), keep, scores, 1, "Ours (online, fixed variance)\n", metric="auc")
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline), keep, scores, 1, "Ours (offline)\n", metric="auc")
|
||||
|
||||
do_plot(functools.partial(generate_ours_offline, fix_variance=True), keep, scores, 1, "Ours (offline, fixed variance)\n", metric="auc")
|
||||
|
||||
do_plot(generate_global, keep, scores, 1, "Global threshold\n", metric="auc")
|
||||
|
||||
plt.semilogx()
|
||||
plt.semilogy()
|
||||
plt.xlim(1e-5, 1)
|
||||
plt.ylim(1e-5, 1)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.plot([0, 1], [0, 1], ls="--", color="gray")
|
||||
plt.subplots_adjust(bottom=0.18, left=0.18, top=0.96, right=0.96)
|
||||
plt.legend(fontsize=8)
|
||||
plt.savefig("fprtpr.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_data()
|
||||
fig_fpr_tpr()
|
21
lira-pytorch/run.sh
Executable file
21
lira-pytorch/run.sh
Executable file
|
@ -0,0 +1,21 @@
|
|||
python3 train.py --epochs 100 --shadow_id 0 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 1 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 2 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 3 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 4 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 5 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 6 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 7 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 8 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 9 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 10 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 11 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 12 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 13 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 14 --debug
|
||||
python3 train.py --epochs 100 --shadow_id 15 --debug
|
||||
|
||||
python3 inference.py --savedir exp/cifar10
|
||||
python3 score.py --savedir exp/cifar10
|
||||
python3 plot.py --savedir exp/cifar10
|
||||
|
70
lira-pytorch/score.py
Normal file
70
lira-pytorch/score.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modified copy by Chenxiang Zhang (orientino) of the original:
|
||||
# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
|
||||
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_one(path):
|
||||
"""
|
||||
This loads a logits and converts it to a scored prediction.
|
||||
"""
|
||||
opredictions = np.load(os.path.join(path, "logits.npy")) # [n_examples, n_augs, n_classes]
|
||||
|
||||
# Be exceptionally careful.
|
||||
# Numerically stable everything, as described in the paper.
|
||||
predictions = opredictions - np.max(opredictions, axis=-1, keepdims=True)
|
||||
predictions = np.array(np.exp(predictions), dtype=np.float64)
|
||||
predictions = predictions / np.sum(predictions, axis=-1, keepdims=True)
|
||||
|
||||
labels = get_labels() # TODO generalize this
|
||||
|
||||
COUNT = predictions.shape[0]
|
||||
y_true = predictions[np.arange(COUNT), :, labels[:COUNT]]
|
||||
|
||||
print("mean acc", np.mean(predictions[:, 0, :].argmax(1) == labels[:COUNT]))
|
||||
|
||||
predictions[np.arange(COUNT), :, labels[:COUNT]] = 0
|
||||
y_wrong = np.sum(predictions, axis=-1)
|
||||
|
||||
logit = np.log(y_true + 1e-45) - np.log(y_wrong + 1e-45)
|
||||
np.save(os.path.join(path, "scores.npy"), logit)
|
||||
|
||||
|
||||
def get_labels():
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True)
|
||||
return np.array(train_ds.targets)
|
||||
|
||||
|
||||
def load_stats():
|
||||
with mp.Pool(8) as p:
|
||||
p.map(load_one, [os.path.join(args.savedir, x) for x in os.listdir(args.savedir)])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_stats()
|
180
lira-pytorch/train.py
Normal file
180
lira-pytorch/train.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# PyTorch implementation of
|
||||
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/train.py
|
||||
#
|
||||
# author: Chenxiang Zhang (orientino)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import wandb
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import models, transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from tqdm import tqdm
|
||||
from opacus.validators import ModuleValidator
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
|
||||
from wide_resnet import WideResNet
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lr", default=0.1, type=float)
|
||||
parser.add_argument("--epochs", default=1, type=int)
|
||||
parser.add_argument("--n_shadows", default=16, type=int)
|
||||
parser.add_argument("--shadow_id", default=1, type=int)
|
||||
parser.add_argument("--model", default="resnet18", type=str)
|
||||
parser.add_argument("--pkeep", default=0.5, type=float)
|
||||
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
||||
EPOCHS = args.epochs
|
||||
|
||||
|
||||
def run():
|
||||
seed = np.random.randint(0, 1000000000)
|
||||
seed ^= int(time.time())
|
||||
pl.seed_everything(seed)
|
||||
|
||||
args.debug = True
|
||||
wandb.init(project="lira", mode="disabled" if args.debug else "online")
|
||||
wandb.config.update(args)
|
||||
|
||||
# Dataset
|
||||
train_transform = transforms.Compose(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
test_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
|
||||
]
|
||||
)
|
||||
datadir = Path().home() / "opt/data/cifar"
|
||||
train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform)
|
||||
test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform)
|
||||
|
||||
# Compute the IN / OUT subset:
|
||||
# If we run each experiment independently then even after a lot of trials
|
||||
# there will still probably be some examples that were always included
|
||||
# or always excluded. So instead, with experiment IDs, we guarantee that
|
||||
# after `args.n_shadows` are done, each example is seen exactly half
|
||||
# of the time in train, and half of the time not in train.
|
||||
|
||||
size = len(train_ds)
|
||||
np.random.seed(seed)
|
||||
if args.n_shadows is not None:
|
||||
np.random.seed(0)
|
||||
keep = np.random.uniform(0, 1, size=(args.n_shadows, size))
|
||||
order = keep.argsort(0)
|
||||
keep = order < int(args.pkeep * args.n_shadows)
|
||||
keep = np.array(keep[args.shadow_id], dtype=bool)
|
||||
keep = keep.nonzero()[0]
|
||||
else:
|
||||
keep = np.random.choice(size, size=int(args.pkeep * size), replace=False)
|
||||
keep.sort()
|
||||
keep_bool = np.full((size), False)
|
||||
keep_bool[keep] = True
|
||||
|
||||
train_ds = torch.utils.data.Subset(train_ds, keep)
|
||||
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
|
||||
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
|
||||
|
||||
# Model
|
||||
if args.model == "wresnet28-2":
|
||||
m = WideResNet(28, 2, 0.0, 10)
|
||||
print("one")
|
||||
elif args.model == "wresnet28-10":
|
||||
m = WideResNet(28, 10, 0.3, 10)
|
||||
print("two")
|
||||
elif args.model == "resnet18":
|
||||
print("three")
|
||||
m = models.resnet18(weights=None, num_classes=10)
|
||||
m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
m.maxpool = nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
m = m.to(DEVICE)
|
||||
|
||||
m = ModuleValidator.fix(m)
|
||||
ModuleValidator.validate(m, strict=True)
|
||||
|
||||
optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)
|
||||
|
||||
privacy_engine = PrivacyEngine(accountant='rdp', secure_mod=True)
|
||||
m, optim, train_dl = privacy_engine.make_private_with_epsilon(
|
||||
module=m,
|
||||
optimizer=optim,
|
||||
data_loader=train_dl,
|
||||
epochs=args.epochs,
|
||||
target_epsilon=1,
|
||||
target_delta=1e-4,
|
||||
max_grad_norm=1.0,
|
||||
)
|
||||
|
||||
print(f"Device: {DEVICE}")
|
||||
|
||||
# Train
|
||||
# max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
|
||||
with BatchMemoryManager(
|
||||
data_loader=train_dl,
|
||||
max_physical_batch_size=1000,
|
||||
optimizer=optim
|
||||
) as memory_safe_data_loader:
|
||||
|
||||
for i in tqdm(range(args.epochs)):
|
||||
m.train()
|
||||
loss_total = 0
|
||||
pbar = tqdm(memory_safe_data_loader, leave=False)
|
||||
#pbar = tqdm(train_dl, leave=False)
|
||||
for itr, (x, y) in enumerate(pbar):
|
||||
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||
|
||||
loss = F.cross_entropy(m(x), y)
|
||||
loss_total += loss
|
||||
|
||||
pbar.set_postfix_str(f"loss: {loss:.2f}")
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
sched.step()
|
||||
|
||||
wandb.log({"loss": loss_total / len(train_dl)})
|
||||
|
||||
print(f"[test] acc_test: {get_acc(m, test_dl):.4f}")
|
||||
wandb.log({"acc_test": get_acc(m, test_dl)})
|
||||
|
||||
savedir = os.path.join(args.savedir, str(args.shadow_id))
|
||||
os.makedirs(savedir, exist_ok=True)
|
||||
np.save(savedir + "/keep.npy", keep_bool)
|
||||
torch.save(m.state_dict(), savedir + "/model.pt")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_acc(model, dl):
|
||||
acc = []
|
||||
for x, y in dl:
|
||||
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||
acc.append(torch.argmax(model(x), dim=1) == y)
|
||||
acc = torch.cat(acc)
|
||||
acc = torch.sum(acc) / len(acc)
|
||||
|
||||
return acc.item()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
75
lira-pytorch/wide_resnet.py
Normal file
75
lira-pytorch/wide_resnet.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class wide_basic(nn.Module):
|
||||
def __init__(self, in_planes, planes, dropout_rate, stride=1):
|
||||
super(wide_basic, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += self.shortcut(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
def __init__(self, depth, widen_factor, dropout_rate, n_classes):
|
||||
super(WideResNet, self).__init__()
|
||||
self.in_planes = 16
|
||||
|
||||
assert (depth - 4) % 6 == 0, "Wide-ResNet depth should be 6n+4"
|
||||
n = (depth - 4) // 6
|
||||
k = widen_factor
|
||||
stages = [16, 16 * k, 32 * k, 64 * k]
|
||||
|
||||
self.conv1 = nn.Conv2d(3, stages[0], kernel_size=3, stride=1, padding=1)
|
||||
self.layer1 = self._wide_layer(wide_basic, stages[1], n, dropout_rate, stride=1)
|
||||
self.layer2 = self._wide_layer(wide_basic, stages[2], n, dropout_rate, stride=2)
|
||||
self.layer3 = self._wide_layer(wide_basic, stages[3], n, dropout_rate, stride=2)
|
||||
self.bn1 = nn.BatchNorm2d(stages[3], momentum=0.9)
|
||||
self.linear = nn.Linear(stages[3], n_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _wide_layer(self, block, planes, n_blocks, dropout_rate, stride):
|
||||
strides = [stride] + [1] * (int(n_blocks) - 1)
|
||||
layers = []
|
||||
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, dropout_rate, stride))
|
||||
self.in_planes = planes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
Loading…
Reference in a new issue