b0df24ef25
PiperOrigin-RevId: 297199727
396 lines
13 KiB
Python
396 lines
13 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import gzip
|
|
import math
|
|
import os
|
|
import sys
|
|
import tarfile
|
|
|
|
import numpy as np
|
|
from scipy.io import loadmat as loadmat
|
|
from six.moves import cPickle as pickle
|
|
from six.moves import urllib
|
|
from six.moves import xrange
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
FLAGS = tf.flags.FLAGS
|
|
|
|
|
|
def create_dir_if_needed(dest_directory):
|
|
"""Create directory if doesn't exist."""
|
|
if not tf.gfile.IsDirectory(dest_directory):
|
|
tf.gfile.MakeDirs(dest_directory)
|
|
|
|
return True
|
|
|
|
|
|
def maybe_download(file_urls, directory):
|
|
"""Download a set of files in temporary local folder."""
|
|
|
|
# Create directory if doesn't exist
|
|
assert create_dir_if_needed(directory)
|
|
|
|
# This list will include all URLS of the local copy of downloaded files
|
|
result = []
|
|
|
|
# For each file of the dataset
|
|
for file_url in file_urls:
|
|
# Extract filename
|
|
filename = file_url.split('/')[-1]
|
|
|
|
# If downloading from GitHub, remove suffix ?raw=True from local filename
|
|
if filename.endswith("?raw=true"):
|
|
filename = filename[:-9]
|
|
|
|
# Deduce local file url
|
|
#filepath = os.path.join(directory, filename)
|
|
filepath = directory + '/' + filename
|
|
|
|
# Add to result list
|
|
result.append(filepath)
|
|
|
|
# Test if file already exists
|
|
if not tf.gfile.Exists(filepath):
|
|
def _progress(count, block_size, total_size):
|
|
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
|
|
float(count * block_size) / float(total_size) * 100.0))
|
|
sys.stdout.flush()
|
|
filepath, _ = urllib.request.urlretrieve(file_url, filepath, _progress)
|
|
print()
|
|
statinfo = os.stat(filepath)
|
|
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
|
|
|
|
return result
|
|
|
|
|
|
def image_whitening(data):
|
|
"""
|
|
Subtracts mean of image and divides by adjusted standard variance (for
|
|
stability). Operations are per image but performed for the entire array.
|
|
"""
|
|
assert len(np.shape(data)) == 4
|
|
|
|
# Compute number of pixels in image
|
|
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
|
|
|
|
# Subtract mean
|
|
mean = np.mean(data, axis=(1, 2, 3))
|
|
|
|
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
|
|
for i in xrange(len(data)):
|
|
data[i, :, :, :] -= mean[i] * ones
|
|
|
|
# Compute adjusted standard variance
|
|
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long
|
|
|
|
# Divide image
|
|
for i in xrange(len(data)):
|
|
data[i, :, :, :] = data[i, :, :, :] / adj_std_var[i]
|
|
|
|
print(np.shape(data))
|
|
|
|
return data
|
|
|
|
|
|
def extract_svhn(local_url):
|
|
"""Extract a MATLAB matrix into two numpy arrays with data and labels."""
|
|
|
|
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
|
# Load MATLAB matrix using scipy IO
|
|
data_dict = loadmat(file_obj)
|
|
|
|
# Extract each dictionary (one for data, one for labels)
|
|
data, labels = data_dict['X'], data_dict['y']
|
|
|
|
# Set np type
|
|
data = np.asarray(data, dtype=np.float32)
|
|
labels = np.asarray(labels, dtype=np.int32)
|
|
|
|
# Transpose data to match TF model input format
|
|
data = data.transpose(3, 0, 1, 2)
|
|
|
|
# Fix the SVHN labels which label 0s as 10s
|
|
labels[labels == 10] = 0
|
|
|
|
# Fix label dimensions
|
|
labels = labels.reshape(len(labels))
|
|
|
|
return data, labels
|
|
|
|
|
|
def unpickle_cifar_dic(file_path):
|
|
"""Helper function: unpickles a dictionary (used for loading CIFAR)."""
|
|
file_obj = open(file_path, 'rb')
|
|
data_dict = pickle.load(file_obj)
|
|
file_obj.close()
|
|
return data_dict['data'], data_dict['labels']
|
|
|
|
|
|
def extract_cifar10(local_url, data_dir):
|
|
"""Extracts CIFAR-10 and return numpy arrays with the different sets."""
|
|
|
|
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
|
# if they exist in the working directory.
|
|
# Changing the order of this list will ruin the indices below.
|
|
preprocessed_files = ['/cifar10_train.npy',
|
|
'/cifar10_train_labels.npy',
|
|
'/cifar10_test.npy',
|
|
'/cifar10_test_labels.npy']
|
|
|
|
all_preprocessed = True
|
|
for file_name in preprocessed_files:
|
|
if not tf.gfile.Exists(data_dir + file_name):
|
|
all_preprocessed = False
|
|
break
|
|
|
|
if all_preprocessed:
|
|
# Reload pre-processed training data from numpy dumps
|
|
with tf.gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
|
|
train_data = np.load(file_obj)
|
|
with tf.gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
|
|
train_labels = np.load(file_obj)
|
|
|
|
# Reload pre-processed testing data from numpy dumps
|
|
with tf.gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
|
|
test_data = np.load(file_obj)
|
|
with tf.gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
|
|
test_labels = np.load(file_obj)
|
|
|
|
else:
|
|
# Do everything from scratch
|
|
# Define lists of all files we should extract
|
|
train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
|
|
test_file = ['test_batch']
|
|
cifar10_files = train_files + test_file
|
|
|
|
# Check if all files have already been extracted
|
|
need_to_unpack = False
|
|
for file_name in cifar10_files:
|
|
if not tf.gfile.Exists(file_name):
|
|
need_to_unpack = True
|
|
break
|
|
|
|
# We have to unpack the archive
|
|
if need_to_unpack:
|
|
tarfile.open(local_url, 'r:gz').extractall(data_dir)
|
|
|
|
# Load training images and labels
|
|
images = []
|
|
labels = []
|
|
for train_file in train_files:
|
|
# Construct filename
|
|
filename = data_dir + '/cifar-10-batches-py/' + train_file
|
|
|
|
# Unpickle dictionary and extract images and labels
|
|
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
|
|
|
# Append to lists
|
|
images.append(images_tmp)
|
|
labels.append(labels_tmp)
|
|
|
|
# Convert to numpy arrays and reshape in the expected format
|
|
train_data = np.asarray(images, dtype=np.float32)
|
|
train_data = train_data.reshape((50000, 3, 32, 32))
|
|
train_data = np.swapaxes(train_data, 1, 3)
|
|
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
|
|
|
# Save so we don't have to do this again
|
|
np.save(data_dir + preprocessed_files[0], train_data)
|
|
np.save(data_dir + preprocessed_files[1], train_labels)
|
|
|
|
# Construct filename for test file
|
|
filename = data_dir + '/cifar-10-batches-py/' + test_file[0]
|
|
|
|
# Load test images and labels
|
|
test_data, test_images = unpickle_cifar_dic(filename)
|
|
|
|
# Convert to numpy arrays and reshape in the expected format
|
|
test_data = np.asarray(test_data, dtype=np.float32)
|
|
test_data = test_data.reshape((10000, 3, 32, 32))
|
|
test_data = np.swapaxes(test_data, 1, 3)
|
|
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
|
|
|
# Save so we don't have to do this again
|
|
np.save(data_dir + preprocessed_files[2], test_data)
|
|
np.save(data_dir + preprocessed_files[3], test_labels)
|
|
|
|
return train_data, train_labels, test_data, test_labels
|
|
|
|
|
|
def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
|
"""
|
|
Extract the images into a 4D tensor [image index, y, x, channels].
|
|
|
|
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
|
"""
|
|
if not tf.gfile.Exists(filename+'.npy'):
|
|
with gzip.open(filename) as bytestream:
|
|
bytestream.read(16)
|
|
buf = bytestream.read(image_size * image_size * num_images)
|
|
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
|
|
data = (data - (pixel_depth / 2.0)) / pixel_depth
|
|
data = data.reshape(num_images, image_size, image_size, 1)
|
|
np.save(filename, data)
|
|
return data
|
|
else:
|
|
with tf.gfile.Open(filename+'.npy', mode='rb') as file_obj:
|
|
return np.load(file_obj)
|
|
|
|
|
|
def extract_mnist_labels(filename, num_images):
|
|
"""
|
|
Extract the labels into a vector of int64 label IDs.
|
|
"""
|
|
if not tf.gfile.Exists(filename+'.npy'):
|
|
with gzip.open(filename) as bytestream:
|
|
bytestream.read(8)
|
|
buf = bytestream.read(1 * num_images)
|
|
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int32)
|
|
np.save(filename, labels)
|
|
return labels
|
|
else:
|
|
with tf.gfile.Open(filename+'.npy', mode='rb') as file_obj:
|
|
return np.load(file_obj)
|
|
|
|
|
|
def ld_svhn(extended=False, test_only=False):
|
|
"""
|
|
Load the original SVHN data
|
|
|
|
Args:
|
|
extended: include extended training data in the returned array
|
|
test_only: disables loading of both train and extra -> large speed up
|
|
"""
|
|
# Define files to be downloaded
|
|
# WARNING: changing the order of this list will break indices (cf. below)
|
|
file_urls = ['http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
|
|
'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
|
|
'http://ufldl.stanford.edu/housenumbers/extra_32x32.mat']
|
|
|
|
# Maybe download data and retrieve local storage urls
|
|
local_urls = maybe_download(file_urls, FLAGS.data_dir)
|
|
|
|
# Extra Train, Test, and Extended Train data
|
|
if not test_only:
|
|
# Load and applying whitening to train data
|
|
train_data, train_labels = extract_svhn(local_urls[0])
|
|
train_data = image_whitening(train_data)
|
|
|
|
# Load and applying whitening to extended train data
|
|
ext_data, ext_labels = extract_svhn(local_urls[2])
|
|
ext_data = image_whitening(ext_data)
|
|
|
|
# Load and applying whitening to test data
|
|
test_data, test_labels = extract_svhn(local_urls[1])
|
|
test_data = image_whitening(test_data)
|
|
|
|
if test_only:
|
|
return test_data, test_labels
|
|
else:
|
|
if extended:
|
|
# Stack train data with the extended training data
|
|
train_data = np.vstack((train_data, ext_data))
|
|
train_labels = np.hstack((train_labels, ext_labels))
|
|
|
|
return train_data, train_labels, test_data, test_labels
|
|
else:
|
|
# Return training and extended training data separately
|
|
return train_data, train_labels, test_data, test_labels, ext_data, ext_labels
|
|
|
|
|
|
def ld_cifar10(test_only=False):
|
|
"""Load the original CIFAR10 data."""
|
|
|
|
# Define files to be downloaded
|
|
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
|
|
|
# Maybe download data and retrieve local storage urls
|
|
local_urls = maybe_download(file_urls, FLAGS.data_dir)
|
|
|
|
# Extract archives and return different sets
|
|
dataset = extract_cifar10(local_urls[0], FLAGS.data_dir)
|
|
|
|
# Unpack tuple
|
|
train_data, train_labels, test_data, test_labels = dataset
|
|
|
|
# Apply whitening to input data
|
|
train_data = image_whitening(train_data)
|
|
test_data = image_whitening(test_data)
|
|
|
|
if test_only:
|
|
return test_data, test_labels
|
|
else:
|
|
return train_data, train_labels, test_data, test_labels
|
|
|
|
|
|
def ld_mnist(test_only=False):
|
|
"""Load the MNIST dataset."""
|
|
# Define files to be downloaded
|
|
# WARNING: changing the order of this list will break indices (cf. below)
|
|
file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
|
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
|
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
|
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
|
]
|
|
|
|
# Maybe download data and retrieve local storage urls
|
|
local_urls = maybe_download(file_urls, FLAGS.data_dir)
|
|
|
|
# Extract it into np arrays.
|
|
train_data = extract_mnist_data(local_urls[0], 60000, 28, 1)
|
|
train_labels = extract_mnist_labels(local_urls[1], 60000)
|
|
test_data = extract_mnist_data(local_urls[2], 10000, 28, 1)
|
|
test_labels = extract_mnist_labels(local_urls[3], 10000)
|
|
|
|
if test_only:
|
|
return test_data, test_labels
|
|
else:
|
|
return train_data, train_labels, test_data, test_labels
|
|
|
|
|
|
def partition_dataset(data, labels, nb_teachers, teacher_id):
|
|
"""
|
|
Simple partitioning algorithm that returns the right portion of the data
|
|
needed by a given teacher out of a certain nb of teachers
|
|
|
|
Args:
|
|
data: input data to be partitioned
|
|
labels: output data to be partitioned
|
|
nb_teachers: number of teachers in the ensemble (affects size of each
|
|
partition)
|
|
teacher_id: id of partition to retrieve
|
|
"""
|
|
|
|
# Sanity check
|
|
assert len(data) == len(labels)
|
|
assert int(teacher_id) < int(nb_teachers)
|
|
|
|
# This will floor the possible number of batches
|
|
batch_len = int(len(data) / nb_teachers)
|
|
|
|
# Compute start, end indices of partition
|
|
start = teacher_id * batch_len
|
|
end = (teacher_id+1) * batch_len
|
|
|
|
# Slice partition off
|
|
partition_data = data[start:end]
|
|
partition_labels = labels[start:end]
|
|
|
|
return partition_data, partition_labels
|