Bugfix.
PiperOrigin-RevId: 486344068
This commit is contained in:
parent
f7e1e61823
commit
e334633466
2 changed files with 10 additions and 7 deletions
|
@ -42,9 +42,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
|
||||||
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
|
||||||
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
|
||||||
# Copy over sample weights if provided.
|
# Slice sample weights if provided.
|
||||||
result.sample_weight_train = data.sample_weight_train
|
result.sample_weight_train = _slice_if_not_none(data.sample_weight_train,
|
||||||
result.sample_weight_test = data.sample_weight_test
|
idx_train)
|
||||||
|
|
||||||
# Slice test data.
|
# Slice test data.
|
||||||
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
|
||||||
|
@ -52,6 +52,9 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
|
||||||
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
|
||||||
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
|
||||||
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
|
||||||
|
# Slice sample weights if provided.
|
||||||
|
result.sample_weight_test = _slice_if_not_none(data.sample_weight_test,
|
||||||
|
idx_test)
|
||||||
|
|
||||||
# A slice has the same multilabel status as the original data. This is because
|
# A slice has the same multilabel status as the original data. This is because
|
||||||
# of the way multilabel status is computed. A dataset is multilabel if at
|
# of the way multilabel status is computed. A dataset is multilabel if at
|
||||||
|
|
|
@ -115,8 +115,8 @@ class GetSliceTest(absltest.TestCase):
|
||||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||||
entropy_train = np.array([0.4, 8, 0.6, 10])
|
entropy_train = np.array([0.4, 8, 0.6, 10])
|
||||||
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
||||||
sample_weight_train = np.array([1.0, 0.5])
|
sample_weight_train = np.array([1.0, 0.2, 0.5, 0.8])
|
||||||
sample_weight_test = np.array([0.5, 1.0])
|
sample_weight_test = np.array([0.5, 1.0, 0.1, 0.8])
|
||||||
|
|
||||||
self.input_data = AttackInputData(
|
self.input_data = AttackInputData(
|
||||||
logits_train=logits_train,
|
logits_train=logits_train,
|
||||||
|
@ -175,8 +175,8 @@ class GetSliceTest(absltest.TestCase):
|
||||||
# Check sample weights
|
# Check sample weights
|
||||||
self.assertLen(output.sample_weight_train, 2)
|
self.assertLen(output.sample_weight_train, 2)
|
||||||
np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5])
|
np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5])
|
||||||
self.assertLen(output.sample_weight_test, 2)
|
self.assertLen(output.sample_weight_test, 1)
|
||||||
np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0])
|
np.testing.assert_array_equal(output.sample_weight_test, [0.5])
|
||||||
|
|
||||||
def test_slice_by_percentile(self):
|
def test_slice_by_percentile(self):
|
||||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
|
|
Loading…
Reference in a new issue