mia_on_model_distillation/lira-pytorch/inference.py

65 lines
2.1 KiB
Python
Raw Normal View History

2024-11-29 17:16:09 -07:00
# PyTorch implementation of
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
#
# author: Chenxiang Zhang (orientino)
import argparse
import os
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm
2024-12-02 17:54:29 -07:00
import student_model
from utils import json_file_to_pyobj, get_loaders
from distillation_utils import get_teacherstudent_trainset
2024-11-29 17:16:09 -07:00
parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args()
SEED = 42
2024-11-29 17:16:09 -07:00
@torch.no_grad()
def run():
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
2024-12-02 17:54:29 -07:00
dataset = "cifar10"
2024-11-29 17:16:09 -07:00
# Dataset
json_options = json_file_to_pyobj("wresnet16-audit-cifar10.json")
training_configurations = json_options.training
teacherset, studentset, testset = get_teacherstudent_trainset(training_configurations.batch_size, 10, SEED)
train_ds, test_ds = studentset, testset
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)
2024-11-29 17:16:09 -07:00
# Infer the logits with multiple queries
for path in os.listdir(args.savedir):
2024-12-02 17:54:29 -07:00
m = student_model.Model(num_classes=10)
2024-11-29 17:16:09 -07:00
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
m.to(DEVICE)
m.eval()
logits_n = []
for i in range(args.n_queries):
logits = []
for x, _ in tqdm(train_dl):
x = x.to(DEVICE)
outputs = m(x)
logits.append(outputs.cpu().numpy())
logits_n.append(np.concatenate(logits))
logits_n = np.stack(logits_n, axis=1)
print(logits_n.shape)
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
if __name__ == "__main__":
run()