Bugfix for logits_or_probs with an accompanying test.

PiperOrigin-RevId: 341604420
This commit is contained in:
David Marn 2020-11-10 06:07:58 -08:00 committed by A. Unique TensorFlower
parent f0daaf085f
commit caf71c11bc
3 changed files with 22 additions and 7 deletions

View file

@ -207,12 +207,16 @@ class AttackInputData:
@property
def logits_or_probs_train(self):
"""Returns train logits or probs whatever is not None."""
return self.probs_train or self.logits_train
if self.logits_train is not None:
return self.logits_train
return self.probs_train
@property
def logits_or_probs_test(self):
"""Returns test logits or probs whatever is not None."""
return self.probs_test or self.logits_test
if self.logits_test is not None:
return self.logits_test
return self.probs_test
@staticmethod
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):

View file

@ -82,6 +82,16 @@ class AttackInputDataTest(absltest.TestCase):
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
[1.0, 4.0, 6.0])
def test_get_probs_sizes(self):
attack_input = AttackInputData(
probs_train=np.array([[0.1, 0.1, 0.8], [0.8, 0.2, 0]]),
probs_test=np.array([[0, 0.0001, 0.9999]]),
labels_train=np.array([1, 0]),
labels_test=np.array([0]))
np.testing.assert_equal(attack_input.get_train_size(), 2)
np.testing.assert_equal(attack_input.get_test_size(), 1)
def test_get_entropy(self):
attack_input = AttackInputData(
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),

View file

@ -76,17 +76,18 @@ def _slice_by_percentiles(data: AttackInputData, from_percentile: float,
return _slice_data_by_indices(data, idx_train, idx_test)
def _indices_by_classification(logits, labels, correctly_classified):
idx_correct = labels == np.argmax(logits, axis=1)
def _indices_by_classification(logits_or_probs, labels, correctly_classified):
idx_correct = labels == np.argmax(logits_or_probs, axis=1)
return idx_correct if correctly_classified else np.invert(idx_correct)
def _slice_by_classification_correctness(data: AttackInputData,
correctly_classified: bool):
idx_train = _indices_by_classification(data.logits_train, data.labels_train,
idx_train = _indices_by_classification(data.logits_or_probs_train,
data.labels_train,
correctly_classified)
idx_test = _indices_by_classification(data.logits_test, data.labels_test,
correctly_classified)
idx_test = _indices_by_classification(data.logits_or_probs_test,
data.labels_test, correctly_classified)
return _slice_data_by_indices(data, idx_train, idx_test)