# PyTorch implementation of # https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py # # author: Chenxiang Zhang (orientino) import argparse import os from pathlib import Path import numpy as np import torch from torch import nn from torch.utils.data import DataLoader from torchvision import models, transforms from torchvision.datasets import CIFAR10 from tqdm import tqdm import student_model from utils import json_file_to_pyobj, get_loaders from distillation_utils import get_teacherstudent_trainset parser = argparse.ArgumentParser() parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str) args = parser.parse_args() SEED = 42 @torch.no_grad() def run(): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") dataset = "cifar10" # Dataset json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json") training_configurations = json_options.training teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED) train_ds, test_ds = studentset, testset train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4) # Infer the logits with multiple queries for path in os.listdir(args.savedir): m = student_model.Model(num_classes=10) m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt"))) m.to(DEVICE) m.eval() logits_n = [] for i in range(args.n_queries): logits = [] for x, _ in tqdm(train_dl): x = x.to(DEVICE) outputs = m(x) logits.append(outputs.cpu().numpy()) logits_n.append(np.concatenate(logits)) logits_n = np.stack(logits_n, axis=1) print(logits_n.shape) np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n) if __name__ == "__main__": run()