mia_on_model_distillation/lira-pytorch/inference.py

59 lines
1.6 KiB
Python
Raw Permalink Normal View History

2024-11-29 17:16:09 -07:00
# PyTorch implementation of
# https://github.com/tensorflow/privacy/blob/master/research/mi_lira_2021/inference.py
#
# author: Chenxiang Zhang (orientino)
import argparse
import os
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm
2024-12-02 17:54:29 -07:00
import student_model
from utils import json_file_to_pyobj, get_loaders
2024-11-29 17:16:09 -07:00
parser = argparse.ArgumentParser()
parser.add_argument("--n_queries", default=2, type=int)
parser.add_argument("--model", default="resnet18", type=str)
parser.add_argument("--savedir", default="exp/cifar10", type=str)
args = parser.parse_args()
@torch.no_grad()
def run():
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
2024-12-02 17:54:29 -07:00
dataset = "cifar10"
2024-11-29 17:16:09 -07:00
# Dataset
2024-12-02 17:54:29 -07:00
train_dl, test_dl = get_loaders(dataset, 4096)
2024-11-29 17:16:09 -07:00
# Infer the logits with multiple queries
for path in os.listdir(args.savedir):
2024-12-02 17:54:29 -07:00
m = student_model.Model(num_classes=10)
2024-11-29 17:16:09 -07:00
m.load_state_dict(torch.load(os.path.join(args.savedir, path, "model.pt")))
m.to(DEVICE)
m.eval()
logits_n = []
for i in range(args.n_queries):
logits = []
for x, _ in tqdm(train_dl):
x = x.to(DEVICE)
outputs = m(x)
logits.append(outputs.cpu().numpy())
logits_n.append(np.concatenate(logits))
logits_n = np.stack(logits_n, axis=1)
print(logits_n.shape)
np.save(os.path.join(args.savedir, path, "logits.npy"), logits_n)
if __name__ == "__main__":
run()