151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
|
# Copyright 2021 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import functools
|
||
|
import os
|
||
|
from typing import Callable
|
||
|
import json
|
||
|
|
||
|
import re
|
||
|
import jax
|
||
|
import jax.numpy as jn
|
||
|
import numpy as np
|
||
|
import tensorflow as tf # For data augmentation.
|
||
|
import tensorflow_datasets as tfds
|
||
|
from absl import app, flags
|
||
|
from tqdm import tqdm, trange
|
||
|
import pickle
|
||
|
from functools import partial
|
||
|
|
||
|
import objax
|
||
|
from objax.jaxboard import SummaryWriter, Summary
|
||
|
from objax.util import EasyDict
|
||
|
from objax.zoo import convnet, wide_resnet
|
||
|
|
||
|
from dataset import DataSet
|
||
|
|
||
|
from train import MemModule, network
|
||
|
|
||
|
from collections import defaultdict
|
||
|
FLAGS = flags.FLAGS
|
||
|
|
||
|
|
||
|
def main(argv):
|
||
|
"""
|
||
|
Perform inference of the saved model in order to generate the
|
||
|
output logits, using a particular set of augmentations.
|
||
|
"""
|
||
|
del argv
|
||
|
tf.config.experimental.set_visible_devices([], "GPU")
|
||
|
|
||
|
def load(arch):
|
||
|
return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
|
||
|
mnist=FLAGS.dataset == 'mnist',
|
||
|
arch=arch,
|
||
|
lr=.1,
|
||
|
batch=0,
|
||
|
epochs=0,
|
||
|
weight_decay=0)
|
||
|
|
||
|
def cache_load(arch):
|
||
|
thing = []
|
||
|
def fn():
|
||
|
if len(thing) == 0:
|
||
|
thing.append(load(arch))
|
||
|
return thing[0]
|
||
|
return fn
|
||
|
|
||
|
xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
|
||
|
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]
|
||
|
|
||
|
|
||
|
def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
|
||
|
|
||
|
outs = []
|
||
|
for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
|
||
|
aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
|
||
|
for dx in range(0, 2*shift+1, stride):
|
||
|
for dy in range(0, 2*shift+1, stride):
|
||
|
this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
|
||
|
|
||
|
logits = model.model(this_x, training=True)
|
||
|
outs.append(logits)
|
||
|
|
||
|
print(np.array(outs).shape)
|
||
|
return np.array(outs).transpose((1, 0, 2))
|
||
|
|
||
|
N = 5000
|
||
|
|
||
|
def features(model, xbatch, ybatch):
|
||
|
return get_loss(model, xbatch, ybatch,
|
||
|
shift=0, reflect=True, stride=1)
|
||
|
|
||
|
for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
|
||
|
if re.search(FLAGS.regex, path) is None:
|
||
|
print("Skipping from regex")
|
||
|
continue
|
||
|
|
||
|
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
|
||
|
arch = hparams['arch']
|
||
|
model = cache_load(arch)()
|
||
|
|
||
|
logdir = os.path.join(FLAGS.logdir, path)
|
||
|
|
||
|
checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
|
||
|
max_epoch, last_ckpt = checkpoint.restore(model.vars())
|
||
|
if max_epoch == 0: continue
|
||
|
|
||
|
if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
|
||
|
os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
|
||
|
if FLAGS.from_epoch is not None:
|
||
|
first = FLAGS.from_epoch
|
||
|
else:
|
||
|
first = max_epoch-1
|
||
|
|
||
|
for epoch in range(first,max_epoch+1):
|
||
|
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
|
||
|
# no checkpoint saved here
|
||
|
continue
|
||
|
|
||
|
if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
|
||
|
print("Skipping already generated file", epoch)
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
|
||
|
except:
|
||
|
print("Fail to load", epoch)
|
||
|
continue
|
||
|
|
||
|
stats = []
|
||
|
|
||
|
for i in range(0,len(xs_all),N):
|
||
|
stats.extend(features(model, xs_all[i:i+N],
|
||
|
ys_all[i:i+N]))
|
||
|
# This will be shape N, augs, nclass
|
||
|
|
||
|
np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
|
||
|
np.array(stats)[:,None,:,:])
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
|
||
|
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
|
||
|
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
|
||
|
flags.DEFINE_bool('random_labels', False, 'use random labels.')
|
||
|
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
|
||
|
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
|
||
|
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
|
||
|
flags.DEFINE_integer('modulus', 8, 'modulus.')
|
||
|
app.run(main)
|