PiperOrigin-RevId: 486344068
This commit is contained in:
A. Unique TensorFlower 2022-11-05 05:18:26 -07:00
parent f7e1e61823
commit e334633466
2 changed files with 10 additions and 7 deletions

View file

@ -42,9 +42,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
result.loss_train = _slice_if_not_none(data.loss_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train) result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
# Copy over sample weights if provided. # Slice sample weights if provided.
result.sample_weight_train = data.sample_weight_train result.sample_weight_train = _slice_if_not_none(data.sample_weight_train,
result.sample_weight_test = data.sample_weight_test idx_train)
# Slice test data. # Slice test data.
result.logits_test = _slice_if_not_none(data.logits_test, idx_test) result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
@ -52,6 +52,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
result.labels_test = _slice_if_not_none(data.labels_test, idx_test) result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
result.loss_test = _slice_if_not_none(data.loss_test, idx_test) result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test) result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
# Slice sample weights if provided.
result.sample_weight_test = _slice_if_not_none(data.sample_weight_test,
idx_test)
# A slice has the same multilabel status as the original data. This is because # A slice has the same multilabel status as the original data. This is because
# of the way multilabel status is computed. A dataset is multilabel if at # of the way multilabel status is computed. A dataset is multilabel if at

View file

@ -115,8 +115,8 @@ class GetSliceTest(absltest.TestCase):
loss_test = np.array([0.5, 3.5, 7, 4.5]) loss_test = np.array([0.5, 3.5, 7, 4.5])
entropy_train = np.array([0.4, 8, 0.6, 10]) entropy_train = np.array([0.4, 8, 0.6, 10])
entropy_test = np.array([15, 10.5, 4.5, 0.3]) entropy_test = np.array([15, 10.5, 4.5, 0.3])
sample_weight_train = np.array([1.0, 0.5]) sample_weight_train = np.array([1.0, 0.2, 0.5, 0.8])
sample_weight_test = np.array([0.5, 1.0]) sample_weight_test = np.array([0.5, 1.0, 0.1, 0.8])
self.input_data = AttackInputData( self.input_data = AttackInputData(
logits_train=logits_train, logits_train=logits_train,
@ -175,8 +175,8 @@ class GetSliceTest(absltest.TestCase):
# Check sample weights # Check sample weights
self.assertLen(output.sample_weight_train, 2) self.assertLen(output.sample_weight_train, 2)
np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5]) np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5])
self.assertLen(output.sample_weight_test, 2) self.assertLen(output.sample_weight_test, 1)
np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0]) np.testing.assert_array_equal(output.sample_weight_test, [0.5])
def test_slice_by_percentile(self): def test_slice_by_percentile(self):
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))