tensorflow_privacy/research/mi_lira_2021/dataset.py

96 lines
3.9 KiB
Python
Raw Normal View History

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, List
import numpy as np
import tensorflow as tf
def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
features = tf.io.parse_single_example(serialized_example,
features={'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)})
image = tf.image.decode_image(features['image']).set_shape(image_shape)
image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
return dict(image=image, label=features['label'])
class DataSet:
"""Wrapper for tf.data.Dataset to permit extensions."""
def __init__(self, data: tf.data.Dataset,
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None,
parse_fn: Optional[Callable] = record_parse):
self.data = data
self.parse_fn = parse_fn
self.augment_fn = augment_fn
self.image_shape = image_shape
@classmethod
def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
augment_fn=augment_fn, parse_fn=None)
@classmethod
def from_files(cls, filenames: List[str],
image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable],
parse_fn: Optional[Callable] = record_parse):
filenames_in = filenames
filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
if not filenames:
raise ValueError('Empty dataset, files not found:', filenames_in)
return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
@classmethod
def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
augment_fn: Optional[Callable] = None):
return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
image_shape, augment_fn=augment_fn, parse_fn=None)
def __iter__(self):
return iter(self.data)
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
def call_and_update(*args, **kwargs):
v = getattr(self.__dict__['data'], item)(*args, **kwargs)
if isinstance(v, tf.data.Dataset):
return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
return v
return call_and_update
def augment(self, para_augment: int = 4):
if self.augment_fn:
return self.map(self.augment_fn, para_augment)
return self
def nchw(self):
return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
def one_hot(self, nclass: int):
return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
def parse(self, para_parse: int = 2):
if not self.parse_fn:
return self
if self.image_shape:
return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
return self.map(self.parse_fn, para_parse)