Merge pull request #31 from an1006634493:patch-1

PiperOrigin-RevId: 239029264
This commit is contained in:
A. Unique TensorFlower 2019-03-18 11:44:36 -07:00
commit 8fc35f9ca3

View file

@ -17,28 +17,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cPickle
import gzip
import math
import numpy as np
import os
from scipy.io import loadmat as loadmat
from six.moves import urllib
from six.moves import xrange
import sys
import tarfile
import numpy as np
from scipy.io import loadmat as loadmat
from six.moves import cPickle as pickle
from six.moves import urllib
from six.moves import xrange
import tensorflow as tf
FLAGS = tf.flags.FLAGS
def create_dir_if_needed(dest_directory):
"""
Create directory if doesn't exist
:param dest_directory:
:return: True if everything went well
"""
"""Create directory if doesn't exist."""
if not tf.gfile.IsDirectory(dest_directory):
tf.gfile.MakeDirs(dest_directory)
@ -46,11 +42,8 @@ def create_dir_if_needed(dest_directory):
def maybe_download(file_urls, directory):
"""
Download a set of files in temporary local folder
:param directory: the directory where to download
:return: a tuple of filepaths corresponding to the files given as input
"""
"""Download a set of files in temporary local folder."""
# Create directory if doesn't exist
assert create_dir_if_needed(directory)
@ -91,8 +84,6 @@ def image_whitening(data):
"""
Subtracts mean of image and divides by adjusted standard variance (for
stability). Operations are per image but performed for the entire array.
:param image: 4D array (ID, Height, Weight, Channel)
:return: 4D array (ID, Height, Weight, Channel)
"""
assert len(np.shape(data)) == 4
@ -100,14 +91,14 @@ def image_whitening(data):
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
# Subtract mean
mean = np.mean(data, axis=(1,2,3))
mean = np.mean(data, axis=(1, 2, 3))
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
for i in xrange(len(data)):
data[i, :, :, :] -= mean[i] * ones
# Compute adjusted standard variance
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long
# Divide image
for i in xrange(len(data)):
@ -119,18 +110,14 @@ def image_whitening(data):
def extract_svhn(local_url):
"""
Extract a MATLAB matrix into two numpy arrays with data and labels
:param local_url:
:return:
"""
"""Extract a MATLAB matrix into two numpy arrays with data and labels."""
with tf.gfile.Open(local_url, mode='r') as file_obj:
# Load MATLAB matrix using scipy IO
dict = loadmat(file_obj)
data_dict = loadmat(file_obj)
# Extract each dictionary (one for data, one for labels)
data, labels = dict["X"], dict["y"]
data, labels = data_dict['X'], data_dict['y']
# Set np type
data = np.asarray(data, dtype=np.float32)
@ -148,25 +135,17 @@ def extract_svhn(local_url):
return data, labels
def unpickle_cifar_dic(file):
"""
Helper function: unpickles a dictionary (used for loading CIFAR)
:param file: filename of the pickle
:return: tuple of (images, labels)
"""
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
return dict['data'], dict['labels']
def unpickle_cifar_dic(file_path):
"""Helper function: unpickles a dictionary (used for loading CIFAR)."""
file_obj = open(file_path, 'rb')
data_dict = pickle.load(file_obj)
file_obj.close()
return data_dict['data'], data_dict['labels']
def extract_cifar10(local_url, data_dir):
"""
Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
:param local_url: where the tar.gz archive is located locally
:param data_dir: where to extract the archive's file
:return: a tuple (train data, train labels, test data, test labels)
"""
"""Extracts CIFAR-10 and return numpy arrays with the different sets."""
# These numpy dumps can be reloaded to avoid performing the pre-processing
# if they exist in the working directory.
# Changing the order of this list will ruin the indices below.
@ -176,8 +155,8 @@ def extract_cifar10(local_url, data_dir):
'/cifar10_test_labels.npy']
all_preprocessed = True
for file in preprocessed_files:
if not tf.gfile.Exists(data_dir + file):
for file_name in preprocessed_files:
if not tf.gfile.Exists(data_dir + file_name):
all_preprocessed = False
break
@ -197,14 +176,14 @@ def extract_cifar10(local_url, data_dir):
else:
# Do everything from scratch
# Define lists of all files we should extract
train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
test_file = ["test_batch"]
train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
test_file = ['test_batch']
cifar10_files = train_files + test_file
# Check if all files have already been extracted
need_to_unpack = False
for file in cifar10_files:
if not tf.gfile.Exists(file):
for file_name in cifar10_files:
if not tf.gfile.Exists(file_name):
need_to_unpack = True
break
@ -215,9 +194,9 @@ def extract_cifar10(local_url, data_dir):
# Load training images and labels
images = []
labels = []
for file in train_files:
for train_file in train_files:
# Construct filename
filename = data_dir + "/cifar-10-batches-py/" + file
filename = data_dir + '/cifar-10-batches-py/' + train_file
# Unpickle dictionary and extract images and labels
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
@ -227,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
labels.append(labels_tmp)
# Convert to numpy arrays and reshape in the expected format
train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
train_data = np.asarray(images, dtype=np.float32)
train_data = train_data.reshape((50000, 3, 32, 32))
train_data = np.swapaxes(train_data, 1, 3)
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
@ -236,13 +216,14 @@ def extract_cifar10(local_url, data_dir):
np.save(data_dir + preprocessed_files[1], train_labels)
# Construct filename for test file
filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
filename = data_dir + '/cifar-10-batches-py/' + test_file[0]
# Load test images and labels
test_data, test_images = unpickle_cifar_dic(filename)
# Convert to numpy arrays and reshape in the expected format
test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
test_data = np.asarray(test_data, dtype=np.float32)
test_data = test_data.reshape((10000, 3, 32, 32))
test_data = np.swapaxes(test_data, 1, 3)
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
@ -259,8 +240,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
Values are rescaled from [0, 255] down to [-0.5, 0.5].
"""
# if not os.path.exists(file):
if not tf.gfile.Exists(filename+".npy"):
if not tf.gfile.Exists(filename+'.npy'):
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(image_size * image_size * num_images)
@ -270,7 +250,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
np.save(filename, data)
return data
else:
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
return np.load(file_obj)
@ -278,8 +258,7 @@ def extract_mnist_labels(filename, num_images):
"""
Extract the labels into a vector of int64 label IDs.
"""
# if not os.path.exists(file):
if not tf.gfile.Exists(filename+".npy"):
if not tf.gfile.Exists(filename+'.npy'):
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
@ -287,16 +266,17 @@ def extract_mnist_labels(filename, num_images):
np.save(filename, labels)
return labels
else:
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
return np.load(file_obj)
def ld_svhn(extended=False, test_only=False):
"""
Load the original SVHN data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
Args:
extended: include extended training data in the returned array
test_only: disables loading of both train and extra -> large speed up
"""
# Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below)
@ -332,16 +312,12 @@ def ld_svhn(extended=False, test_only=False):
return train_data, train_labels, test_data, test_labels
else:
# Return training and extended training data separately
return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
return train_data, train_labels, test_data, test_labels, ext_data, ext_labels
def ld_cifar10(test_only=False):
"""
Load the original CIFAR10 data
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
"""Load the original CIFAR10 data."""
# Define files to be downloaded
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
@ -365,19 +341,14 @@ def ld_cifar10(test_only=False):
def ld_mnist(test_only=False):
"""
Load the MNIST dataset
:param extended: include extended training data in the returned array
:param test_only: disables loading of both train and extra -> large speed up
:return: tuple of arrays which depend on the parameters
"""
"""Load the MNIST dataset."""
# Define files to be downloaded
# WARNING: changing the order of this list will break indices (cf. below)
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
]
# Maybe download data and retrieve local storage urls
local_urls = maybe_download(file_urls, FLAGS.data_dir)
@ -398,12 +369,13 @@ def partition_dataset(data, labels, nb_teachers, teacher_id):
"""
Simple partitioning algorithm that returns the right portion of the data
needed by a given teacher out of a certain nb of teachers
:param data: input data to be partitioned
:param labels: output data to be partitioned
:param nb_teachers: number of teachers in the ensemble (affects size of each
Args:
data: input data to be partitioned
labels: output data to be partitioned
nb_teachers: number of teachers in the ensemble (affects size of each
partition)
:param teacher_id: id of partition to retrieve
:return:
teacher_id: id of partition to retrieve
"""
# Sanity check