diff --git a/lira-pytorch/.gitignore b/lira-pytorch/.gitignore new file mode 100644 index 0000000..a06f1b4 --- /dev/null +++ b/lira-pytorch/.gitignore @@ -0,0 +1,7 @@ +__pycache__ +exp +logs +slurm +gpu.sh +*.out + diff --git a/lira-pytorch/LICENSE b/lira-pytorch/LICENSE new file mode 100644 index 0000000..57bc88a --- /dev/null +++ b/lira-pytorch/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/lira-pytorch/README.md b/lira-pytorch/README.md new file mode 100644 index 0000000..ab2d606 --- /dev/null +++ b/lira-pytorch/README.md @@ -0,0 +1,42 @@ +# Likelihood Ration Attack (LiRA) in PyTorch +Implementation of the original [LiRA](https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021) using PyTorch. To run the code, first create an environment with the `env.yml` file. Then run the following command to train the models and run the LiRA attack: + +``` +./run.sh +``` + +The output will generate and store a log-scale FPR-TPR curve as `./fprtpr.png` with the TPR@0.1%FPR in the output log. + +## Results on CIFAR10 + +Using 16 shadow models trained with `ResNet18 and 2 augmented queries`: + +![roc](figures/fprtpr_resnet18.png) +``` +Attack Ours (online) + AUC 0.6548, Accuracy 0.6015, TPR@0.1%FPR of 0.0068 +Attack Ours (online, fixed variance) + AUC 0.6700, Accuracy 0.6042, TPR@0.1%FPR of 0.0464 +Attack Ours (offline) + AUC 0.5250, Accuracy 0.5353, TPR@0.1%FPR of 0.0041 +Attack Ours (offline, fixed variance) + AUC 0.5270, Accuracy 0.5380, TPR@0.1%FPR of 0.0192 +Attack Global threshold + AUC 0.5948, Accuracy 0.5869, TPR@0.1%FPR of 0.0006 +``` + +Using 16 shadow models trained with `WideResNet28-10 and 2 augmented queries`: + +![roc](figures/fprtpr_wideresnet.png) +``` +Attack Ours (online) + AUC 0.6834, Accuracy 0.6152, TPR@0.1%FPR of 0.0240 +Attack Ours (online, fixed variance) + AUC 0.7017, Accuracy 0.6240, TPR@0.1%FPR of 0.0704 +Attack Ours (offline) + AUC 0.5621, Accuracy 0.5649, TPR@0.1%FPR of 0.0140 +Attack Ours (offline, fixed variance) + AUC 0.5698, Accuracy 0.5628, TPR@0.1%FPR of 0.0370 +Attack Global threshold + AUC 0.6016, Accuracy 0.5977, TPR@0.1%FPR of 0.0013 +``` \ No newline at end of file diff --git a/lira-pytorch/env.yml b/lira-pytorch/env.yml new file mode 100644 index 0000000..faa5b07 --- /dev/null +++ b/lira-pytorch/env.yml @@ -0,0 +1,35 @@ +# Minimal environment for starting a project using conda/mamba: +# conda env create -n ENVNAME --file ENV.yml + +name: template +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.8.6 + - pip + - pytest + - numpy + - scipy + - scikit-learn + - matplotlib + - pandas + - tqdm + - wandb + - jupyterlab + - jupyter + - ipykernel + - pytorch + - torchvision + - torchaudio + - pytorch-cuda=12.1 + - tqdm + - pytorch-lightning + - lightning-bolts + - torchmetrics + + # Install packages with pip + # - pip: + # - ray[tune] diff --git a/lira-pytorch/figures/fprtpr_resnet18.png b/lira-pytorch/figures/fprtpr_resnet18.png new file mode 100644 index 0000000..00a672a Binary files /dev/null and b/lira-pytorch/figures/fprtpr_resnet18.png differ diff --git a/lira-pytorch/figures/fprtpr_wideresnet.png b/lira-pytorch/figures/fprtpr_wideresnet.png new file mode 100644 index 0000000..2defb08 Binary files /dev/null and b/lira-pytorch/figures/fprtpr_wideresnet.png differ diff --git a/lira-pytorch/inference.py b/lira-pytorch/inference.py new file mode 100644 index 0000000..0577eb7 --- /dev/null +++ b/lira-pytorch/inference.py @@ -0,0 +1,75 @@ +# PyTorch implementation of +# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py +# +# author: Chenxiang Zhang (orientino) + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader +from torchvision import models, transforms +from torchvision.datasets import CIFAR10 +from tqdm import tqdm + +from wide_resnet import WideResNet + +parser = argparse.ArgumentParser() +parser.add_argument("--n_queries", default=2, type=int) +parser.add_argument("--model", default="resnet18", type=str) +parser.add_argument("--savedir", default="exp/cifar10", type=str) +args = parser.parse_args() + + +@torch.no_grad() +def run(): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") + + # Dataset + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), + ] + ) + datadir = Path().home() / "opt/data/cifar" + train_ds = CIFAR10(root=datadir, train=True, download=True, transform=transform) + train_dl = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=4) + + # Infer the logits with multiple queries + for path in os.listdir(args.savedir): + if args.model == "wresnet28-2": + m = WideResNet(28, 2, 0.0, 10) + elif args.model == "wresnet28-10": + m = WideResNet(28, 10, 0.3, 10) + elif args.model == "resnet18": + m = models.resnet18(weights=None, num_classes=10) + m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + m.maxpool = nn.Identity() + else: + raise NotImplementedError + m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt"))) + m.to(DEVICE) + m.eval() + + logits_n = [] + for i in range(args.n_queries): + logits = [] + for x, _ in tqdm(train_dl): + x = x.to(DEVICE) + outputs = m(x) + logits.append(outputs.cpu().numpy()) + logits_n.append(np.concatenate(logits)) + logits_n = np.stack(logits_n, axis=1) + print(logits_n.shape) + + np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n) + + +if __name__ == "__main__": + run() diff --git a/lira-pytorch/plot.py b/lira-pytorch/plot.py new file mode 100644 index 0000000..f38d0c7 --- /dev/null +++ b/lira-pytorch/plot.py @@ -0,0 +1,205 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified copy by Chenxiang Zhang (orientino) of the original: +# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021 + + +import argparse +import functools +import os + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import scipy.stats +from sklearn.metrics import auc, roc_curve + +matplotlib.rcParams["pdf.fonttype"] = 42 +matplotlib.rcParams["ps.fonttype"] = 42 + +parser = argparse.ArgumentParser() +parser.add_argument("--savedir", default="exp/cifar10", type=str) +args = parser.parse_args() + + +def sweep(score, x): + """ + Compute a ROC curve and then return the FPR, TPR, AUC, and ACC. + """ + fpr, tpr, _ = roc_curve(x, -score) + acc = np.max(1 - (fpr + (1 - tpr)) / 2) + return fpr, tpr, auc(fpr, tpr), acc + + +def load_data(): + """ + Load our saved scores and then put them into a big matrix. + """ + global scores, keep + scores = [] + keep = [] + + for path in os.listdir(args.savedir): + scores.append(np.load(os.path.join(args.savedir, path, "scores.npy"))) + keep.append(np.load(os.path.join(args.savedir, path, "keep.npy"))) + scores = np.array(scores) + keep = np.array(keep) + + return scores, keep + + +def generate_ours(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False): + """ + Fit a two predictive models using keep and scores in order to predict + if the examples in check_scores were training data or not, using the + ground truth answer from check_keep. + """ + dat_in = [] + dat_out = [] + + for j in range(scores.shape[1]): + dat_in.append(scores[keep[:, j], j, :]) + dat_out.append(scores[~keep[:, j], j, :]) + + in_size = min(min(map(len, dat_in)), in_size) + out_size = min(min(map(len, dat_out)), out_size) + + dat_in = np.array([x[:in_size] for x in dat_in]) + dat_out = np.array([x[:out_size] for x in dat_out]) + + mean_in = np.median(dat_in, 1) + mean_out = np.median(dat_out, 1) + + if fix_variance: + std_in = np.std(dat_in) + std_out = np.std(dat_in) + else: + std_in = np.std(dat_in, 1) + std_out = np.std(dat_out, 1) + + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in + 1e-30) + pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30) + score = pr_in - pr_out + + prediction.extend(score.mean(1)) + answers.extend(ans) + + return prediction, answers + + +def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000, out_size=100000, fix_variance=False): + """ + Fit a single predictive model using keep and scores in order to predict + if the examples in check_scores were training data or not, using the + ground truth answer from check_keep. + """ + dat_in = [] + dat_out = [] + + for j in range(scores.shape[1]): + dat_in.append(scores[keep[:, j], j, :]) + dat_out.append(scores[~keep[:, j], j, :]) + + out_size = min(min(map(len, dat_out)), out_size) + + dat_out = np.array([x[:out_size] for x in dat_out]) + + mean_out = np.median(dat_out, 1) + + if fix_variance: + std_out = np.std(dat_out) + else: + std_out = np.std(dat_out, 1) + + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + score = scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30) + + prediction.extend(score.mean(1)) + answers.extend(ans) + return prediction, answers + + +def generate_global(keep, scores, check_keep, check_scores): + """ + Use a simple global threshold sweep to predict if the examples in + check_scores were training data or not, using the ground truth answer from + check_keep. + """ + prediction = [] + answers = [] + for ans, sc in zip(check_keep, check_scores): + prediction.extend(-sc.mean(1)) + answers.extend(ans) + + return prediction, answers + + +def do_plot(fn, keep, scores, ntest, legend="", metric="auc", sweep_fn=sweep, **plot_kwargs): + """ + Generate the ROC curves by using ntest models as test models and the rest to train. + """ + + prediction, answers = fn(keep[:-ntest], scores[:-ntest], keep[-ntest:], scores[-ntest:]) + + fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool)) + + low = tpr[np.where(fpr < 0.001)[0][-1]] + + print("Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f" % (legend, auc, acc, low)) + + metric_text = "" + if metric == "auc": + metric_text = "auc=%.3f" % auc + elif metric == "acc": + metric_text = "acc=%.3f" % acc + + plt.plot(fpr, tpr, label=legend + metric_text, **plot_kwargs) + return (acc, auc) + + +def fig_fpr_tpr(): + plt.figure(figsize=(4, 3)) + + do_plot(generate_ours, keep, scores, 1, "Ours (online)\n", metric="auc") + + do_plot(functools.partial(generate_ours, fix_variance=True), keep, scores, 1, "Ours (online, fixed variance)\n", metric="auc") + + do_plot(functools.partial(generate_ours_offline), keep, scores, 1, "Ours (offline)\n", metric="auc") + + do_plot(functools.partial(generate_ours_offline, fix_variance=True), keep, scores, 1, "Ours (offline, fixed variance)\n", metric="auc") + + do_plot(generate_global, keep, scores, 1, "Global threshold\n", metric="auc") + + plt.semilogx() + plt.semilogy() + plt.xlim(1e-5, 1) + plt.ylim(1e-5, 1) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.plot([0, 1], [0, 1], ls="--", color="gray") + plt.subplots_adjust(bottom=0.18, left=0.18, top=0.96, right=0.96) + plt.legend(fontsize=8) + plt.savefig("fprtpr.png") + plt.show() + + +if __name__ == "__main__": + load_data() + fig_fpr_tpr() diff --git a/lira-pytorch/run.sh b/lira-pytorch/run.sh new file mode 100755 index 0000000..670d490 --- /dev/null +++ b/lira-pytorch/run.sh @@ -0,0 +1,21 @@ +python3 train.py --epochs 100 --shadow_id 0 --debug +python3 train.py --epochs 100 --shadow_id 1 --debug +python3 train.py --epochs 100 --shadow_id 2 --debug +python3 train.py --epochs 100 --shadow_id 3 --debug +python3 train.py --epochs 100 --shadow_id 4 --debug +python3 train.py --epochs 100 --shadow_id 5 --debug +python3 train.py --epochs 100 --shadow_id 6 --debug +python3 train.py --epochs 100 --shadow_id 7 --debug +python3 train.py --epochs 100 --shadow_id 8 --debug +python3 train.py --epochs 100 --shadow_id 9 --debug +python3 train.py --epochs 100 --shadow_id 10 --debug +python3 train.py --epochs 100 --shadow_id 11 --debug +python3 train.py --epochs 100 --shadow_id 12 --debug +python3 train.py --epochs 100 --shadow_id 13 --debug +python3 train.py --epochs 100 --shadow_id 14 --debug +python3 train.py --epochs 100 --shadow_id 15 --debug + +python3 inference.py --savedir exp/cifar10 +python3 score.py --savedir exp/cifar10 +python3 plot.py --savedir exp/cifar10 + diff --git a/lira-pytorch/score.py b/lira-pytorch/score.py new file mode 100644 index 0000000..68933fa --- /dev/null +++ b/lira-pytorch/score.py @@ -0,0 +1,70 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified copy by Chenxiang Zhang (orientino) of the original: +# https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021 + + +import argparse +import multiprocessing as mp +import os +from pathlib import Path + +import numpy as np +from torchvision.datasets import CIFAR10 + +parser = argparse.ArgumentParser() +parser.add_argument("--savedir", default="exp/cifar10", type=str) +args = parser.parse_args() + + +def load_one(path): + """ + This loads a logits and converts it to a scored prediction. + """ + opredictions = np.load(os.path.join(path, "logits.npy")) # [n_examples, n_augs, n_classes] + + # Be exceptionally careful. + # Numerically stable everything, as described in the paper. + predictions = opredictions - np.max(opredictions, axis=-1, keepdims=True) + predictions = np.array(np.exp(predictions), dtype=np.float64) + predictions = predictions / np.sum(predictions, axis=-1, keepdims=True) + + labels = get_labels() # TODO generalize this + + COUNT = predictions.shape[0] + y_true = predictions[np.arange(COUNT), :, labels[:COUNT]] + + print("mean acc", np.mean(predictions[:, 0, :].argmax(1) == labels[:COUNT])) + + predictions[np.arange(COUNT), :, labels[:COUNT]] = 0 + y_wrong = np.sum(predictions, axis=-1) + + logit = np.log(y_true + 1e-45) - np.log(y_wrong + 1e-45) + np.save(os.path.join(path, "scores.npy"), logit) + + +def get_labels(): + datadir = Path().home() / "opt/data/cifar" + train_ds = CIFAR10(root=datadir, train=True, download=True) + return np.array(train_ds.targets) + + +def load_stats(): + with mp.Pool(8) as p: + p.map(load_one, [os.path.join(args.savedir, x) for x in os.listdir(args.savedir)]) + + +if __name__ == "__main__": + load_stats() diff --git a/lira-pytorch/train.py b/lira-pytorch/train.py new file mode 100644 index 0000000..5c6a538 --- /dev/null +++ b/lira-pytorch/train.py @@ -0,0 +1,180 @@ +# PyTorch implementation of +# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/train.py +# +# author: Chenxiang Zhang (orientino) + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +import wandb +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision import models, transforms +from torchvision.datasets import CIFAR10 +from tqdm import tqdm +from opacus.validators import ModuleValidator +from opacus import PrivacyEngine +from opacus.utils.batch_memory_manager import BatchMemoryManager + +from wide_resnet import WideResNet + +parser = argparse.ArgumentParser() +parser.add_argument("--lr", default=0.1, type=float) +parser.add_argument("--epochs", default=1, type=int) +parser.add_argument("--n_shadows", default=16, type=int) +parser.add_argument("--shadow_id", default=1, type=int) +parser.add_argument("--model", default="resnet18", type=str) +parser.add_argument("--pkeep", default=0.5, type=float) +parser.add_argument("--savedir", default="exp/cifar10", type=str) +parser.add_argument("--debug", action="store_true") +args = parser.parse_args() + +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") +EPOCHS = args.epochs + + +def run(): + seed = np.random.randint(0, 1000000000) + seed ^= int(time.time()) + pl.seed_everything(seed) + + args.debug = True + wandb.init(project="lira", mode="disabled" if args.debug else "online") + wandb.config.update(args) + + # Dataset + train_transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), + ] + ) + test_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]), + ] + ) + datadir = Path().home() / "opt/data/cifar" + train_ds = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) + test_ds = CIFAR10(root=datadir, train=False, download=True, transform=test_transform) + + # Compute the IN / OUT subset: + # If we run each experiment independently then even after a lot of trials + # there will still probably be some examples that were always included + # or always excluded. So instead, with experiment IDs, we guarantee that + # after `args.n_shadows` are done, each example is seen exactly half + # of the time in train, and half of the time not in train. + + size = len(train_ds) + np.random.seed(seed) + if args.n_shadows is not None: + np.random.seed(0) + keep = np.random.uniform(0, 1, size=(args.n_shadows, size)) + order = keep.argsort(0) + keep = order < int(args.pkeep * args.n_shadows) + keep = np.array(keep[args.shadow_id], dtype=bool) + keep = keep.nonzero()[0] + else: + keep = np.random.choice(size, size=int(args.pkeep * size), replace=False) + keep.sort() + keep_bool = np.full((size), False) + keep_bool[keep] = True + + train_ds = torch.utils.data.Subset(train_ds, keep) + train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4) + test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4) + + # Model + if args.model == "wresnet28-2": + m = WideResNet(28, 2, 0.0, 10) + print("one") + elif args.model == "wresnet28-10": + m = WideResNet(28, 10, 0.3, 10) + print("two") + elif args.model == "resnet18": + print("three") + m = models.resnet18(weights=None, num_classes=10) + m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + m.maxpool = nn.Identity() + else: + raise NotImplementedError + m = m.to(DEVICE) + + m = ModuleValidator.fix(m) + ModuleValidator.validate(m, strict=True) + + optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs) + + privacy_engine = PrivacyEngine(accountant='rdp', secure_mod=True) + m, optim, train_dl = privacy_engine.make_private_with_epsilon( + module=m, + optimizer=optim, + data_loader=train_dl, + epochs=args.epochs, + target_epsilon=1, + target_delta=1e-4, + max_grad_norm=1.0, + ) + + print(f"Device: {DEVICE}") + + # Train + # max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, + with BatchMemoryManager( + data_loader=train_dl, + max_physical_batch_size=1000, + optimizer=optim + ) as memory_safe_data_loader: + + for i in tqdm(range(args.epochs)): + m.train() + loss_total = 0 + pbar = tqdm(memory_safe_data_loader, leave=False) + #pbar = tqdm(train_dl, leave=False) + for itr, (x, y) in enumerate(pbar): + x, y = x.to(DEVICE), y.to(DEVICE) + + loss = F.cross_entropy(m(x), y) + loss_total += loss + + pbar.set_postfix_str(f"loss: {loss:.2f}") + optim.zero_grad() + loss.backward() + optim.step() + sched.step() + + wandb.log({"loss": loss_total / len(train_dl)}) + + print(f"[test] acc_test: {get_acc(m, test_dl):.4f}") + wandb.log({"acc_test": get_acc(m, test_dl)}) + + savedir = os.path.join(args.savedir, str(args.shadow_id)) + os.makedirs(savedir, exist_ok=True) + np.save(savedir + "/keep.npy", keep_bool) + torch.save(m.state_dict(), savedir + "/model.pt") + + +@torch.no_grad() +def get_acc(model, dl): + acc = [] + for x, y in dl: + x, y = x.to(DEVICE), y.to(DEVICE) + acc.append(torch.argmax(model(x), dim=1) == y) + acc = torch.cat(acc) + acc = torch.sum(acc) / len(acc) + + return acc.item() + + +if __name__ == "__main__": + run() diff --git a/lira-pytorch/wide_resnet.py b/lira-pytorch/wide_resnet.py new file mode 100644 index 0000000..3d9fc87 --- /dev/null +++ b/lira-pytorch/wide_resnet.py @@ -0,0 +1,75 @@ +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +class wide_basic(nn.Module): + def __init__(self, in_planes, planes, dropout_rate, stride=1): + super(wide_basic, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1) + self.dropout = nn.Dropout(p=dropout_rate) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + ) + + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = self.dropout(out) + out = self.conv2(F.relu(self.bn2(out))) + out += self.shortcut(x) + + return out + + +class WideResNet(nn.Module): + def __init__(self, depth, widen_factor, dropout_rate, n_classes): + super(WideResNet, self).__init__() + self.in_planes = 16 + + assert (depth - 4) % 6 == 0, "Wide-ResNet depth should be 6n+4" + n = (depth - 4) // 6 + k = widen_factor + stages = [16, 16 * k, 32 * k, 64 * k] + + self.conv1 = nn.Conv2d(3, stages[0], kernel_size=3, stride=1, padding=1) + self.layer1 = self._wide_layer(wide_basic, stages[1], n, dropout_rate, stride=1) + self.layer2 = self._wide_layer(wide_basic, stages[2], n, dropout_rate, stride=2) + self.layer3 = self._wide_layer(wide_basic, stages[3], n, dropout_rate, stride=2) + self.bn1 = nn.BatchNorm2d(stages[3], momentum=0.9) + self.linear = nn.Linear(stages[3], n_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(m.bias, 0) + + def _wide_layer(self, block, planes, n_blocks, dropout_rate, stride): + strides = [stride] + [1] * (int(n_blocks) - 1) + layers = [] + + for stride in strides: + layers.append(block(self.in_planes, planes, dropout_rate, stride)) + self.in_planes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = F.relu(self.bn1(out)) + out = F.avg_pool2d(out, 8) + out = out.view(out.size(0), -1) + out = self.linear(out) + + return out