# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Tuple, List

import numpy as np
import tensorflow as tf


def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
    features = tf.io.parse_single_example(serialized_example,
                                          features={'image': tf.io.FixedLenFeature([], tf.string),
                                                    'label': tf.io.FixedLenFeature([], tf.int64)})
    image = tf.image.decode_image(features['image']).set_shape(image_shape)
    image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
    return dict(image=image, label=features['label'])


class DataSet:
    """Wrapper for tf.data.Dataset to permit extensions."""

    def __init__(self, data: tf.data.Dataset,
                 image_shape: Tuple[int, int, int],
                 augment_fn: Optional[Callable] = None,
                 parse_fn: Optional[Callable] = record_parse):
        self.data = data
        self.parse_fn = parse_fn
        self.augment_fn = augment_fn
        self.image_shape = image_shape

    @classmethod
    def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
        return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
                   augment_fn=augment_fn, parse_fn=None)

    @classmethod
    def from_files(cls, filenames: List[str],
                   image_shape: Tuple[int, int, int],
                   augment_fn: Optional[Callable],
                   parse_fn: Optional[Callable] = record_parse):
        filenames_in = filenames
        filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
        if not filenames:
            raise ValueError('Empty dataset, files not found:', filenames_in)
        return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)

    @classmethod
    def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
                  augment_fn: Optional[Callable] = None):
        return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
                   image_shape, augment_fn=augment_fn, parse_fn=None)

    def __iter__(self):
        return iter(self.data)

    def __getattr__(self, item):
        if item in self.__dict__:
            return self.__dict__[item]

        def call_and_update(*args, **kwargs):
            v = getattr(self.__dict__['data'], item)(*args, **kwargs)
            if isinstance(v, tf.data.Dataset):
                return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
            return v

        return call_and_update

    def augment(self, para_augment: int = 4):
        if self.augment_fn:
            return self.map(self.augment_fn, para_augment)
        return self

    def nchw(self):
        return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))

    def one_hot(self, nclass: int):
        return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))

    def parse(self, para_parse: int = 2):
        if not self.parse_fn:
            return self
        if self.image_shape:
            return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
        return self.map(self.parse_fn, para_parse)