docstrings

This commit is contained in:
npapernot 2019-03-18 16:58:06 +00:00
parent e55a832d54
commit a209988d87

View file

@ -34,11 +34,7 @@ FLAGS = tf.flags.FLAGS
def create_dir_if_needed(dest_directory): def create_dir_if_needed(dest_directory):
""" """Create directory if doesn't exist."""
Create directory if doesn't exist
:param dest_directory:
:return: True if everything went well
"""
if not tf.gfile.IsDirectory(dest_directory): if not tf.gfile.IsDirectory(dest_directory):
tf.gfile.MakeDirs(dest_directory) tf.gfile.MakeDirs(dest_directory)
@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory):
def maybe_download(file_urls, directory): def maybe_download(file_urls, directory):
""" """Download a set of files in temporary local folder."""
Download a set of files in temporary local folder
:param directory: the directory where to download
:return: a tuple of filepaths corresponding to the files given as input
"""
# Create directory if doesn't exist # Create directory if doesn't exist
assert create_dir_if_needed(directory) assert create_dir_if_needed(directory)
@ -91,8 +84,6 @@ def image_whitening(data):
""" """
Subtracts mean of image and divides by adjusted standard variance (for Subtracts mean of image and divides by adjusted standard variance (for
stability). Operations are per image but performed for the entire array. stability). Operations are per image but performed for the entire array.
:param image: 4D array (ID, Height, Weight, Channel)
:return: 4D array (ID, Height, Weight, Channel)
""" """
assert len(np.shape(data)) == 4 assert len(np.shape(data)) == 4
@ -119,11 +110,7 @@ def image_whitening(data):
def extract_svhn(local_url): def extract_svhn(local_url):
""" """Extract a MATLAB matrix into two numpy arrays with data and labels"""
Extract a MATLAB matrix into two numpy arrays with data and labels
:param local_url:
:return:
"""
with tf.gfile.Open(local_url, mode='r') as file_obj: with tf.gfile.Open(local_url, mode='r') as file_obj:
# Load MATLAB matrix using scipy IO # Load MATLAB matrix using scipy IO
@ -149,11 +136,7 @@ def extract_svhn(local_url):
def unpickle_cifar_dic(file_path): def unpickle_cifar_dic(file_path):
""" """Helper function: unpickles a dictionary (used for loading CIFAR)"""
Helper function: unpickles a dictionary (used for loading CIFAR)
:param file_path: filename of the pickle
:return: tuple of (images, labels)
"""
file_obj = open(file_path, 'rb') file_obj = open(file_path, 'rb')
data_dict = pickle.load(file_obj) data_dict = pickle.load(file_obj)
file_obj.close() file_obj.close()
@ -161,12 +144,8 @@ def unpickle_cifar_dic(file_path):
def extract_cifar10(local_url, data_dir): def extract_cifar10(local_url, data_dir):
""" """Extracts CIFAR-10 and return numpy arrays with the different sets"""
Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
:param local_url: where the tar.gz archive is located locally
:param data_dir: where to extract the archive's file
:return: a tuple (train data, train labels, test data, test labels)
"""
# These numpy dumps can be reloaded to avoid performing the pre-processing # These numpy dumps can be reloaded to avoid performing the pre-processing
# if they exist in the working directory. # if they exist in the working directory.
# Changing the order of this list will ruin the indices below. # Changing the order of this list will ruin the indices below.
@ -203,8 +182,8 @@ def extract_cifar10(local_url, data_dir):
# Check if all files have already been extracted # Check if all files have already been extracted
need_to_unpack = False need_to_unpack = False
for file in cifar10_files: for file_name in cifar10_files:
if not tf.gfile.Exists(file): if not tf.gfile.Exists(file_name):
need_to_unpack = True need_to_unpack = True
break break
@ -292,9 +271,10 @@ def extract_mnist_labels(filename, num_images):
def ld_svhn(extended=False, test_only=False): def ld_svhn(extended=False, test_only=False):
""" """
Load the original SVHN data Load the original SVHN data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up Args:
:return: tuple of arrays which depend on the parameters extended: include extended training data in the returned array
test_only: disables loading of both train and extra -> large speed up
""" """
# Define files to be downloaded # Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below) # WARNING: changing the order of this list will break indices (cf. below)
@ -334,12 +314,8 @@ def ld_svhn(extended=False, test_only=False):
def ld_cifar10(test_only=False): def ld_cifar10(test_only=False):
""" """Load the original CIFAR10 data"""
Load the original CIFAR10 data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
# Define files to be downloaded # Define files to be downloaded
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'] file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
@ -363,12 +339,7 @@ def ld_cifar10(test_only=False):
def ld_mnist(test_only=False): def ld_mnist(test_only=False):
""" """Load the MNIST dataset."""
Load the MNIST dataset
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
# Define files to be downloaded # Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below) # WARNING: changing the order of this list will break indices (cf. below)
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
@ -396,12 +367,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id):
""" """
Simple partitioning algorithm that returns the right portion of the data Simple partitioning algorithm that returns the right portion of the data
needed by a given teacher out of a certain nb of teachers needed by a given teacher out of a certain nb of teachers
:param data: input data to be partitioned
:param labels: output data to be partitioned Args:
:param nb_teachers: number of teachers in the ensemble (affects size of each data: input data to be partitioned
labels: output data to be partitioned
nb_teachers: number of teachers in the ensemble (affects size of each
partition) partition)
:param teacher_id: id of partition to retrieve teacher_id: id of partition to retrieve
:return:
""" """
# Sanity check # Sanity check