forked from 626_privacy/tensorflow_privacy
fnames
This commit is contained in:
parent
b6c932ec66
commit
e55a832d54
1 changed files with 5 additions and 7 deletions
|
@ -148,13 +148,13 @@ def extract_svhn(local_url):
|
||||||
return data, labels
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
def unpickle_cifar_dic(file): # pylint: disable=redefined-builtin
|
def unpickle_cifar_dic(file_path):
|
||||||
"""
|
"""
|
||||||
Helper function: unpickles a dictionary (used for loading CIFAR)
|
Helper function: unpickles a dictionary (used for loading CIFAR)
|
||||||
:param file: filename of the pickle
|
:param file_path: filename of the pickle
|
||||||
:return: tuple of (images, labels)
|
:return: tuple of (images, labels)
|
||||||
"""
|
"""
|
||||||
file_obj = open(file, 'rb')
|
file_obj = open(file_path, 'rb')
|
||||||
data_dict = pickle.load(file_obj)
|
data_dict = pickle.load(file_obj)
|
||||||
file_obj.close()
|
file_obj.close()
|
||||||
return data_dict['data'], data_dict['labels']
|
return data_dict['data'], data_dict['labels']
|
||||||
|
@ -215,9 +215,9 @@ def extract_cifar10(local_url, data_dir):
|
||||||
# Load training images and labels
|
# Load training images and labels
|
||||||
images = []
|
images = []
|
||||||
labels = []
|
labels = []
|
||||||
for file in train_files:
|
for train_file in train_files:
|
||||||
# Construct filename
|
# Construct filename
|
||||||
filename = data_dir + '/cifar-10-batches-py/' + file
|
filename = data_dir + '/cifar-10-batches-py/' + train_file
|
||||||
|
|
||||||
# Unpickle dictionary and extract images and labels
|
# Unpickle dictionary and extract images and labels
|
||||||
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
images_tmp, labels_tmp = unpickle_cifar_dic(filename)
|
||||||
|
@ -259,7 +259,6 @@ def extract_mnist_data(filename, num_images, image_size, pixel_depth):
|
||||||
|
|
||||||
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
|
||||||
if not tf.gfile.Exists(filename+'.npy'):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(16)
|
bytestream.read(16)
|
||||||
|
@ -278,7 +277,6 @@ def extract_mnist_labels(filename, num_images):
|
||||||
"""
|
"""
|
||||||
Extract the labels into a vector of int64 label IDs.
|
Extract the labels into a vector of int64 label IDs.
|
||||||
"""
|
"""
|
||||||
# if not os.path.exists(file):
|
|
||||||
if not tf.gfile.Exists(filename+'.npy'):
|
if not tf.gfile.Exists(filename+'.npy'):
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(8)
|
bytestream.read(8)
|
||||||
|
|
Loading…
Reference in a new issue