diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 6b79abe..a784d51 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -82,10 +82,8 @@ def create_attacker_data(attack_input_data: AttackInputData, labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest)))) # Perform a train-test split - features_train, features_test, \ - is_training_labels_train, is_training_labels_test = \ - model_selection.train_test_split( - features_all, labels_all, test_size=test_fraction, stratify=labels_all) + features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split( + features_all, labels_all, test_size=test_fraction, stratify=labels_all) return AttackerData(features_train, is_training_labels_train, features_test, is_training_labels_test, DataSize(ntrain=ntrain, ntest=ntest))