64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
# PyTorch implementation of
|
|
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
|
|
#
|
|
# author: Chenxiang Zhang (orientino)
|
|
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import models, transforms
|
|
from torchvision.datasets import CIFAR10
|
|
from tqdm import tqdm
|
|
|
|
import student_model
|
|
from utils import json_file_to_pyobj, get_loaders
|
|
from distillation_utils import get_teacherstudent_trainset
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--n_queries", default=2, type=int)
|
|
parser.add_argument("--model", default="resnet18", type=str)
|
|
parser.add_argument("--savedir", default="exp/cifar10", type=str)
|
|
args = parser.parse_args()
|
|
SEED = 42
|
|
|
|
@torch.no_grad()
|
|
def run():
|
|
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
|
|
dataset = "cifar10"
|
|
|
|
# Dataset
|
|
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
|
|
training_configurations = json_options.training
|
|
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
|
|
train_ds, test_ds = studentset, testset
|
|
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
|
|
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
|
|
|
|
# Infer the logits with multiple queries
|
|
for path in os.listdir(args.savedir):
|
|
m = student_model.Model(num_classes=10)
|
|
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
|
|
m.to(DEVICE)
|
|
m.eval()
|
|
|
|
logits_n = []
|
|
for i in range(args.n_queries):
|
|
logits = []
|
|
for x, _ in tqdm(train_dl):
|
|
x = x.to(DEVICE)
|
|
outputs = m(x)
|
|
logits.append(outputs.cpu().numpy())
|
|
logits_n.append(np.concatenate(logits))
|
|
logits_n = np.stack(logits_n, axis=1)
|
|
print(logits_n.shape)
|
|
|
|
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|