# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pyformat: disable

import json
import os
import re

import numpy as np
import objax
import tensorflow as tf  # For data augmentation.
from absl import app
from absl import flags

from train import MemModule
from train import network

FLAGS = flags.FLAGS


def main(argv):
    """
    Perform inference of the saved model in order to generate the
    output logits, using a particular set of augmentations.
    """
    del argv
    tf.config.experimental.set_visible_devices([], "GPU")

    def load(arch):
        return MemModule(network(arch), nclass=100 if FLAGS.dataset == 'cifar100' else 10,
                         mnist=FLAGS.dataset == 'mnist',
                         arch=arch,
                         lr=.1,
                         batch=0,
                         epochs=0,
                         weight_decay=0)

    def cache_load(arch):
        thing = []
        def fn():
            if len(thing) == 0:
                thing.append(load(arch))
            return thing[0]
        return fn

    xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
    ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]


    def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):

        outs = []
        for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
            aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
            for dx in range(0, 2*shift+1, stride):
                for dy in range(0, 2*shift+1, stride):
                    this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))

                    logits = model.model(this_x, training=True)
                    outs.append(logits)

        print(np.array(outs).shape)
        return np.array(outs).transpose((1, 0, 2))

    N = 5000

    def features(model, xbatch, ybatch):
        return get_loss(model, xbatch, ybatch,
                        shift=0, reflect=True, stride=1)

    for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
        if re.search(FLAGS.regex, path) is None:
            print("Skipping from regex")
            continue

        hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
        arch = hparams['arch']
        model = cache_load(arch)()

        logdir = os.path.join(FLAGS.logdir, path)

        checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
        max_epoch, last_ckpt = checkpoint.restore(model.vars())
        if max_epoch == 0: continue

        if not os.path.exists(os.path.join(FLAGS.logdir, path, "logits")):
            os.mkdir(os.path.join(FLAGS.logdir, path, "logits"))
        if FLAGS.from_epoch is not None:
            first = FLAGS.from_epoch
        else:
            first = max_epoch-1

        for epoch in range(first,max_epoch+1):
            if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
                # no checkpoint saved here
                continue

            if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
                print("Skipping already generated file", epoch)
                continue

            try:
                start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
            except:
                print("Fail to load", epoch)
                continue

            stats = []

            for i in range(0,len(xs_all),N):
                stats.extend(features(model, xs_all[i:i+N],
                                      ys_all[i:i+N]))
            # This will be shape N, augs, nclass

            np.save(os.path.join(FLAGS.logdir, path, "logits", "%010d"%epoch),
                    np.array(stats)[:,None,:,:])

if __name__ == '__main__':
    flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
    flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
    flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
    flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
    flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
    app.run(main)