Move initialization for privacy_report_metadata to args

This commit is contained in:
amad-person 2020-11-27 18:03:18 +08:00
parent 981d5a95f5
commit 6c7d607e65

View file

@ -248,9 +248,10 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata,
test_fraction: float = 0.25,
balance: bool = True) -> AttackerData:
balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
) -> AttackerData:
"""Prepares Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
@ -261,11 +262,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
Args:
attack_input_data: Original Seq2SeqAttackInputData
privacy_report_metadata: the metadata of the model under attack.
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns:
AttackerData.
@ -297,7 +298,6 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train