PiperOrigin-RevId: 478591776
This commit is contained in:
A. Unique TensorFlower 2022-10-03 13:32:58 -07:00
parent 3f6d0acdef
commit 0738d6f555

View file

@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData,
left_out_indices = prepared_attacker_data.left_out_indices left_out_indices = prepared_attacker_data.left_out_indices
features = prepared_attacker_data.features_all features = prepared_attacker_data.features_all
labels = prepared_attacker_data.labels_all labels = prepared_attacker_data.labels_all
sample_weights = prepared_attacker_data.sample_weights_all
# We are going to train multiple models on disjoint subsets of the data # We are going to train multiple models on disjoint subsets of the data
# (`features`, `labels`), so we can get the membership scores of all samples, # (`features`, `labels`), so we can get the membership scores of all samples,
@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData,
# Make sure one sample only got score predicted once # Make sure one sample only got score predicted once
assert np.all(np.isnan(scores[test_indices])) assert np.all(np.isnan(scores[test_indices]))
# Setup sample weights if provided.
if sample_weights is not None:
# If sample weights are provided, only the weights at the training indices
# are used for training. The weights at the test indices are not used
# during prediction. Not that 'train' and 'test' refer to the data for the
# attack models, not the data for the original models.
sample_weights_train = np.squeeze(sample_weights[train_indices])
else:
sample_weights_train = None
attacker = models.create_attacker(attack_type, backend=backend) attacker = models.create_attacker(attack_type, backend=backend)
attacker.train_model(features[train_indices], labels[train_indices]) attacker.train_model(
features[train_indices],
labels[train_indices],
sample_weight=sample_weights_train)
predictions = attacker.predict(features[test_indices]) predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions scores[test_indices] = predictions