mia_on_model_distillation/lira-pytorch/inference.py
2024-12-02 17:58:01 -07:00

58 lines
1.6 KiB
Python

# PyTorch implementation of
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
#
# author: Chenxiang Zhang (orientino)
import argparse
import os
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm
import student_model
from utils import json_file_to_pyobj, get_loaders
parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args()
@torch.no_grad()
def run():
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
dataset = "cifar10"
# Dataset
train_dl, test_dl = get_loaders(dataset, 4096)
# Infer the logits with multiple queries
for path in os.listdir(args.savedir):
m = student_model.Model(num_classes=10)
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
m.to(DEVICE)
m.eval()
logits_n = []
for i in range(args.n_queries):
logits = []
for x, _ in tqdm(train_dl):
x = x.to(DEVICE)
outputs = m(x)
logits.append(outputs.cpu().numpy())
logits_n.append(np.concatenate(logits))
logits_n = np.stack(logits_n, axis=1)
print(logits_n.shape)
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
if __name__ == "__main__":
run()