diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index aea8e78..fbafcd6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -205,12 +205,12 @@ class AttackInputData: @property def logits_or_probs_train(self): """Returns train logits or probs whatever is not None.""" - return self.logits_train if self.probs_train is None else self.probs_train + return self.probs_train or self.logits_train @property def logits_or_probs_test(self): """Returns test logits or probs whatever is not None.""" - return self.logits_test if self.probs_test is None else self.probs_test + return self.probs_test or self.logits_test @staticmethod def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): @@ -278,13 +278,13 @@ class AttackInputData: """Returns size of the training set.""" if self.loss_train is not None: return self.loss_train.size - return self.logits_train.shape[0] + return self.logits_or_probs_train.shape[0] def get_test_size(self): """Returns size of the test set.""" if self.loss_test is not None: return self.loss_test.size - return self.logits_test.shape[0] + return self.logits_or_probs_test.shape[0] def validate(self): """Validates the inputs."""