# PyTorch implementation of # https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py # # author: Chenxiang Zhang (orientino) import argparse import os from pathlib import Path import numpy as np import torch from torch import nn from torch.utils.data import DataLoader from torchvision import models, transforms from torchvision.datasets import CIFAR10 from tqdm import tqdm import student_model from utils import json_file_to_pyobj, get_loaders parser = argparse.ArgumentParser() parser.add_argument("--n_queries", default=2, type=int) parser.add_argument("--model", default="resnet18", type=str) parser.add_argument("--savedir", default="exp/cifar10", type=str) args = parser.parse_args() @torch.no_grad() def run(): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") dataset = "cifar10" # Dataset train_dl, test_dl = get_loaders(dataset, 4096) # Infer the logits with multiple queries for path in os.listdir(args.savedir): m = student_model.Model(num_classes=10) m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt"))) m.to(DEVICE) m.eval() logits_n = [] for i in range(args.n_queries): logits = [] for x, _ in tqdm(train_dl): x = x.to(DEVICE) outputs = m(x) logits.append(outputs.cpu().numpy()) logits_n.append(np.concatenate(logits)) logits_n = np.stack(logits_n, axis=1) print(logits_n.shape) np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n) if __name__ == "__main__": run()