forked from 626_privacy/tensorflow_privacy
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
|
# Copyright 2020 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from typing import Callable, Optional, Tuple, List
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
|
||
|
features = tf.io.parse_single_example(serialized_example,
|
||
|
features={'image': tf.io.FixedLenFeature([], tf.string),
|
||
|
'label': tf.io.FixedLenFeature([], tf.int64)})
|
||
|
image = tf.image.decode_image(features['image']).set_shape(image_shape)
|
||
|
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
|
||
|
return dict(image=image, label=features['label'])
|
||
|
|
||
|
|
||
|
class DataSet:
|
||
|
"""Wrapper for tf.data.Dataset to permit extensions."""
|
||
|
|
||
|
def __init__(self, data: tf.data.Dataset,
|
||
|
image_shape: Tuple[int, int, int],
|
||
|
augment_fn: Optional[Callable] = None,
|
||
|
parse_fn: Optional[Callable] = record_parse):
|
||
|
self.data = data
|
||
|
self.parse_fn = parse_fn
|
||
|
self.augment_fn = augment_fn
|
||
|
self.image_shape = image_shape
|
||
|
|
||
|
@classmethod
|
||
|
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
|
||
|
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
|
||
|
augment_fn=augment_fn, parse_fn=None)
|
||
|
|
||
|
@classmethod
|
||
|
def from_files(cls, filenames: List[str],
|
||
|
image_shape: Tuple[int, int, int],
|
||
|
augment_fn: Optional[Callable],
|
||
|
parse_fn: Optional[Callable] = record_parse):
|
||
|
filenames_in = filenames
|
||
|
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
|
||
|
if not filenames:
|
||
|
raise ValueError('Empty dataset, files not found:', filenames_in)
|
||
|
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
|
||
|
|
||
|
@classmethod
|
||
|
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
|
||
|
augment_fn: Optional[Callable] = None):
|
||
|
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
|
||
|
image_shape, augment_fn=augment_fn, parse_fn=None)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self.data)
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
if item in self.__dict__:
|
||
|
return self.__dict__[item]
|
||
|
|
||
|
def call_and_update(*args, **kwargs):
|
||
|
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
|
||
|
if isinstance(v, tf.data.Dataset):
|
||
|
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
|
||
|
return v
|
||
|
|
||
|
return call_and_update
|
||
|
|
||
|
def augment(self, para_augment: int = 4):
|
||
|
if self.augment_fn:
|
||
|
return self.map(self.augment_fn, para_augment)
|
||
|
return self
|
||
|
|
||
|
def nchw(self):
|
||
|
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
|
||
|
|
||
|
def one_hot(self, nclass: int):
|
||
|
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
|
||
|
|
||
|
def parse(self, para_parse: int = 2):
|
||
|
if not self.parse_fn:
|
||
|
return self
|
||
|
if self.image_shape:
|
||
|
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
|
||
|
return self.map(self.parse_fn, para_parse)
|