glint
This commit is contained in:
parent
a209988d87
commit
2aa9debb91
1 changed files with 8 additions and 6 deletions
|
@ -110,7 +110,7 @@ def image_whitening(data):
|
||||||
|
|
||||||
|
|
||||||
def extract_svhn(local_url):
|
def extract_svhn(local_url):
|
||||||
"""Extract a MATLAB matrix into two numpy arrays with data and labels"""
|
"""Extract a MATLAB matrix into two numpy arrays with data and labels."""
|
||||||
|
|
||||||
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
||||||
# Load MATLAB matrix using scipy IO
|
# Load MATLAB matrix using scipy IO
|
||||||
|
@ -136,7 +136,7 @@ def extract_svhn(local_url):
|
||||||
|
|
||||||
|
|
||||||
def unpickle_cifar_dic(file_path):
|
def unpickle_cifar_dic(file_path):
|
||||||
"""Helper function: unpickles a dictionary (used for loading CIFAR)"""
|
"""Helper function: unpickles a dictionary (used for loading CIFAR)."""
|
||||||
file_obj = open(file_path, 'rb')
|
file_obj = open(file_path, 'rb')
|
||||||
data_dict = pickle.load(file_obj)
|
data_dict = pickle.load(file_obj)
|
||||||
file_obj.close()
|
file_obj.close()
|
||||||
|
@ -144,7 +144,7 @@ def unpickle_cifar_dic(file_path):
|
||||||
|
|
||||||
|
|
||||||
def extract_cifar10(local_url, data_dir):
|
def extract_cifar10(local_url, data_dir):
|
||||||
"""Extracts CIFAR-10 and return numpy arrays with the different sets"""
|
"""Extracts CIFAR-10 and return numpy arrays with the different sets."""
|
||||||
|
|
||||||
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
||||||
# if they exist in the working directory.
|
# if they exist in the working directory.
|
||||||
|
@ -206,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
labels.append(labels_tmp)
|
labels.append(labels_tmp)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32))
|
train_data = np.asarray(images, dtype=np.float32)
|
||||||
|
train_data = train_data.reshape((50000, 3, 32, 32))
|
||||||
train_data = np.swapaxes(train_data, 1, 3)
|
train_data = np.swapaxes(train_data, 1, 3)
|
||||||
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
||||||
|
|
||||||
|
@ -221,7 +222,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
test_data, test_images = unpickle_cifar_dic(filename)
|
test_data, test_images = unpickle_cifar_dic(filename)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32))
|
test_data = np.asarray(test_data, dtype=np.float32)
|
||||||
|
test_data = test_data.reshape((10000, 3, 32, 32))
|
||||||
test_data = np.swapaxes(test_data, 1, 3)
|
test_data = np.swapaxes(test_data, 1, 3)
|
||||||
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
||||||
|
|
||||||
|
@ -314,7 +316,7 @@ def ld_svhn(extended=False, test_only=False):
|
||||||
|
|
||||||
|
|
||||||
def ld_cifar10(test_only=False):
|
def ld_cifar10(test_only=False):
|
||||||
"""Load the original CIFAR10 data"""
|
"""Load the original CIFAR10 data."""
|
||||||
|
|
||||||
# Define files to be downloaded
|
# Define files to be downloaded
|
||||||
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
||||||
|
|
Loading…
Reference in a new issue