diff --git a/research/pate_2017/input.py b/research/pate_2017/input.py index 0a1d89f..d838806 100644 --- a/research/pate_2017/input.py +++ b/research/pate_2017/input.py @@ -17,28 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import cPickle import gzip import math -import numpy as np import os -from scipy.io import loadmat as loadmat -from six.moves import urllib -from six.moves import xrange import sys import tarfile +import numpy as np +from scipy.io import loadmat as loadmat +from six.moves import cPickle as pickle +from six.moves import urllib +from six.moves import xrange import tensorflow as tf FLAGS = tf.flags.FLAGS def create_dir_if_needed(dest_directory): - """ - Create directory if doesn't exist - :param dest_directory: - :return: True if everything went well - """ + """Create directory if doesn't exist.""" if not tf.gfile.IsDirectory(dest_directory): tf.gfile.MakeDirs(dest_directory) @@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory): def maybe_download(file_urls, directory): - """ - Download a set of files in temporary local folder - :param directory: the directory where to download - :return: a tuple of filepaths corresponding to the files given as input - """ + """Download a set of files in temporary local folder.""" + # Create directory if doesn't exist assert create_dir_if_needed(directory) @@ -91,8 +84,6 @@ def image_whitening(data): """ Subtracts mean of image and divides by adjusted standard variance (for stability). Operations are per image but performed for the entire array. - :param image: 4D array (ID, Height, Weight, Channel) - :return: 4D array (ID, Height, Weight, Channel) """ assert len(np.shape(data)) == 4 @@ -100,14 +91,14 @@ def image_whitening(data): nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3] # Subtract mean - mean = np.mean(data, axis=(1,2,3)) + mean = np.mean(data, axis=(1, 2, 3)) ones = np.ones(np.shape(data)[1:4], dtype=np.float32) for i in xrange(len(data)): data[i, :, :, :] -= mean[i] * ones # Compute adjusted standard variance - adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line) + adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long # Divide image for i in xrange(len(data)): @@ -119,18 +110,14 @@ def image_whitening(data): def extract_svhn(local_url): - """ - Extract a MATLAB matrix into two numpy arrays with data and labels - :param local_url: - :return: - """ + """Extract a MATLAB matrix into two numpy arrays with data and labels.""" with tf.gfile.Open(local_url, mode='r') as file_obj: # Load MATLAB matrix using scipy IO - dict = loadmat(file_obj) + data_dict = loadmat(file_obj) # Extract each dictionary (one for data, one for labels) - data, labels = dict["X"], dict["y"] + data, labels = data_dict['X'], data_dict['y'] # Set np type data = np.asarray(data, dtype=np.float32) @@ -148,25 +135,17 @@ def extract_svhn(local_url): return data, labels -def unpickle_cifar_dic(file): - """ - Helper function: unpickles a dictionary (used for loading CIFAR) - :param file: filename of the pickle - :return: tuple of (images, labels) - """ - fo = open(file, 'rb') - dict = cPickle.load(fo) - fo.close() - return dict['data'], dict['labels'] +def unpickle_cifar_dic(file_path): + """Helper function: unpickles a dictionary (used for loading CIFAR).""" + file_obj = open(file_path, 'rb') + data_dict = pickle.load(file_obj) + file_obj.close() + return data_dict['data'], data_dict['labels'] def extract_cifar10(local_url, data_dir): - """ - Extracts the CIFAR-10 dataset and return numpy arrays with the different sets - :param local_url: where the tar.gz archive is located locally - :param data_dir: where to extract the archive's file - :return: a tuple (train data, train labels, test data, test labels) - """ + """Extracts CIFAR-10 and return numpy arrays with the different sets.""" + # These numpy dumps can be reloaded to avoid performing the pre-processing # if they exist in the working directory. # Changing the order of this list will ruin the indices below. @@ -176,8 +155,8 @@ def extract_cifar10(local_url, data_dir): '/cifar10_test_labels.npy'] all_preprocessed = True - for file in preprocessed_files: - if not tf.gfile.Exists(data_dir + file): + for file_name in preprocessed_files: + if not tf.gfile.Exists(data_dir + file_name): all_preprocessed = False break @@ -197,14 +176,14 @@ def extract_cifar10(local_url, data_dir): else: # Do everything from scratch # Define lists of all files we should extract - train_files = ["data_batch_" + str(i) for i in xrange(1,6)] - test_file = ["test_batch"] + train_files = ['data_batch_' + str(i) for i in xrange(1, 6)] + test_file = ['test_batch'] cifar10_files = train_files + test_file # Check if all files have already been extracted need_to_unpack = False - for file in cifar10_files: - if not tf.gfile.Exists(file): + for file_name in cifar10_files: + if not tf.gfile.Exists(file_name): need_to_unpack = True break @@ -215,9 +194,9 @@ def extract_cifar10(local_url, data_dir): # Load training images and labels images = [] labels = [] - for file in train_files: + for train_file in train_files: # Construct filename - filename = data_dir + "/cifar-10-batches-py/" + file + filename = data_dir + '/cifar-10-batches-py/' + train_file # Unpickle dictionary and extract images and labels images_tmp, labels_tmp = unpickle_cifar_dic(filename) @@ -227,7 +206,8 @@ def extract_cifar10(local_url, data_dir): labels.append(labels_tmp) # Convert to numpy arrays and reshape in the expected format - train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32)) + train_data = np.asarray(images, dtype=np.float32) + train_data = train_data.reshape((50000, 3, 32, 32)) train_data = np.swapaxes(train_data, 1, 3) train_labels = np.asarray(labels, dtype=np.int32).reshape(50000) @@ -236,13 +216,14 @@ def extract_cifar10(local_url, data_dir): np.save(data_dir + preprocessed_files[1], train_labels) # Construct filename for test file - filename = data_dir + "/cifar-10-batches-py/" + test_file[0] + filename = data_dir + '/cifar-10-batches-py/' + test_file[0] # Load test images and labels test_data, test_images = unpickle_cifar_dic(filename) # Convert to numpy arrays and reshape in the expected format - test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32)) + test_data = np.asarray(test_data, dtype=np.float32) + test_data = test_data.reshape((10000, 3, 32, 32)) test_data = np.swapaxes(test_data, 1, 3) test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000) @@ -259,8 +240,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth): Values are rescaled from [0, 255] down to [-0.5, 0.5]. """ - # if not os.path.exists(file): - if not tf.gfile.Exists(filename+".npy"): + if not tf.gfile.Exists(filename+'.npy'): with gzip.open(filename) as bytestream: bytestream.read(16) buf = bytestream.read(image_size * image_size * num_images) @@ -270,7 +250,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth): np.save(filename, data) return data else: - with tf.gfile.Open(filename+".npy", mode='r') as file_obj: + with tf.gfile.Open(filename+'.npy', mode='r') as file_obj: return np.load(file_obj) @@ -278,8 +258,7 @@ def extract_mnist_labels(filename, num_images): """ Extract the labels into a vector of int64 label IDs. """ - # if not os.path.exists(file): - if not tf.gfile.Exists(filename+".npy"): + if not tf.gfile.Exists(filename+'.npy'): with gzip.open(filename) as bytestream: bytestream.read(8) buf = bytestream.read(1 * num_images) @@ -287,16 +266,17 @@ def extract_mnist_labels(filename, num_images): np.save(filename, labels) return labels else: - with tf.gfile.Open(filename+".npy", mode='r') as file_obj: + with tf.gfile.Open(filename+'.npy', mode='r') as file_obj: return np.load(file_obj) def ld_svhn(extended=False, test_only=False): """ Load the original SVHN data - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters + + Args: + extended: include extended training data in the returned array + test_only: disables loading of both train and extra -> large speed up """ # Define files to be downloaded # WARNING: changing the order of this list will break indices (cf. below) @@ -332,16 +312,12 @@ def ld_svhn(extended=False, test_only=False): return train_data, train_labels, test_data, test_labels else: # Return training and extended training data separately - return train_data,train_labels, test_data,test_labels, ext_data,ext_labels + return train_data, train_labels, test_data, test_labels, ext_data, ext_labels def ld_cifar10(test_only=False): - """ - Load the original CIFAR10 data - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters - """ + """Load the original CIFAR10 data.""" + # Define files to be downloaded file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'] @@ -365,19 +341,14 @@ def ld_cifar10(test_only=False): def ld_mnist(test_only=False): - """ - Load the MNIST dataset - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters - """ + """Load the MNIST dataset.""" # Define files to be downloaded # WARNING: changing the order of this list will break indices (cf. below) file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', - ] + ] # Maybe download data and retrieve local storage urls local_urls = maybe_download(file_urls, FLAGS.data_dir) @@ -398,12 +369,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id): """ Simple partitioning algorithm that returns the right portion of the data needed by a given teacher out of a certain nb of teachers - :param data: input data to be partitioned - :param labels: output data to be partitioned - :param nb_teachers: number of teachers in the ensemble (affects size of each + + Args: + data: input data to be partitioned + labels: output data to be partitioned + nb_teachers: number of teachers in the ensemble (affects size of each partition) - :param teacher_id: id of partition to retrieve - :return: + teacher_id: id of partition to retrieve """ # Sanity check