diff --git a/research/pate_2017/input.py b/research/pate_2017/input.py index c325941..65c300b 100644 --- a/research/pate_2017/input.py +++ b/research/pate_2017/input.py @@ -100,14 +100,14 @@ def image_whitening(data): nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3] # Subtract mean - mean = np.mean(data, axis=(1,2,3)) + mean = np.mean(data, axis=(1, 2, 3)) ones = np.ones(np.shape(data)[1:4], dtype=np.float32) for i in xrange(len(data)): data[i, :, :, :] -= mean[i] * ones # Compute adjusted standard variance - adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line) + adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1, 2, 3))) # pylint: disable=line-too-long # Divide image for i in xrange(len(data)): @@ -148,15 +148,15 @@ def extract_svhn(local_url): return data, labels -def unpickle_cifar_dic(file): +def unpickle_cifar_dic(file): # pylint: disable=redefined-builtin """ Helper function: unpickles a dictionary (used for loading CIFAR) :param file: filename of the pickle :return: tuple of (images, labels) """ - fo = open(file, 'rb') - data_dict = pickle.load(fo) - fo.close() + file_obj = open(file, 'rb') + data_dict = pickle.load(file_obj) + file_obj.close() return data_dict['data'], data_dict['labels'] @@ -176,8 +176,8 @@ def extract_cifar10(local_url, data_dir): '/cifar10_test_labels.npy'] all_preprocessed = True - for file in preprocessed_files: - if not tf.gfile.Exists(data_dir + file): + for file_name in preprocessed_files: + if not tf.gfile.Exists(data_dir + file_name): all_preprocessed = False break @@ -197,7 +197,7 @@ def extract_cifar10(local_url, data_dir): else: # Do everything from scratch # Define lists of all files we should extract - train_files = ["data_batch_" + str(i) for i in xrange(1,6)] + train_files = ["data_batch_" + str(i) for i in xrange(1, 6)] test_file = ["test_batch"] cifar10_files = train_files + test_file @@ -227,7 +227,7 @@ def extract_cifar10(local_url, data_dir): labels.append(labels_tmp) # Convert to numpy arrays and reshape in the expected format - train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32)) + train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32)) train_data = np.swapaxes(train_data, 1, 3) train_labels = np.asarray(labels, dtype=np.int32).reshape(50000) @@ -242,7 +242,7 @@ def extract_cifar10(local_url, data_dir): test_data, test_images = unpickle_cifar_dic(filename) # Convert to numpy arrays and reshape in the expected format - test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32)) + test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32)) test_data = np.swapaxes(test_data, 1, 3) test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000) @@ -332,7 +332,7 @@ def ld_svhn(extended=False, test_only=False): return train_data, train_labels, test_data, test_labels else: # Return training and extended training data separately - return train_data,train_labels, test_data,test_labels, ext_data,ext_labels + return train_data, train_labels, test_data, test_labels, ext_data, ext_labels def ld_cifar10(test_only=False): @@ -377,7 +377,7 @@ def ld_mnist(test_only=False): 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', - ] + ] # Maybe download data and retrieve local storage urls local_urls = maybe_download(file_urls, FLAGS.data_dir)