forked from 626_privacy/tensorflow_privacy
Bugfix.
PiperOrigin-RevId: 478591776
This commit is contained in:
parent
3f6d0acdef
commit
0738d6f555
1 changed files with 15 additions and 1 deletions
|
@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
left_out_indices = prepared_attacker_data.left_out_indices
|
||||
features = prepared_attacker_data.features_all
|
||||
labels = prepared_attacker_data.labels_all
|
||||
sample_weights = prepared_attacker_data.sample_weights_all
|
||||
|
||||
# We are going to train multiple models on disjoint subsets of the data
|
||||
# (`features`, `labels`), so we can get the membership scores of all samples,
|
||||
|
@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData,
|
|||
# Make sure one sample only got score predicted once
|
||||
assert np.all(np.isnan(scores[test_indices]))
|
||||
|
||||
# Setup sample weights if provided.
|
||||
if sample_weights is not None:
|
||||
# If sample weights are provided, only the weights at the training indices
|
||||
# are used for training. The weights at the test indices are not used
|
||||
# during prediction. Not that 'train' and 'test' refer to the data for the
|
||||
# attack models, not the data for the original models.
|
||||
sample_weights_train = np.squeeze(sample_weights[train_indices])
|
||||
else:
|
||||
sample_weights_train = None
|
||||
|
||||
attacker = models.create_attacker(attack_type, backend=backend)
|
||||
attacker.train_model(features[train_indices], labels[train_indices])
|
||||
attacker.train_model(
|
||||
features[train_indices],
|
||||
labels[train_indices],
|
||||
sample_weight=sample_weights_train)
|
||||
predictions = attacker.predict(features[test_indices])
|
||||
scores[test_indices] = predictions
|
||||
|
||||
|
|
Loading…
Reference in a new issue