Merge pull request #31 from an1006634493:patch-1

PiperOrigin-RevId: 239029264
This commit is contained in:
A. Unique TensorFlower 2019-03-18 11:44:36 -07:00
commit 8fc35f9ca3

View file

@ -17,28 +17,24 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import cPickle
import gzip import gzip
import math import math
import numpy as np
import os import os
from scipy.io import loadmat as loadmat
from six.moves import urllib
from six.moves import xrange
import sys import sys
import tarfile import tarfile
import numpy as np
from scipy.io import loadmat as loadmat
from six.moves import cPickle as pickle
from six.moves import urllib
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
FLAGS = tf.flags.FLAGS FLAGS = tf.flags.FLAGS
def create_dir_if_needed(dest_directory): def create_dir_if_needed(dest_directory):
""" """Create directory if doesn't exist."""
Create directory if doesn't exist
:param dest_directory:
:return: True if everything went well
"""
if not tf.gfile.IsDirectory(dest_directory): if not tf.gfile.IsDirectory(dest_directory):
tf.gfile.MakeDirs(dest_directory) tf.gfile.MakeDirs(dest_directory)
@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory):
def maybe_download(file_urls, directory): def maybe_download(file_urls, directory):
""" """Download a set of files in temporary local folder."""
Download a set of files in temporary local folder
:param directory: the directory where to download
:return: a tuple of filepaths corresponding to the files given as input
"""
# Create directory if doesn't exist # Create directory if doesn't exist
assert create_dir_if_needed(directory) assert create_dir_if_needed(directory)
@ -91,8 +84,6 @@ def image_whitening(data):
""" """
Subtracts mean of image and divides by adjusted standard variance (for Subtracts mean of image and divides by adjusted standard variance (for
stability). Operations are per image but performed for the entire array. stability). Operations are per image but performed for the entire array.
:param image: 4D array (ID, Height, Weight, Channel)
:return: 4D array (ID, Height, Weight, Channel)
""" """
assert len(np.shape(data)) == 4 assert len(np.shape(data)) == 4
@ -100,14 +91,14 @@ def image_whitening(data):
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3] nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
# Subtract mean # Subtract mean
mean = np.mean(data, axis=(1,2,3)) mean = np.mean(data, axis=(1, 2, 3))
ones = np.ones(np.shape(data)[1:4], dtype=np.float32) ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
for i in xrange(len(data)): for i in xrange(len(data)):
data[i, :, :, :] -= mean[i] * ones data[i, :, :, :] -= mean[i] * ones
# Compute adjusted standard variance # Compute adjusted standard variance
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line) adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long
# Divide image # Divide image
for i in xrange(len(data)): for i in xrange(len(data)):
@ -119,18 +110,14 @@ def image_whitening(data):
def extract_svhn(local_url): def extract_svhn(local_url):
""" """Extract a MATLAB matrix into two numpy arrays with data and labels."""
Extract a MATLAB matrix into two numpy arrays with data and labels
:param local_url:
:return:
"""
with tf.gfile.Open(local_url, mode='r') as file_obj: with tf.gfile.Open(local_url, mode='r') as file_obj:
# Load MATLAB matrix using scipy IO # Load MATLAB matrix using scipy IO
dict = loadmat(file_obj) data_dict = loadmat(file_obj)
# Extract each dictionary (one for data, one for labels) # Extract each dictionary (one for data, one for labels)
data, labels = dict["X"], dict["y"] data, labels = data_dict['X'], data_dict['y']
# Set np type # Set np type
data = np.asarray(data, dtype=np.float32) data = np.asarray(data, dtype=np.float32)
@ -148,25 +135,17 @@ def extract_svhn(local_url):
return data, labels return data, labels
def unpickle_cifar_dic(file): def unpickle_cifar_dic(file_path):
""" """Helper function: unpickles a dictionary (used for loading CIFAR)."""
Helper function: unpickles a dictionary (used for loading CIFAR) file_obj = open(file_path, 'rb')
:param file: filename of the pickle data_dict = pickle.load(file_obj)
:return: tuple of (images, labels) file_obj.close()
""" return data_dict['data'], data_dict['labels']
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
return dict['data'], dict['labels']
def extract_cifar10(local_url, data_dir): def extract_cifar10(local_url, data_dir):
""" """Extracts CIFAR-10 and return numpy arrays with the different sets."""
Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
:param local_url: where the tar.gz archive is located locally
:param data_dir: where to extract the archive's file
:return: a tuple (train data, train labels, test data, test labels)
"""
# These numpy dumps can be reloaded to avoid performing the pre-processing # These numpy dumps can be reloaded to avoid performing the pre-processing
# if they exist in the working directory. # if they exist in the working directory.
# Changing the order of this list will ruin the indices below. # Changing the order of this list will ruin the indices below.
@ -176,8 +155,8 @@ def extract_cifar10(local_url, data_dir):
'/cifar10_test_labels.npy'] '/cifar10_test_labels.npy']
all_preprocessed = True all_preprocessed = True
for file in preprocessed_files: for file_name in preprocessed_files:
if not tf.gfile.Exists(data_dir + file): if not tf.gfile.Exists(data_dir + file_name):
all_preprocessed = False all_preprocessed = False
break break
@ -197,14 +176,14 @@ def extract_cifar10(local_url, data_dir):
else: else:
# Do everything from scratch # Do everything from scratch
# Define lists of all files we should extract # Define lists of all files we should extract
train_files = ["data_batch_" + str(i) for i in xrange(1,6)] train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
test_file = ["test_batch"] test_file = ['test_batch']
cifar10_files = train_files + test_file cifar10_files = train_files + test_file
# Check if all files have already been extracted # Check if all files have already been extracted
need_to_unpack = False need_to_unpack = False
for file in cifar10_files: for file_name in cifar10_files:
if not tf.gfile.Exists(file): if not tf.gfile.Exists(file_name):
need_to_unpack = True need_to_unpack = True
break break
@ -215,9 +194,9 @@ def extract_cifar10(local_url, data_dir):
# Load training images and labels # Load training images and labels
images = [] images = []
labels = [] labels = []
for file in train_files: for train_file in train_files:
# Construct filename # Construct filename
filename = data_dir + "/cifar-10-batches-py/" + file filename = data_dir + '/cifar-10-batches-py/' + train_file
# Unpickle dictionary and extract images and labels # Unpickle dictionary and extract images and labels
images_tmp, labels_tmp = unpickle_cifar_dic(filename) images_tmp, labels_tmp = unpickle_cifar_dic(filename)
@ -227,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
labels.append(labels_tmp) labels.append(labels_tmp)
# Convert to numpy arrays and reshape in the expected format # Convert to numpy arrays and reshape in the expected format
train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32)) train_data = np.asarray(images, dtype=np.float32)
train_data = train_data.reshape((50000, 3, 32, 32))
train_data = np.swapaxes(train_data, 1, 3) train_data = np.swapaxes(train_data, 1, 3)
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000) train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
@ -236,13 +216,14 @@ def extract_cifar10(local_url, data_dir):
np.save(data_dir + preprocessed_files[1], train_labels) np.save(data_dir + preprocessed_files[1], train_labels)
# Construct filename for test file # Construct filename for test file
filename = data_dir + "/cifar-10-batches-py/" + test_file[0] filename = data_dir + '/cifar-10-batches-py/' + test_file[0]
# Load test images and labels # Load test images and labels
test_data, test_images = unpickle_cifar_dic(filename) test_data, test_images = unpickle_cifar_dic(filename)
# Convert to numpy arrays and reshape in the expected format # Convert to numpy arrays and reshape in the expected format
test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32)) test_data = np.asarray(test_data, dtype=np.float32)
test_data = test_data.reshape((10000, 3, 32, 32))
test_data = np.swapaxes(test_data, 1, 3) test_data = np.swapaxes(test_data, 1, 3)
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000) test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
@ -259,8 +240,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
Values are rescaled from [0, 255] down to [-0.5, 0.5]. Values are rescaled from [0, 255] down to [-0.5, 0.5].
""" """
# if not os.path.exists(file): if not tf.gfile.Exists(filename+'.npy'):
if not tf.gfile.Exists(filename+".npy"):
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
bytestream.read(16) bytestream.read(16)
buf = bytestream.read(image_size * image_size * num_images) buf = bytestream.read(image_size * image_size * num_images)
@ -270,7 +250,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
np.save(filename, data) np.save(filename, data)
return data return data
else: else:
with tf.gfile.Open(filename+".npy", mode='r') as file_obj: with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
return np.load(file_obj) return np.load(file_obj)
@ -278,8 +258,7 @@ def extract_mnist_labels(filename, num_images):
""" """
Extract the labels into a vector of int64 label IDs. Extract the labels into a vector of int64 label IDs.
""" """
# if not os.path.exists(file): if not tf.gfile.Exists(filename+'.npy'):
if not tf.gfile.Exists(filename+".npy"):
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
bytestream.read(8) bytestream.read(8)
buf = bytestream.read(1 * num_images) buf = bytestream.read(1 * num_images)
@ -287,16 +266,17 @@ def extract_mnist_labels(filename, num_images):
np.save(filename, labels) np.save(filename, labels)
return labels return labels
else: else:
with tf.gfile.Open(filename+".npy", mode='r') as file_obj: with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
return np.load(file_obj) return np.load(file_obj)
def ld_svhn(extended=False, test_only=False): def ld_svhn(extended=False, test_only=False):
""" """
Load the original SVHN data Load the original SVHN data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up Args:
:return: tuple of arrays which depend on the parameters extended: include extended training data in the returned array
test_only: disables loading of both train and extra -> large speed up
""" """
# Define files to be downloaded # Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below) # WARNING: changing the order of this list will break indices (cf. below)
@ -332,16 +312,12 @@ def ld_svhn(extended=False, test_only=False):
return train_data, train_labels, test_data, test_labels return train_data, train_labels, test_data, test_labels
else: else:
# Return training and extended training data separately # Return training and extended training data separately
return train_data,train_labels, test_data,test_labels, ext_data,ext_labels return train_data, train_labels, test_data, test_labels, ext_data, ext_labels
def ld_cifar10(test_only=False): def ld_cifar10(test_only=False):
""" """Load the original CIFAR10 data."""
Load the original CIFAR10 data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
# Define files to be downloaded # Define files to be downloaded
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'] file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
@ -365,12 +341,7 @@ def ld_cifar10(test_only=False):
def ld_mnist(test_only=False): def ld_mnist(test_only=False):
""" """Load the MNIST dataset."""
Load the MNIST dataset
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
# Define files to be downloaded # Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below) # WARNING: changing the order of this list will break indices (cf. below)
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
@ -398,12 +369,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id):
""" """
Simple partitioning algorithm that returns the right portion of the data Simple partitioning algorithm that returns the right portion of the data
needed by a given teacher out of a certain nb of teachers needed by a given teacher out of a certain nb of teachers
:param data: input data to be partitioned
:param labels: output data to be partitioned Args:
:param nb_teachers: number of teachers in the ensemble (affects size of each data: input data to be partitioned
labels: output data to be partitioned
nb_teachers: number of teachers in the ensemble (affects size of each
partition) partition)
:param teacher_id: id of partition to retrieve teacher_id: id of partition to retrieve
:return:
""" """
# Sanity check # Sanity check