forked from 626_privacy/tensorflow_privacy
Bugfix for logits_or_probs with an accompanying test.
PiperOrigin-RevId: 341604420
This commit is contained in:
parent
f0daaf085f
commit
caf71c11bc
3 changed files with 22 additions and 7 deletions
|
@ -207,12 +207,16 @@ class AttackInputData:
|
||||||
@property
|
@property
|
||||||
def logits_or_probs_train(self):
|
def logits_or_probs_train(self):
|
||||||
"""Returns train logits or probs whatever is not None."""
|
"""Returns train logits or probs whatever is not None."""
|
||||||
return self.probs_train or self.logits_train
|
if self.logits_train is not None:
|
||||||
|
return self.logits_train
|
||||||
|
return self.probs_train
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits_or_probs_test(self):
|
def logits_or_probs_test(self):
|
||||||
"""Returns test logits or probs whatever is not None."""
|
"""Returns test logits or probs whatever is not None."""
|
||||||
return self.probs_test or self.logits_test
|
if self.logits_test is not None:
|
||||||
|
return self.logits_test
|
||||||
|
return self.probs_test
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||||
|
|
|
@ -82,6 +82,16 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
[1.0, 4.0, 6.0])
|
[1.0, 4.0, 6.0])
|
||||||
|
|
||||||
|
def test_get_probs_sizes(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
|
||||||
|
probs_test=np.array([[0, 0.0001, 0.9999]]),
|
||||||
|
labels_train=np.array([1, 0]),
|
||||||
|
labels_test=np.array([0]))
|
||||||
|
|
||||||
|
np.testing.assert_equal(attack_input.get_train_size(), 2)
|
||||||
|
np.testing.assert_equal(attack_input.get_test_size(), 1)
|
||||||
|
|
||||||
def test_get_entropy(self):
|
def test_get_entropy(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||||
|
|
|
@ -76,17 +76,18 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
|
||||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
|
||||||
|
|
||||||
def _indices_by_classification(logits, labels, correctly_classified):
|
def _indices_by_classification(logits_or_probs, labels, correctly_classified):
|
||||||
idx_correct = labels == np.argmax(logits, axis=1)
|
idx_correct = labels == np.argmax(logits_or_probs, axis=1)
|
||||||
return idx_correct if correctly_classified else np.invert(idx_correct)
|
return idx_correct if correctly_classified else np.invert(idx_correct)
|
||||||
|
|
||||||
|
|
||||||
def _slice_by_classification_correctness(data: AttackInputData,
|
def _slice_by_classification_correctness(data: AttackInputData,
|
||||||
correctly_classified: bool):
|
correctly_classified: bool):
|
||||||
idx_train = _indices_by_classification(data.logits_train, data.labels_train,
|
idx_train = _indices_by_classification(data.logits_or_probs_train,
|
||||||
|
data.labels_train,
|
||||||
correctly_classified)
|
correctly_classified)
|
||||||
idx_test = _indices_by_classification(data.logits_test, data.labels_test,
|
idx_test = _indices_by_classification(data.logits_or_probs_test,
|
||||||
correctly_classified)
|
data.labels_test, correctly_classified)
|
||||||
return _slice_data_by_indices(data, idx_train, idx_test)
|
return _slice_data_by_indices(data, idx_train, idx_test)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue