This commit is contained in:
npapernot 2019-03-18 17:01:25 +00:00
parent a209988d87
commit 2aa9debb91

View file

@ -110,7 +110,7 @@ def image_whitening(data):
def extract_svhn(local_url): def extract_svhn(local_url):
"""Extract a MATLAB matrix into two numpy arrays with data and labels""" """Extract a MATLAB matrix into two numpy arrays with data and labels."""
with tf.gfile.Open(local_url, mode='r') as file_obj: with tf.gfile.Open(local_url, mode='r') as file_obj:
# Load MATLAB matrix using scipy IO # Load MATLAB matrix using scipy IO
@ -136,7 +136,7 @@ def extract_svhn(local_url):
def unpickle_cifar_dic(file_path): def unpickle_cifar_dic(file_path):
"""Helper function: unpickles a dictionary (used for loading CIFAR)""" """Helper function: unpickles a dictionary (used for loading CIFAR)."""
file_obj = open(file_path, 'rb') file_obj = open(file_path, 'rb')
data_dict = pickle.load(file_obj) data_dict = pickle.load(file_obj)
file_obj.close() file_obj.close()
@ -144,7 +144,7 @@ def unpickle_cifar_dic(file_path):
def extract_cifar10(local_url, data_dir): def extract_cifar10(local_url, data_dir):
"""Extracts CIFAR-10 and return numpy arrays with the different sets""" """Extracts CIFAR-10 and return numpy arrays with the different sets."""
# These numpy dumps can be reloaded to avoid performing the pre-processing # These numpy dumps can be reloaded to avoid performing the pre-processing
# if they exist in the working directory. # if they exist in the working directory.
@ -206,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
labels.append(labels_tmp) labels.append(labels_tmp)
# Convert to numpy arrays and reshape in the expected format # Convert to numpy arrays and reshape in the expected format
train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32)) train_data = np.asarray(images, dtype=np.float32)
train_data = train_data.reshape((50000, 3, 32, 32))
train_data = np.swapaxes(train_data, 1, 3) train_data = np.swapaxes(train_data, 1, 3)
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000) train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
@ -221,7 +222,8 @@ def extract_cifar10(local_url, data_dir):
test_data, test_images = unpickle_cifar_dic(filename) test_data, test_images = unpickle_cifar_dic(filename)
# Convert to numpy arrays and reshape in the expected format # Convert to numpy arrays and reshape in the expected format
test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32)) test_data = np.asarray(test_data, dtype=np.float32)
test_data = test_data.reshape((10000, 3, 32, 32))
test_data = np.swapaxes(test_data, 1, 3) test_data = np.swapaxes(test_data, 1, 3)
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000) test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
@ -314,7 +316,7 @@ def ld_svhn(extended=False, test_only=False):
def ld_cifar10(test_only=False): def ld_cifar10(test_only=False):
"""Load the original CIFAR10 data""" """Load the original CIFAR10 data."""
# Define files to be downloaded # Define files to be downloaded
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'] file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']