forked from 626_privacy/tensorflow_privacy
Bugfix for logits_or_probs with an accompanying test.
PiperOrigin-RevId: 341604420
This commit is contained in:
parent
f0daaf085f
commit
caf71c11bc
3 changed files with 22 additions and 7 deletions
|
@ -207,12 +207,16 @@ class AttackInputData:
|
|||
@property
|
||||
def logits_or_probs_train(self):
|
||||
"""Returns train logits or probs whatever is not None."""
|
||||
return self.probs_train or self.logits_train
|
||||
if self.logits_train is not None:
|
||||
return self.logits_train
|
||||
return self.probs_train
|
||||
|
||||
@property
|
||||
def logits_or_probs_test(self):
|
||||
"""Returns test logits or probs whatever is not None."""
|
||||
return self.probs_test or self.logits_test
|
||||
if self.logits_test is not None:
|
||||
return self.logits_test
|
||||
return self.probs_test
|
||||
|
||||
@staticmethod
|
||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||
|
|
|
@ -82,6 +82,16 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||
[1.0, 4.0, 6.0])
|
||||
|
||||
def test_get_probs_sizes(self):
|
||||
attack_input = AttackInputData(
|
||||
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
||||
probs_test=np.array([[0, 0.0001, 0.9999]]),
|
||||
labels_train=np.array([1, 0]),
|
||||
labels_test=np.array([0]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_train_size(), 2)
|
||||
np.testing.assert_equal(attack_input.get_test_size(), 1)
|
||||
|
||||
def test_get_entropy(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||
|
|
|
@ -76,17 +76,18 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
|||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
def _indices_by_classification(logits, labels, correctly_classified):
|
||||
idx_correct = labels == np.argmax(logits, axis=1)
|
||||
def _indices_by_classification(logits_or_probs, labels, correctly_classified):
|
||||
idx_correct = labels == np.argmax(logits_or_probs, axis=1)
|
||||
return idx_correct if correctly_classified else np.invert(idx_correct)
|
||||
|
||||
|
||||
def _slice_by_classification_correctness(data: AttackInputData,
|
||||
correctly_classified: bool):
|
||||
idx_train = _indices_by_classification(data.logits_train, data.labels_train,
|
||||
idx_train = _indices_by_classification(data.logits_or_probs_train,
|
||||
data.labels_train,
|
||||
correctly_classified)
|
||||
idx_test = _indices_by_classification(data.logits_test, data.labels_test,
|
||||
correctly_classified)
|
||||
idx_test = _indices_by_classification(data.logits_or_probs_test,
|
||||
data.labels_test, correctly_classified)
|
||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue