forked from 626_privacy/tensorflow_privacy
Fixed train/test_size calculation.
PiperOrigin-RevId: 337886488
This commit is contained in:
parent
19ae5c9669
commit
4143957701
1 changed files with 4 additions and 4 deletions
|
@ -205,12 +205,12 @@ class AttackInputData:
|
||||||
@property
|
@property
|
||||||
def logits_or_probs_train(self):
|
def logits_or_probs_train(self):
|
||||||
"""Returns train logits or probs whatever is not None."""
|
"""Returns train logits or probs whatever is not None."""
|
||||||
return self.logits_train if self.probs_train is None else self.probs_train
|
return self.probs_train or self.logits_train
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits_or_probs_test(self):
|
def logits_or_probs_test(self):
|
||||||
"""Returns test logits or probs whatever is not None."""
|
"""Returns test logits or probs whatever is not None."""
|
||||||
return self.logits_test if self.probs_test is None else self.probs_test
|
return self.probs_test or self.logits_test
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||||
|
@ -278,13 +278,13 @@ class AttackInputData:
|
||||||
"""Returns size of the training set."""
|
"""Returns size of the training set."""
|
||||||
if self.loss_train is not None:
|
if self.loss_train is not None:
|
||||||
return self.loss_train.size
|
return self.loss_train.size
|
||||||
return self.logits_train.shape[0]
|
return self.logits_or_probs_train.shape[0]
|
||||||
|
|
||||||
def get_test_size(self):
|
def get_test_size(self):
|
||||||
"""Returns size of the test set."""
|
"""Returns size of the test set."""
|
||||||
if self.loss_test is not None:
|
if self.loss_test is not None:
|
||||||
return self.loss_test.size
|
return self.loss_test.size
|
||||||
return self.logits_test.shape[0]
|
return self.logits_or_probs_test.shape[0]
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Validates the inputs."""
|
"""Validates the inputs."""
|
||||||
|
|
Loading…
Reference in a new issue