Fixed train/test_size calculation.
PiperOrigin-RevId: 337886488
This commit is contained in:
parent
19ae5c9669
commit
4143957701
1 changed files with 4 additions and 4 deletions
|
@ -205,12 +205,12 @@ class AttackInputData:
|
|||
@property
|
||||
def logits_or_probs_train(self):
|
||||
"""Returns train logits or probs whatever is not None."""
|
||||
return self.logits_train if self.probs_train is None else self.probs_train
|
||||
return self.probs_train or self.logits_train
|
||||
|
||||
@property
|
||||
def logits_or_probs_test(self):
|
||||
"""Returns test logits or probs whatever is not None."""
|
||||
return self.logits_test if self.probs_test is None else self.probs_test
|
||||
return self.probs_test or self.logits_test
|
||||
|
||||
@staticmethod
|
||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||
|
@ -278,13 +278,13 @@ class AttackInputData:
|
|||
"""Returns size of the training set."""
|
||||
if self.loss_train is not None:
|
||||
return self.loss_train.size
|
||||
return self.logits_train.shape[0]
|
||||
return self.logits_or_probs_train.shape[0]
|
||||
|
||||
def get_test_size(self):
|
||||
"""Returns size of the test set."""
|
||||
if self.loss_test is not None:
|
||||
return self.loss_test.size
|
||||
return self.logits_test.shape[0]
|
||||
return self.logits_or_probs_test.shape[0]
|
||||
|
||||
def validate(self):
|
||||
"""Validates the inputs."""
|
||||
|
|
Loading…
Reference in a new issue