This commit is contained in:
npapernot 2019-03-18 16:49:34 +00:00
parent b6c932ec66
commit e55a832d54

View file

@ -148,13 +148,13 @@ def extract_svhn(local_url):
return data, labels return data, labels
def unpickle_cifar_dic(file): # pylint: disable=redefined-builtin def unpickle_cifar_dic(file_path):
""" """
Helper function: unpickles a dictionary (used for loading CIFAR) Helper function: unpickles a dictionary (used for loading CIFAR)
:param file: filename of the pickle :param file_path: filename of the pickle
:return: tuple of (images, labels) :return: tuple of (images, labels)
""" """
file_obj = open(file, 'rb') file_obj = open(file_path, 'rb')
data_dict = pickle.load(file_obj) data_dict = pickle.load(file_obj)
file_obj.close() file_obj.close()
return data_dict['data'], data_dict['labels'] return data_dict['data'], data_dict['labels']
@ -215,9 +215,9 @@ def extract_cifar10(local_url, data_dir):
# Load training images and labels # Load training images and labels
images = [] images = []
labels = [] labels = []
for file in train_files: for train_file in train_files:
# Construct filename # Construct filename
filename = data_dir + '/cifar-10-batches-py/' + file filename = data_dir + '/cifar-10-batches-py/' + train_file
# Unpickle dictionary and extract images and labels # Unpickle dictionary and extract images and labels
images_tmp, labels_tmp = unpickle_cifar_dic(filename) images_tmp, labels_tmp = unpickle_cifar_dic(filename)
@ -259,7 +259,6 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
Values are rescaled from [0, 255] down to [-0.5, 0.5]. Values are rescaled from [0, 255] down to [-0.5, 0.5].
""" """
# if not os.path.exists(file):
if not tf.gfile.Exists(filename+'.npy'): if not tf.gfile.Exists(filename+'.npy'):
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
bytestream.read(16) bytestream.read(16)
@ -278,7 +277,6 @@ def extract_mnist_labels(filename, num_images):
""" """
Extract the labels into a vector of int64 label IDs. Extract the labels into a vector of int64 label IDs.
""" """
# if not os.path.exists(file):
if not tf.gfile.Exists(filename+'.npy'): if not tf.gfile.Exists(filename+'.npy'):
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
bytestream.read(8) bytestream.read(8)