Bugfix.
PiperOrigin-RevId: 478591776
This commit is contained in:
parent
3f6d0acdef
commit
0738d6f555
1 changed files with 15 additions and 1 deletions
|
@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
left_out_indices = prepared_attacker_data.left_out_indices
|
left_out_indices = prepared_attacker_data.left_out_indices
|
||||||
features = prepared_attacker_data.features_all
|
features = prepared_attacker_data.features_all
|
||||||
labels = prepared_attacker_data.labels_all
|
labels = prepared_attacker_data.labels_all
|
||||||
|
sample_weights = prepared_attacker_data.sample_weights_all
|
||||||
|
|
||||||
# We are going to train multiple models on disjoint subsets of the data
|
# We are going to train multiple models on disjoint subsets of the data
|
||||||
# (`features`, `labels`), so we can get the membership scores of all samples,
|
# (`features`, `labels`), so we can get the membership scores of all samples,
|
||||||
|
@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData,
|
||||||
# Make sure one sample only got score predicted once
|
# Make sure one sample only got score predicted once
|
||||||
assert np.all(np.isnan(scores[test_indices]))
|
assert np.all(np.isnan(scores[test_indices]))
|
||||||
|
|
||||||
|
# Setup sample weights if provided.
|
||||||
|
if sample_weights is not None:
|
||||||
|
# If sample weights are provided, only the weights at the training indices
|
||||||
|
# are used for training. The weights at the test indices are not used
|
||||||
|
# during prediction. Not that 'train' and 'test' refer to the data for the
|
||||||
|
# attack models, not the data for the original models.
|
||||||
|
sample_weights_train = np.squeeze(sample_weights[train_indices])
|
||||||
|
else:
|
||||||
|
sample_weights_train = None
|
||||||
|
|
||||||
attacker = models.create_attacker(attack_type, backend=backend)
|
attacker = models.create_attacker(attack_type, backend=backend)
|
||||||
attacker.train_model(features[train_indices], labels[train_indices])
|
attacker.train_model(
|
||||||
|
features[train_indices],
|
||||||
|
labels[train_indices],
|
||||||
|
sample_weight=sample_weights_train)
|
||||||
predictions = attacker.predict(features[test_indices])
|
predictions = attacker.predict(features[test_indices])
|
||||||
scores[test_indices] = predictions
|
scores[test_indices] = predictions
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue