forked from 626_privacy/tensorflow_privacy
docstrings
This commit is contained in:
parent
e55a832d54
commit
a209988d87
1 changed files with 22 additions and 50 deletions
|
@ -34,11 +34,7 @@ FLAGS = tf.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
def create_dir_if_needed(dest_directory):
|
def create_dir_if_needed(dest_directory):
|
||||||
"""
|
"""Create directory if doesn't exist."""
|
||||||
Create directory if doesn't exist
|
|
||||||
:param dest_directory:
|
|
||||||
:return: True if everything went well
|
|
||||||
"""
|
|
||||||
if not tf.gfile.IsDirectory(dest_directory):
|
if not tf.gfile.IsDirectory(dest_directory):
|
||||||
tf.gfile.MakeDirs(dest_directory)
|
tf.gfile.MakeDirs(dest_directory)
|
||||||
|
|
||||||
|
@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory):
|
||||||
|
|
||||||
|
|
||||||
def maybe_download(file_urls, directory):
|
def maybe_download(file_urls, directory):
|
||||||
"""
|
"""Download a set of files in temporary local folder."""
|
||||||
Download a set of files in temporary local folder
|
|
||||||
:param directory: the directory where to download
|
|
||||||
:return: a tuple of filepaths corresponding to the files given as input
|
|
||||||
"""
|
|
||||||
# Create directory if doesn't exist
|
# Create directory if doesn't exist
|
||||||
assert create_dir_if_needed(directory)
|
assert create_dir_if_needed(directory)
|
||||||
|
|
||||||
|
@ -91,8 +84,6 @@ def image_whitening(data):
|
||||||
"""
|
"""
|
||||||
Subtracts mean of image and divides by adjusted standard variance (for
|
Subtracts mean of image and divides by adjusted standard variance (for
|
||||||
stability). Operations are per image but performed for the entire array.
|
stability). Operations are per image but performed for the entire array.
|
||||||
:param image: 4D array (ID, Height, Weight, Channel)
|
|
||||||
:return: 4D array (ID, Height, Weight, Channel)
|
|
||||||
"""
|
"""
|
||||||
assert len(np.shape(data)) == 4
|
assert len(np.shape(data)) == 4
|
||||||
|
|
||||||
|
@ -119,11 +110,7 @@ def image_whitening(data):
|
||||||
|
|
||||||
|
|
||||||
def extract_svhn(local_url):
|
def extract_svhn(local_url):
|
||||||
"""
|
"""Extract a MATLAB matrix into two numpy arrays with data and labels"""
|
||||||
Extract a MATLAB matrix into two numpy arrays with data and labels
|
|
||||||
:param local_url:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
||||||
# Load MATLAB matrix using scipy IO
|
# Load MATLAB matrix using scipy IO
|
||||||
|
@ -149,11 +136,7 @@ def extract_svhn(local_url):
|
||||||
|
|
||||||
|
|
||||||
def unpickle_cifar_dic(file_path):
|
def unpickle_cifar_dic(file_path):
|
||||||
"""
|
"""Helper function: unpickles a dictionary (used for loading CIFAR)"""
|
||||||
Helper function: unpickles a dictionary (used for loading CIFAR)
|
|
||||||
:param file_path: filename of the pickle
|
|
||||||
:return: tuple of (images, labels)
|
|
||||||
"""
|
|
||||||
file_obj = open(file_path, 'rb')
|
file_obj = open(file_path, 'rb')
|
||||||
data_dict = pickle.load(file_obj)
|
data_dict = pickle.load(file_obj)
|
||||||
file_obj.close()
|
file_obj.close()
|
||||||
|
@ -161,12 +144,8 @@ def unpickle_cifar_dic(file_path):
|
||||||
|
|
||||||
|
|
||||||
def extract_cifar10(local_url, data_dir):
|
def extract_cifar10(local_url, data_dir):
|
||||||
"""
|
"""Extracts CIFAR-10 and return numpy arrays with the different sets"""
|
||||||
Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
|
|
||||||
:param local_url: where the tar.gz archive is located locally
|
|
||||||
:param data_dir: where to extract the archive's file
|
|
||||||
:return: a tuple (train data, train labels, test data, test labels)
|
|
||||||
"""
|
|
||||||
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
||||||
# if they exist in the working directory.
|
# if they exist in the working directory.
|
||||||
# Changing the order of this list will ruin the indices below.
|
# Changing the order of this list will ruin the indices below.
|
||||||
|
@ -203,8 +182,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
|
|
||||||
# Check if all files have already been extracted
|
# Check if all files have already been extracted
|
||||||
need_to_unpack = False
|
need_to_unpack = False
|
||||||
for file in cifar10_files:
|
for file_name in cifar10_files:
|
||||||
if not tf.gfile.Exists(file):
|
if not tf.gfile.Exists(file_name):
|
||||||
need_to_unpack = True
|
need_to_unpack = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -292,9 +271,10 @@ def extract_mnist_labels(filename, num_images):
|
||||||
def ld_svhn(extended=False, test_only=False):
|
def ld_svhn(extended=False, test_only=False):
|
||||||
"""
|
"""
|
||||||
Load the original SVHN data
|
Load the original SVHN data
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
Args:
|
||||||
:return: tuple of arrays which depend on the parameters
|
extended: include extended training data in the returned array
|
||||||
|
test_only: disables loading of both train and extra -> large speed up
|
||||||
"""
|
"""
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
# WARNING: changing the order of this list will break indices (cf. below)
|
# WARNING: changing the order of this list will break indices (cf. below)
|
||||||
|
@ -334,12 +314,8 @@ def ld_svhn(extended=False, test_only=False):
|
||||||
|
|
||||||
|
|
||||||
def ld_cifar10(test_only=False):
|
def ld_cifar10(test_only=False):
|
||||||
"""
|
"""Load the original CIFAR10 data"""
|
||||||
Load the original CIFAR10 data
|
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
|
||||||
:return: tuple of arrays which depend on the parameters
|
|
||||||
"""
|
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
||||||
|
|
||||||
|
@ -363,12 +339,7 @@ def ld_cifar10(test_only=False):
|
||||||
|
|
||||||
|
|
||||||
def ld_mnist(test_only=False):
|
def ld_mnist(test_only=False):
|
||||||
"""
|
"""Load the MNIST dataset."""
|
||||||
Load the MNIST dataset
|
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
|
||||||
:return: tuple of arrays which depend on the parameters
|
|
||||||
"""
|
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
# WARNING: changing the order of this list will break indices (cf. below)
|
# WARNING: changing the order of this list will break indices (cf. below)
|
||||||
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
||||||
|
@ -396,12 +367,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id):
|
||||||
"""
|
"""
|
||||||
Simple partitioning algorithm that returns the right portion of the data
|
Simple partitioning algorithm that returns the right portion of the data
|
||||||
needed by a given teacher out of a certain nb of teachers
|
needed by a given teacher out of a certain nb of teachers
|
||||||
:param data: input data to be partitioned
|
|
||||||
:param labels: output data to be partitioned
|
Args:
|
||||||
:param nb_teachers: number of teachers in the ensemble (affects size of each
|
data: input data to be partitioned
|
||||||
|
labels: output data to be partitioned
|
||||||
|
nb_teachers: number of teachers in the ensemble (affects size of each
|
||||||
partition)
|
partition)
|
||||||
:param teacher_id: id of partition to retrieve
|
teacher_id: id of partition to retrieve
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
|
|
Loading…
Reference in a new issue