quotes
This commit is contained in:
parent
4784b0f31e
commit
b6c932ec66
1 changed files with 9 additions and 9 deletions
|
@ -130,7 +130,7 @@ def extract_svhn(local_url):
|
||||||
data_dict = loadmat(file_obj)
|
data_dict = loadmat(file_obj)
|
||||||
|
|
||||||
# Extract each dictionary (one for data, one for labels)
|
# Extract each dictionary (one for data, one for labels)
|
||||||
data, labels = data_dict["X"], data_dict["y"]
|
data, labels = data_dict['X'], data_dict['y']
|
||||||
|
|
||||||
# Set np type
|
# Set np type
|
||||||
data = np.asarray(data, dtype=np.float32)
|
data = np.asarray(data, dtype=np.float32)
|
||||||
|
@ -197,8 +197,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
else:
|
else:
|
||||||
# Do everything from scratch
|
# Do everything from scratch
|
||||||
# Define lists of all files we should extract
|
# Define lists of all files we should extract
|
||||||
train_files = ["data_batch_" + str(i) for i in xrange(1, 6)]
|
train_files = ['data_batch_' + str(i) for i in xrange(1, 6)]
|
||||||
test_file = ["test_batch"]
|
test_file = ['test_batch']
|
||||||
cifar10_files = train_files + test_file
|
cifar10_files = train_files + test_file
|
||||||
|
|
||||||
# Check if all files have already been extracted
|
# Check if all files have already been extracted
|
||||||
|
@ -217,7 +217,7 @@ def extract_cifar10(local_url, data_dir):
|
||||||
labels = []
|
labels = []
|
||||||
for file in train_files:
|
for file in train_files:
|
||||||
# Construct filename
|
# Construct filename
|
||||||
filename = data_dir + "/cifar-10-batches-py/" + file
|
filename = data_dir + '/cifar-10-batches-py/' + file
|
||||||
|
|
||||||
# Unpickle dictionary and extract images and labels
|
# Unpickle dictionary and extract images and labels
|
||||||
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
||||||
|
@ -236,7 +236,7 @@ def extract_cifar10(local_url, data_dir):
|
||||||
np.save(data_dir + preprocessed_files[1], train_labels)
|
np.save(data_dir + preprocessed_files[1], train_labels)
|
||||||
|
|
||||||
# Construct filename for test file
|
# Construct filename for test file
|
||||||
filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
|
filename = data_dir + '/cifar-10-batches-py/' + test_file[0]
|
||||||
|
|
||||||
# Load test images and labels
|
# Load test images and labels
|
||||||
test_data, test_images = unpickle_cifar_dic(filename)
|
test_data, test_images = unpickle_cifar_dic(filename)
|
||||||
|
@ -260,7 +260,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
||||||
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
# if not os.path.exists(file):
|
||||||
if not tf.gfile.Exists(filename+".npy"):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(16)
|
bytestream.read(16)
|
||||||
buf = bytestream.read(image_size * image_size * num_images)
|
buf = bytestream.read(image_size * image_size * num_images)
|
||||||
|
@ -270,7 +270,7 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
||||||
np.save(filename, data)
|
np.save(filename, data)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
|
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
|
||||||
return np.load(file_obj)
|
return np.load(file_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ def extract_mnist_labels(filename, num_images):
|
||||||
Extract the labels into a vector of int64 label IDs.
|
Extract the labels into a vector of int64 label IDs.
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
# if not os.path.exists(file):
|
||||||
if not tf.gfile.Exists(filename+".npy"):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(8)
|
bytestream.read(8)
|
||||||
buf = bytestream.read(1 * num_images)
|
buf = bytestream.read(1 * num_images)
|
||||||
|
@ -287,7 +287,7 @@ def extract_mnist_labels(filename, num_images):
|
||||||
np.save(filename, labels)
|
np.save(filename, labels)
|
||||||
return labels
|
return labels
|
||||||
else:
|
else:
|
||||||
with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
|
with tf.gfile.Open(filename+'.npy', mode='r') as file_obj:
|
||||||
return np.load(file_obj)
|
return np.load(file_obj)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue