Fixed train/test_size calculation.

PiperOrigin-RevId: 337886488
This commit is contained in:
Vadym Doroshenko 2020-10-19 10:37:47 -07:00 committed by A. Unique TensorFlower
parent 19ae5c9669
commit 4143957701

View file

@ -205,12 +205,12 @@ class AttackInputData:
@property
def logits_or_probs_train(self):
"""Returns train logits or probs whatever is not None."""
return self.logits_train if self.probs_train is None else self.probs_train
return self.probs_train or self.logits_train
@property
def logits_or_probs_test(self):
"""Returns test logits or probs whatever is not None."""
return self.logits_test if self.probs_test is None else self.probs_test
return self.probs_test or self.logits_test
@staticmethod
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
@ -278,13 +278,13 @@ class AttackInputData:
"""Returns size of the training set."""
if self.loss_train is not None:
return self.loss_train.size
return self.logits_train.shape[0]
return self.logits_or_probs_train.shape[0]
def get_test_size(self):
"""Returns size of the test set."""
if self.loss_test is not None:
return self.loss_test.size
return self.logits_test.shape[0]
return self.logits_or_probs_test.shape[0]
def validate(self):
"""Validates the inputs."""