pylint edits
This commit is contained in:
parent
ec2204ac97
commit
4784b0f31e
1 changed files with 13 additions and 13 deletions
|
@ -100,14 +100,14 @@ def image_whitening(data):
|
||||||
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
|
nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
|
||||||
|
|
||||||
# Subtract mean
|
# Subtract mean
|
||||||
mean = np.mean(data, axis=(1,2,3))
|
mean = np.mean(data, axis=(1, 2, 3))
|
||||||
|
|
||||||
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
|
ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
|
||||||
for i in xrange(len(data)):
|
for i in xrange(len(data)):
|
||||||
data[i, :, :, :] -= mean[i] * ones
|
data[i, :, :, :] -= mean[i] * ones
|
||||||
|
|
||||||
# Compute adjusted standard variance
|
# Compute adjusted standard variance
|
||||||
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
|
adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long
|
||||||
|
|
||||||
# Divide image
|
# Divide image
|
||||||
for i in xrange(len(data)):
|
for i in xrange(len(data)):
|
||||||
|
@ -148,15 +148,15 @@ def extract_svhn(local_url):
|
||||||
return data, labels
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
def unpickle_cifar_dic(file):
|
def unpickle_cifar_dic(file): # pylint: disable=redefined-builtin
|
||||||
"""
|
"""
|
||||||
Helper function: unpickles a dictionary (used for loading CIFAR)
|
Helper function: unpickles a dictionary (used for loading CIFAR)
|
||||||
:param file: filename of the pickle
|
:param file: filename of the pickle
|
||||||
:return: tuple of (images, labels)
|
:return: tuple of (images, labels)
|
||||||
"""
|
"""
|
||||||
fo = open(file, 'rb')
|
file_obj = open(file, 'rb')
|
||||||
data_dict = pickle.load(fo)
|
data_dict = pickle.load(file_obj)
|
||||||
fo.close()
|
file_obj.close()
|
||||||
return data_dict['data'], data_dict['labels']
|
return data_dict['data'], data_dict['labels']
|
||||||
|
|
||||||
|
|
||||||
|
@ -176,8 +176,8 @@ def extract_cifar10(local_url, data_dir):
|
||||||
'/cifar10_test_labels.npy']
|
'/cifar10_test_labels.npy']
|
||||||
|
|
||||||
all_preprocessed = True
|
all_preprocessed = True
|
||||||
for file in preprocessed_files:
|
for file_name in preprocessed_files:
|
||||||
if not tf.gfile.Exists(data_dir + file):
|
if not tf.gfile.Exists(data_dir + file_name):
|
||||||
all_preprocessed = False
|
all_preprocessed = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -197,7 +197,7 @@ def extract_cifar10(local_url, data_dir):
|
||||||
else:
|
else:
|
||||||
# Do everything from scratch
|
# Do everything from scratch
|
||||||
# Define lists of all files we should extract
|
# Define lists of all files we should extract
|
||||||
train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
|
train_files = ["data_batch_" + str(i) for i in xrange(1, 6)]
|
||||||
test_file = ["test_batch"]
|
test_file = ["test_batch"]
|
||||||
cifar10_files = train_files + test_file
|
cifar10_files = train_files + test_file
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ def extract_cifar10(local_url, data_dir):
|
||||||
labels.append(labels_tmp)
|
labels.append(labels_tmp)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
|
train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32))
|
||||||
train_data = np.swapaxes(train_data, 1, 3)
|
train_data = np.swapaxes(train_data, 1, 3)
|
||||||
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ def extract_cifar10(local_url, data_dir):
|
||||||
test_data, test_images = unpickle_cifar_dic(filename)
|
test_data, test_images = unpickle_cifar_dic(filename)
|
||||||
|
|
||||||
# Convert to numpy arrays and reshape in the expected format
|
# Convert to numpy arrays and reshape in the expected format
|
||||||
test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
|
test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32))
|
||||||
test_data = np.swapaxes(test_data, 1, 3)
|
test_data = np.swapaxes(test_data, 1, 3)
|
||||||
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ def ld_svhn(extended=False, test_only=False):
|
||||||
return train_data, train_labels, test_data, test_labels
|
return train_data, train_labels, test_data, test_labels
|
||||||
else:
|
else:
|
||||||
# Return training and extended training data separately
|
# Return training and extended training data separately
|
||||||
return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
|
return train_data, train_labels, test_data, test_labels, ext_data, ext_labels
|
||||||
|
|
||||||
|
|
||||||
def ld_cifar10(test_only=False):
|
def ld_cifar10(test_only=False):
|
||||||
|
@ -377,7 +377,7 @@ def ld_mnist(test_only=False):
|
||||||
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
||||||
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
||||||
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
||||||
]
|
]
|
||||||
|
|
||||||
# Maybe download data and retrieve local storage urls
|
# Maybe download data and retrieve local storage urls
|
||||||
local_urls = maybe_download(file_urls, FLAGS.data_dir)
|
local_urls = maybe_download(file_urls, FLAGS.data_dir)
|
||||||
|
|
Loading…
Reference in a new issue