diff --git a/research/pate_2017/input.py b/research/pate_2017/input.py index 14c8023..86f302c 100644 --- a/research/pate_2017/input.py +++ b/research/pate_2017/input.py @@ -34,11 +34,7 @@ FLAGS = tf.flags.FLAGS def create_dir_if_needed(dest_directory): - """ - Create directory if doesn't exist - :param dest_directory: - :return: True if everything went well - """ + """Create directory if doesn't exist.""" if not tf.gfile.IsDirectory(dest_directory): tf.gfile.MakeDirs(dest_directory) @@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory): def maybe_download(file_urls, directory): - """ - Download a set of files in temporary local folder - :param directory: the directory where to download - :return: a tuple of filepaths corresponding to the files given as input - """ + """Download a set of files in temporary local folder.""" + # Create directory if doesn't exist assert create_dir_if_needed(directory) @@ -91,8 +84,6 @@ def image_whitening(data): """ Subtracts mean of image and divides by adjusted standard variance (for stability). Operations are per image but performed for the entire array. - :param image: 4D array (ID, Height, Weight, Channel) - :return: 4D array (ID, Height, Weight, Channel) """ assert len(np.shape(data)) == 4 @@ -119,11 +110,7 @@ def image_whitening(data): def extract_svhn(local_url): - """ - Extract a MATLAB matrix into two numpy arrays with data and labels - :param local_url: - :return: - """ + """Extract a MATLAB matrix into two numpy arrays with data and labels""" with tf.gfile.Open(local_url, mode='r') as file_obj: # Load MATLAB matrix using scipy IO @@ -149,11 +136,7 @@ def extract_svhn(local_url): def unpickle_cifar_dic(file_path): - """ - Helper function: unpickles a dictionary (used for loading CIFAR) - :param file_path: filename of the pickle - :return: tuple of (images, labels) - """ + """Helper function: unpickles a dictionary (used for loading CIFAR)""" file_obj = open(file_path, 'rb') data_dict = pickle.load(file_obj) file_obj.close() @@ -161,12 +144,8 @@ def unpickle_cifar_dic(file_path): def extract_cifar10(local_url, data_dir): - """ - Extracts the CIFAR-10 dataset and return numpy arrays with the different sets - :param local_url: where the tar.gz archive is located locally - :param data_dir: where to extract the archive's file - :return: a tuple (train data, train labels, test data, test labels) - """ + """Extracts CIFAR-10 and return numpy arrays with the different sets""" + # These numpy dumps can be reloaded to avoid performing the pre-processing # if they exist in the working directory. # Changing the order of this list will ruin the indices below. @@ -203,8 +182,8 @@ def extract_cifar10(local_url, data_dir): # Check if all files have already been extracted need_to_unpack = False - for file in cifar10_files: - if not tf.gfile.Exists(file): + for file_name in cifar10_files: + if not tf.gfile.Exists(file_name): need_to_unpack = True break @@ -292,9 +271,10 @@ def extract_mnist_labels(filename, num_images): def ld_svhn(extended=False, test_only=False): """ Load the original SVHN data - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters + + Args: + extended: include extended training data in the returned array + test_only: disables loading of both train and extra -> large speed up """ # Define files to be downloaded # WARNING: changing the order of this list will break indices (cf. below) @@ -334,12 +314,8 @@ def ld_svhn(extended=False, test_only=False): def ld_cifar10(test_only=False): - """ - Load the original CIFAR10 data - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters - """ + """Load the original CIFAR10 data""" + # Define files to be downloaded file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'] @@ -363,12 +339,7 @@ def ld_cifar10(test_only=False): def ld_mnist(test_only=False): - """ - Load the MNIST dataset - :param extended: include extended training data in the returned array - :param test_only: disables loading of both train and extra -> large speed up - :return: tuple of arrays which depend on the parameters - """ + """Load the MNIST dataset.""" # Define files to be downloaded # WARNING: changing the order of this list will break indices (cf. below) file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', @@ -396,12 +367,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id): """ Simple partitioning algorithm that returns the right portion of the data needed by a given teacher out of a certain nb of teachers - :param data: input data to be partitioned - :param labels: output data to be partitioned - :param nb_teachers: number of teachers in the ensemble (affects size of each + + Args: + data: input data to be partitioned + labels: output data to be partitioned + nb_teachers: number of teachers in the ensemble (affects size of each partition) - :param teacher_id: id of partition to retrieve - :return: + teacher_id: id of partition to retrieve """ # Sanity check