glint
This commit is contained in:
parent
a209988d87
commit
2aa9debb91
1 changed files with 8 additions and 6 deletions
|
@ -110,7 +110,7 @@ def image_whitening(data):
|
|||
|
||||
|
||||
def extract_svhn(local_url):
|
||||
"""Extract a MATLAB matrix into two numpy arrays with data and labels"""
|
||||
"""Extract a MATLAB matrix into two numpy arrays with data and labels."""
|
||||
|
||||
with tf.gfile.Open(local_url, mode='r') as file_obj:
|
||||
# Load MATLAB matrix using scipy IO
|
||||
|
@ -136,7 +136,7 @@ def extract_svhn(local_url):
|
|||
|
||||
|
||||
def unpickle_cifar_dic(file_path):
|
||||
"""Helper function: unpickles a dictionary (used for loading CIFAR)"""
|
||||
"""Helper function: unpickles a dictionary (used for loading CIFAR)."""
|
||||
file_obj = open(file_path, 'rb')
|
||||
data_dict = pickle.load(file_obj)
|
||||
file_obj.close()
|
||||
|
@ -144,7 +144,7 @@ def unpickle_cifar_dic(file_path):
|
|||
|
||||
|
||||
def extract_cifar10(local_url, data_dir):
|
||||
"""Extracts CIFAR-10 and return numpy arrays with the different sets"""
|
||||
"""Extracts CIFAR-10 and return numpy arrays with the different sets."""
|
||||
|
||||
# These numpy dumps can be reloaded to avoid performing the pre-processing
|
||||
# if they exist in the working directory.
|
||||
|
@ -206,7 +206,8 @@ def extract_cifar10(local_url, data_dir):
|
|||
labels.append(labels_tmp)
|
||||
|
||||
# Convert to numpy arrays and reshape in the expected format
|
||||
train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32))
|
||||
train_data = np.asarray(images, dtype=np.float32)
|
||||
train_data = train_data.reshape((50000, 3, 32, 32))
|
||||
train_data = np.swapaxes(train_data, 1, 3)
|
||||
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
||||
|
||||
|
@ -221,7 +222,8 @@ def extract_cifar10(local_url, data_dir):
|
|||
test_data, test_images = unpickle_cifar_dic(filename)
|
||||
|
||||
# Convert to numpy arrays and reshape in the expected format
|
||||
test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32))
|
||||
test_data = np.asarray(test_data, dtype=np.float32)
|
||||
test_data = test_data.reshape((10000, 3, 32, 32))
|
||||
test_data = np.swapaxes(test_data, 1, 3)
|
||||
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
||||
|
||||
|
@ -314,7 +316,7 @@ def ld_svhn(extended=False, test_only=False):
|
|||
|
||||
|
||||
def ld_cifar10(test_only=False):
|
||||
"""Load the original CIFAR10 data"""
|
||||
"""Load the original CIFAR10 data."""
|
||||
|
||||
# Define files to be downloaded
|
||||
file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
|
||||
|
|
Loading…
Reference in a new issue