diff --git a/research/pate_2017/input.py b/research/pate_2017/input.py index 86f302c..d838806 100644 --- a/research/pate_2017/input.py +++ b/research/pate_2017/input.py @@ -110,7 +110,7 @@ def image_whitening(data): def extract_svhn(local_url): - """Extract a MATLAB matrix into two numpy arrays with data and labels""" + """Extract a MATLAB matrix into two numpy arrays with data and labels.""" with tf.gfile.Open(local_url, mode='r') as file_obj: # Load MATLAB matrix using scipy IO @@ -136,7 +136,7 @@ def extract_svhn(local_url): def unpickle_cifar_dic(file_path): - """Helper function: unpickles a dictionary (used for loading CIFAR)""" + """Helper function: unpickles a dictionary (used for loading CIFAR).""" file_obj = open(file_path, 'rb') data_dict = pickle.load(file_obj) file_obj.close() @@ -144,7 +144,7 @@ def unpickle_cifar_dic(file_path): def extract_cifar10(local_url, data_dir): - """Extracts CIFAR-10 and return numpy arrays with the different sets""" + """Extracts CIFAR-10 and return numpy arrays with the different sets.""" # These numpy dumps can be reloaded to avoid performing the pre-processing # if they exist in the working directory. @@ -206,7 +206,8 @@ def extract_cifar10(local_url, data_dir): labels.append(labels_tmp) # Convert to numpy arrays and reshape in the expected format - train_data = np.asarray(images, dtype=np.float32).reshape((50000, 3, 32, 32)) + train_data = np.asarray(images, dtype=np.float32) + train_data = train_data.reshape((50000, 3, 32, 32)) train_data = np.swapaxes(train_data, 1, 3) train_labels = np.asarray(labels, dtype=np.int32).reshape(50000) @@ -221,7 +222,8 @@ def extract_cifar10(local_url, data_dir): test_data, test_images = unpickle_cifar_dic(filename) # Convert to numpy arrays and reshape in the expected format - test_data = np.asarray(test_data, dtype=np.float32).reshape((10000, 3, 32, 32)) + test_data = np.asarray(test_data, dtype=np.float32) + test_data = test_data.reshape((10000, 3, 32, 32)) test_data = np.swapaxes(test_data, 1, 3) test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000) @@ -314,7 +316,7 @@ def ld_svhn(extended=False, test_only=False): def ld_cifar10(test_only=False): - """Load the original CIFAR10 data""" + """Load the original CIFAR10 data.""" # Define files to be downloaded file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']