2024-11-29 17:16:09 -07:00
|
|
|
# PyTorch implementation of
|
|
|
|
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
|
|
|
|
#
|
|
|
|
# author: Chenxiang Zhang (orientino)
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torchvision import models, transforms
|
|
|
|
from torchvision.datasets import CIFAR10
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2024-12-02 17:54:29 -07:00
|
|
|
import student_model
|
|
|
|
from utils import json_file_to_pyobj, get_loaders
|
2024-11-29 17:16:09 -07:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--n_queries", default=2, type=int)
|
|
|
|
parser.add_argument("--model", default="resnet18", type=str)
|
|
|
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def run():
|
|
|
|
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
2024-12-02 17:54:29 -07:00
|
|
|
dataset = "cifar10"
|
2024-11-29 17:16:09 -07:00
|
|
|
|
|
|
|
# Dataset
|
2024-12-02 17:54:29 -07:00
|
|
|
train_dl, test_dl = get_loaders(dataset, 4096)
|
2024-11-29 17:16:09 -07:00
|
|
|
|
|
|
|
# Infer the logits with multiple queries
|
|
|
|
for path in os.listdir(args.savedir):
|
2024-12-02 17:54:29 -07:00
|
|
|
m = student_model.Model(num_classes=10)
|
2024-11-29 17:16:09 -07:00
|
|
|
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
|
|
|
|
m.to(DEVICE)
|
|
|
|
m.eval()
|
|
|
|
|
|
|
|
logits_n = []
|
|
|
|
for i in range(args.n_queries):
|
|
|
|
logits = []
|
|
|
|
for x, _ in tqdm(train_dl):
|
|
|
|
x = x.to(DEVICE)
|
|
|
|
outputs = m(x)
|
|
|
|
logits.append(outputs.cpu().numpy())
|
|
|
|
logits_n.append(np.concatenate(logits))
|
|
|
|
logits_n = np.stack(logits_n, axis=1)
|
|
|
|
print(logits_n.shape)
|
|
|
|
|
|
|
|
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
run()
|