From 6c7d607e6531add628b6c365d5adb222faa6dcc3 Mon Sep 17 00:00:00 2001 From: amad-person Date: Fri, 27 Nov 2020 18:03:18 +0800 Subject: [PATCH] Move initialization for privacy_report_metadata to args --- .../privacy/membership_inference_attack/seq2seq_mia.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py index 87552b9..c60e936 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py @@ -248,9 +248,10 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray, def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, - privacy_report_metadata: PrivacyReportMetadata, test_fraction: float = 0.25, - balance: bool = True) -> AttackerData: + balance: bool = True, + privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata() + ) -> AttackerData: """Prepares Seq2SeqAttackInputData to train ML attackers. Uses logits and losses to generate ranks and performs a random train-test @@ -261,11 +262,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, Args: attack_input_data: Original Seq2SeqAttackInputData - privacy_report_metadata: the metadata of the model under attack. test_fraction: Fraction of the dataset to include in the test split. balance: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the model under attack. + privacy_report_metadata: the metadata of the model under attack. Returns: AttackerData. @@ -297,7 +298,6 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, features_all, labels_all, test_size=test_fraction, stratify=labels_all) # Populate accuracy, loss fields in privacy report metadata - privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() privacy_report_metadata.loss_train = loss_train privacy_report_metadata.loss_test = loss_test privacy_report_metadata.accuracy_train = accuracy_train