diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py index 87552b9..c60e936 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py @@ -248,9 +248,10 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray, def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, - privacy_report_metadata: PrivacyReportMetadata, test_fraction: float = 0.25, - balance: bool = True) -> AttackerData: + balance: bool = True, + privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata() + ) -> AttackerData: """Prepares Seq2SeqAttackInputData to train ML attackers. Uses logits and losses to generate ranks and performs a random train-test @@ -261,11 +262,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, Args: attack_input_data: Original Seq2SeqAttackInputData - privacy_report_metadata: the metadata of the model under attack. test_fraction: Fraction of the dataset to include in the test split. balance: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the model under attack. + privacy_report_metadata: the metadata of the model under attack. Returns: AttackerData. @@ -297,7 +298,6 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, features_all, labels_all, test_size=test_fraction, stratify=labels_all) # Populate accuracy, loss fields in privacy report metadata - privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() privacy_report_metadata.loss_train = loss_train privacy_report_metadata.loss_test = loss_test privacy_report_metadata.accuracy_train = accuracy_train