Merge pull request #31 from an1006634493:patch-1
PiperOrigin-RevId: 239029264
This commit is contained in:
commit
8fc35f9ca3
1 changed files with 53 additions and 81 deletions
|
@ -17,28 +17,24 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import cPickle
|
|
||||||
import gzip
|
import gzip
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
from scipy.io import loadmat as loadmat
|
|
||||||
from six.moves import urllib
|
|
||||||
from six.moves import xrange
|
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import loadmat as loadmat
|
||||||
|
from six.moves import cPickle as pickle
|
||||||
|
from six.moves import urllib
|
||||||
|
from six.moves import xrange
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
FLAGS = tf.flags.FLAGS
|
FLAGS = tf.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
def create_dir_if_needed(dest_directory):
|
def create_dir_if_needed(dest_directory):
|
||||||
"""
|
"""Create directory if doesn't exist."""
|
||||||
Create directory if doesn't exist
|
|
||||||
:param dest_directory:
|
|
||||||
:return: True if everything went well
|
|
||||||
"""
|
|
||||||
if not tf.gfile.IsDirectory(dest_directory):
|
if not tf.gfile.IsDirectory(dest_directory):
|
||||||
tf.gfile.MakeDirs(dest_directory)
|
tf.gfile.MakeDirs(dest_directory)
|
||||||
|
|
||||||
|
@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory):
|
||||||
|
|
||||||
|
|
||||||
def maybe_download(file_urls, directory):
|
def maybe_download(file_urls, directory):
|
||||||
"""
|
"""Download a set of files in temporary local folder."""
|
||||||
Download a set of files in temporary local folder
|
|
||||||
:param directory: the directory where to download
|
|
||||||
:return: a tuple of filepaths corresponding to the files given as input
|
|
||||||
"""
|
|
||||||
# Create directory if doesn't exist
|
# Create directory if doesn't exist
|
||||||
assert create_dir_if_needed(directory)
|
assert create_dir_if_needed(directory)
|
||||||
|
|
||||||
|
@ -91,8 +84,6 @@ def image_whitening(data):
|
||||||
"""
|
"""
|
||||||
Subtracts mean of image and divides by adjusted standard variance (for
|
Subtracts mean of image and divides by adjusted standard variance (for
|
||||||
stability). Operations are per image but performed for the entire array.
|
stability). Operations are per image but performed for the entire array.
|
||||||
:param image: 4D array (ID, Height, Weight, Channel)
|
|
||||||
:return: 4D array (ID, Height, Weight, Channel)
|
|
||||||
"""
|
"""
|
||||||
assert len(np.shape(data)) == 4
|
assert len(np.shape(data)) == 4
|
||||||
|
|
||||||
|
@ -100,14 +91,14 @@ def image_whitening(data):
|
||||||
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
|
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
|
||||||
|
|
||||||
# Subtract mean
|
# Subtract mean
|
||||||
mean = np.mean(data, axis=(1,2,3))
|
mean = np.mean(data, axis=(1, 2, 3))
|
||||||
|
|
||||||
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
|
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
|
||||||
for i in xrange(len(data)):
|
for i in xrange(len(data)):
|
||||||
data[i, :, :, :] -= mean[i] * ones
|
data[i, :, :, :] -= mean[i] * ones
|
||||||
|
|
||||||
# Compute adjusted standard variance
|
# Compute adjusted standard variance
|
||||||
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
|
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long
|
||||||
|
|
||||||
# Divide image
|
# Divide image
|
||||||
for i in xrange(len(data)):
|
for i in xrange(len(data)):
|
||||||
|
@ -119,18 +110,14 @@ def image_whitening(data):
|
||||||
|
|
||||||
|
|
||||||
def extract_svhn(local_url):
|
def extract_svhn(local_url):
|
||||||
"""
|
"""Extract a MATLAB matrix into two numpy arrays with data and labels."""
|
||||||
Extract a MATLAB matrix into two numpy arrays with data and labels
|
|
||||||
:param local_url:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
||||||
# Load MATLAB matrix using scipy IO
|
# Load MATLAB matrix using scipy IO
|
||||||
dict = loadmat(file_obj)
|
data_dict = loadmat(file_obj)
|
||||||
|
|
||||||
# Extract each dictionary (one for data, one for labels)
|
# Extract each dictionary (one for data, one for labels)
|
||||||
data, labels = dict["X"], dict["y"]
|
data, labels = data_dict['X'], data_dict['y']
|
||||||
|
|
||||||
# Set np type
|
# Set np type
|
||||||
data = np.asarray(data, dtype=np.float32)
|
data = np.asarray(data, dtype=np.float32)
|
||||||
|
@ -148,25 +135,17 @@ def extract_svhn(local_url):
|
||||||
return data, labels
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
def unpickle_cifar_dic(file):
|
def unpickle_cifar_dic(file_path):
|
||||||
"""
|
"""Helper function: unpickles a dictionary (used for loading CIFAR)."""
|
||||||
Helper function: unpickles a dictionary (used for loading CIFAR)
|
file_obj = open(file_path, 'rb')
|
||||||
:param file: filename of the pickle
|
data_dict = pickle.load(file_obj)
|
||||||
:return: tuple of (images, labels)
|
file_obj.close()
|
||||||
"""
|
return data_dict['data'], data_dict['labels']
|
||||||
fo = open(file, 'rb')
|
|
||||||
dict = cPickle.load(fo)
|
|
||||||
fo.close()
|
|
||||||
return dict['data'], dict['labels']
|
|
||||||
|
|
||||||
|
|
||||||
def extract_cifar10(local_url, data_dir):
|
def extract_cifar10(local_url, data_dir):
|
||||||
"""
|
"""Extracts CIFAR-10 and return numpy arrays with the different sets."""
|
||||||
Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
|
|
||||||
:param local_url: where the tar.gz archive is located locally
|
|
||||||
:param data_dir: where to extract the archive's file
|
|
||||||
:return: a tuple (train data, train labels, test data, test labels)
|
|
||||||
"""
|
|
||||||
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
||||||
# if they exist in the working directory.
|
# if they exist in the working directory.
|
||||||
# Changing the order of this list will ruin the indices below.
|
# Changing the order of this list will ruin the indices below.
|
||||||
|
@ -176,8 +155,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
'/cifar10_test_labels.npy']
|
'/cifar10_test_labels.npy']
|
||||||
|
|
||||||
all_preprocessed = True
|
all_preprocessed = True
|
||||||
for file in preprocessed_files:
|
for file_name in preprocessed_files:
|
||||||
if not tf.gfile.Exists(data_dir + file):
|
if not tf.gfile.Exists(data_dir + file_name):
|
||||||
all_preprocessed = False
|
all_preprocessed = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -197,14 +176,14 @@ def extract_cifar10(local_url, data_dir):
|
||||||
else:
|
else:
|
||||||
# Do everything from scratch
|
# Do everything from scratch
|
||||||
# Define lists of all files we should extract
|
# Define lists of all files we should extract
|
||||||
train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
|
train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
|
||||||
test_file = ["test_batch"]
|
test_file = ['test_batch']
|
||||||
cifar10_files = train_files + test_file
|
cifar10_files = train_files + test_file
|
||||||
|
|
||||||
# Check if all files have already been extracted
|
# Check if all files have already been extracted
|
||||||
need_to_unpack = False
|
need_to_unpack = False
|
||||||
for file in cifar10_files:
|
for file_name in cifar10_files:
|
||||||
if not tf.gfile.Exists(file):
|
if not tf.gfile.Exists(file_name):
|
||||||
need_to_unpack = True
|
need_to_unpack = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -215,9 +194,9 @@ def extract_cifar10(local_url, data_dir):
|
||||||
# Load training images and labels
|
# Load training images and labels
|
||||||
images = []
|
images = []
|
||||||
labels = []
|
labels = []
|
||||||
for file in train_files:
|
for train_file in train_files:
|
||||||
# Construct filename
|
# Construct filename
|
||||||
filename = data_dir + "/cifar-10-batches-py/" + file
|
filename = data_dir + '/cifar-10-batches-py/' + train_file
|
||||||
|
|
||||||
# Unpickle dictionary and extract images and labels
|
# Unpickle dictionary and extract images and labels
|
||||||
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
||||||
|
@ -227,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
labels.append(labels_tmp)
|
labels.append(labels_tmp)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
|
train_data = np.asarray(images, dtype=np.float32)
|
||||||
|
train_data = train_data.reshape((50000, 3, 32, 32))
|
||||||
train_data = np.swapaxes(train_data, 1, 3)
|
train_data = np.swapaxes(train_data, 1, 3)
|
||||||
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
||||||
|
|
||||||
|
@ -236,13 +216,14 @@ def extract_cifar10(local_url, data_dir):
|
||||||
np.save(data_dir + preprocessed_files[1], train_labels)
|
np.save(data_dir + preprocessed_files[1], train_labels)
|
||||||
|
|
||||||
# Construct filename for test file
|
# Construct filename for test file
|
||||||
filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
|
filename = data_dir + '/cifar-10-batches-py/' + test_file[0]
|
||||||
|
|
||||||
# Load test images and labels
|
# Load test images and labels
|
||||||
test_data, test_images = unpickle_cifar_dic(filename)
|
test_data, test_images = unpickle_cifar_dic(filename)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
|
test_data = np.asarray(test_data, dtype=np.float32)
|
||||||
|
test_data = test_data.reshape((10000, 3, 32, 32))
|
||||||
test_data = np.swapaxes(test_data, 1, 3)
|
test_data = np.swapaxes(test_data, 1, 3)
|
||||||
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
||||||
|
|
||||||
|
@ -259,8 +240,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
||||||
|
|
||||||
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
if not tf.gfile.Exists(filename+".npy"):
|
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(16)
|
bytestream.read(16)
|
||||||
buf = bytestream.read(image_size * image_size * num_images)
|
buf = bytestream.read(image_size * image_size * num_images)
|
||||||
|
@ -270,7 +250,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
||||||
np.save(filename, data)
|
np.save(filename, data)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
|
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
|
||||||
return np.load(file_obj)
|
return np.load(file_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,8 +258,7 @@ def extract_mnist_labels(filename, num_images):
|
||||||
"""
|
"""
|
||||||
Extract the labels into a vector of int64 label IDs.
|
Extract the labels into a vector of int64 label IDs.
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
if not tf.gfile.Exists(filename+".npy"):
|
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(8)
|
bytestream.read(8)
|
||||||
buf = bytestream.read(1 * num_images)
|
buf = bytestream.read(1 * num_images)
|
||||||
|
@ -287,16 +266,17 @@ def extract_mnist_labels(filename, num_images):
|
||||||
np.save(filename, labels)
|
np.save(filename, labels)
|
||||||
return labels
|
return labels
|
||||||
else:
|
else:
|
||||||
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
|
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
|
||||||
return np.load(file_obj)
|
return np.load(file_obj)
|
||||||
|
|
||||||
|
|
||||||
def ld_svhn(extended=False, test_only=False):
|
def ld_svhn(extended=False, test_only=False):
|
||||||
"""
|
"""
|
||||||
Load the original SVHN data
|
Load the original SVHN data
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
Args:
|
||||||
:return: tuple of arrays which depend on the parameters
|
extended: include extended training data in the returned array
|
||||||
|
test_only: disables loading of both train and extra -> large speed up
|
||||||
"""
|
"""
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
# WARNING: changing the order of this list will break indices (cf. below)
|
# WARNING: changing the order of this list will break indices (cf. below)
|
||||||
|
@ -332,16 +312,12 @@ def ld_svhn(extended=False, test_only=False):
|
||||||
return train_data, train_labels, test_data, test_labels
|
return train_data, train_labels, test_data, test_labels
|
||||||
else:
|
else:
|
||||||
# Return training and extended training data separately
|
# Return training and extended training data separately
|
||||||
return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
|
return train_data, train_labels, test_data, test_labels, ext_data, ext_labels
|
||||||
|
|
||||||
|
|
||||||
def ld_cifar10(test_only=False):
|
def ld_cifar10(test_only=False):
|
||||||
"""
|
"""Load the original CIFAR10 data."""
|
||||||
Load the original CIFAR10 data
|
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
|
||||||
:return: tuple of arrays which depend on the parameters
|
|
||||||
"""
|
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
||||||
|
|
||||||
|
@ -365,12 +341,7 @@ def ld_cifar10(test_only=False):
|
||||||
|
|
||||||
|
|
||||||
def ld_mnist(test_only=False):
|
def ld_mnist(test_only=False):
|
||||||
"""
|
"""Load the MNIST dataset."""
|
||||||
Load the MNIST dataset
|
|
||||||
:param extended: include extended training data in the returned array
|
|
||||||
:param test_only: disables loading of both train and extra -> large speed up
|
|
||||||
:return: tuple of arrays which depend on the parameters
|
|
||||||
"""
|
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
# WARNING: changing the order of this list will break indices (cf. below)
|
# WARNING: changing the order of this list will break indices (cf. below)
|
||||||
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
||||||
|
@ -398,12 +369,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id):
|
||||||
"""
|
"""
|
||||||
Simple partitioning algorithm that returns the right portion of the data
|
Simple partitioning algorithm that returns the right portion of the data
|
||||||
needed by a given teacher out of a certain nb of teachers
|
needed by a given teacher out of a certain nb of teachers
|
||||||
:param data: input data to be partitioned
|
|
||||||
:param labels: output data to be partitioned
|
Args:
|
||||||
:param nb_teachers: number of teachers in the ensemble (affects size of each
|
data: input data to be partitioned
|
||||||
|
labels: output data to be partitioned
|
||||||
|
nb_teachers: number of teachers in the ensemble (affects size of each
|
||||||
partition)
|
partition)
|
||||||
:param teacher_id: id of partition to retrieve
|
teacher_id: id of partition to retrieve
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
|
|
Loading…
Reference in a new issue