From 41439577017f1e5e652ea5b583de8a0688f0c148 Mon Sep 17 00:00:00 2001 From: Vadym Doroshenko Date: Mon, 19 Oct 2020 10:37:47 -0700 Subject: [PATCH] Fixed train/test_size calculation. PiperOrigin-RevId: 337886488 --- .../membership_inference_attack/data_structures.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index aea8e78..fbafcd6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -205,12 +205,12 @@ class AttackInputData: @property def logits_or_probs_train(self): """Returns train logits or probs whatever is not None.""" - return self.logits_train if self.probs_train is None else self.probs_train + return self.probs_train or self.logits_train @property def logits_or_probs_test(self): """Returns test logits or probs whatever is not None.""" - return self.logits_test if self.probs_test is None else self.probs_test + return self.probs_test or self.logits_test @staticmethod def _get_entropy(logits: np.ndarray, true_labels: np.ndarray): @@ -278,13 +278,13 @@ class AttackInputData: """Returns size of the training set.""" if self.loss_train is not None: return self.loss_train.size - return self.logits_train.shape[0] + return self.logits_or_probs_train.shape[0] def get_test_size(self): """Returns size of the test set.""" if self.loss_test is not None: return self.loss_test.size - return self.logits_test.shape[0] + return self.logits_or_probs_test.shape[0] def validate(self): """Validates the inputs."""