diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 86197f9..37a8fb2 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData, left_out_indices = prepared_attacker_data.left_out_indices features = prepared_attacker_data.features_all labels = prepared_attacker_data.labels_all + sample_weights = prepared_attacker_data.sample_weights_all # We are going to train multiple models on disjoint subsets of the data # (`features`, `labels`), so we can get the membership scores of all samples, @@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData, # Make sure one sample only got score predicted once assert np.all(np.isnan(scores[test_indices])) + # Setup sample weights if provided. + if sample_weights is not None: + # If sample weights are provided, only the weights at the training indices + # are used for training. The weights at the test indices are not used + # during prediction. Not that 'train' and 'test' refer to the data for the + # attack models, not the data for the original models. + sample_weights_train = np.squeeze(sample_weights[train_indices]) + else: + sample_weights_train = None + attacker = models.create_attacker(attack_type, backend=backend) - attacker.train_model(features[train_indices], labels[train_indices]) + attacker.train_model( + features[train_indices], + labels[train_indices], + sample_weight=sample_weights_train) predictions = attacker.predict(features[test_indices]) scores[test_indices] = predictions