diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 082a90b..e323b2f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -22,6 +22,7 @@ from dataclasses import dataclass import numpy as np import pandas as pd from sklearn import metrics +from scipy import special ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)' @@ -144,6 +145,9 @@ def _is_np_array(arr, arr_name): if arr is not None and not isinstance(arr, np.ndarray): raise ValueError('%s should be a numpy array.' % arr_name) +def _log_value(probs, small_value=1e-30): + """Compute the log value on the probability. Clip the probability in case it is close to 0""" + return -np.log(np.maximum(probs, small_value)) @dataclass class AttackInputData: @@ -165,6 +169,11 @@ class AttackInputData: loss_train: np.ndarray = None loss_test: np.ndarray = None + # Explicitly specified prediction entropy. If provided, this is used instead of deriving + # entropy from logits and labels (https://arxiv.org/pdf/2003.10595.pdf by Song and Mittal) + entropy_train: np.ndarray = None + entropy_test: np.ndarray = None + @property def num_classes(self): if self.labels_train is None or self.labels_test is None: @@ -177,6 +186,34 @@ class AttackInputData: def _get_loss(logits: np.ndarray, true_labels: np.ndarray): return logits[range(logits.shape[0]), true_labels] + @staticmethod + def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): + if (np.absolute(np.sum(logits,axis=1)-1)<=1e-3).all(): + probs = logits + else: + """Using softmax to compute probability from logits""" + probs = special.softmax(logits, axis=1) + if true_labels is None: + ''' + When not given ground truth label, we compute the normal prediction entropy. + See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf + ''' + return np.sum(np.multiply(probs, _log_value(probs)),axis=1) + else: + ''' + When given the groud truth label, we compute the modified prediction entropy. + See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf + ''' + log_probs = _log_value(probs) + reverse_probs = 1-probs + log_reverse_probs = _log_value(reverse_probs) + modified_probs = np.copy(probs) + modified_probs[range(true_labels.size), true_labels] = reverse_probs[range(true_labels.size), true_labels] + modified_log_probs = np.copy(log_reverse_probs) + modified_log_probs[range(true_labels.size), true_labels] = log_probs[range(true_labels.size), true_labels] + return np.sum(np.multiply(modified_probs, modified_log_probs),axis=1) + + def get_loss_train(self): """Calculates cross-entropy losses for the training set.""" if self.loss_train is not None: @@ -189,6 +226,18 @@ class AttackInputData: return self.loss_test return self._get_loss(self.logits_test, self.labels_test) + def get_entropy_train(self): + """Calculates prediction entropy for the training set.""" + if self.entropy_train is not None: + return self.entropy_train + return self._get_entropy(self.logits_train, self.labels_train) + + def get_entropy_test(self): + """Calculates prediction entropy for the test set.""" + if self.entropy_test is not None: + return self.entropy_test + return self._get_entropy(self.logits_test, self.labels_test) + def get_train_size(self): """Returns size of the training set.""" if self.loss_train is not None: @@ -206,6 +255,10 @@ class AttackInputData: if (self.loss_train is None) != (self.loss_test is None): raise ValueError( 'loss_test and loss_train should both be either set or unset') + + if (self.entropy_train is None) != (self.entropy_test is None): + raise ValueError( + 'entropy_test and entropy_train should both be either set or unset') if (self.logits_train is None) != (self.logits_test is None): raise ValueError( @@ -216,8 +269,8 @@ class AttackInputData: 'labels_train and labels_test should both be either set or unset') if (self.labels_train is None and self.loss_train is None and - self.logits_train is None): - raise ValueError('At least one of labels, logits or losses should be set') + self.logits_train is None and self.entropy_train is None): + raise ValueError('At least one of labels, logits, losses or entropy should be set') if self.labels_train is not None and not _is_integer_type_array( self.labels_train): @@ -233,11 +286,15 @@ class AttackInputData: _is_np_array(self.labels_test, 'labels_test') _is_np_array(self.loss_train, 'loss_train') _is_np_array(self.loss_test, 'loss_test') + _is_np_array(self.entropy_train, 'entropy_train') + _is_np_array(self.entropy_test, 'entropy_test') _is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test, 'logits_test') _is_array_one_dimensional(self.loss_train, 'loss_train') _is_array_one_dimensional(self.loss_test, 'loss_test') + _is_array_one_dimensional(self.entropy_train, 'entropy_train') + _is_array_one_dimensional(self.entropy_test, 'entropy_test') _is_array_one_dimensional(self.labels_train, 'labels_train') _is_array_one_dimensional(self.labels_test, 'labels_test') @@ -246,6 +303,8 @@ class AttackInputData: result = ['AttackInputData('] _append_array_shape(self.loss_train, 'loss_train', result) _append_array_shape(self.loss_test, 'loss_test', result) + _append_array_shape(self.entropy_train, 'entropy_train', result) + _append_array_shape(self.entropy_test, 'entropy_test', result) _append_array_shape(self.logits_train, 'logits_train', result) _append_array_shape(self.logits_test, 'logits_test', result) _append_array_shape(self.labels_train, 'labels_train', result)