Move initialization for privacy_report_metadata to args

This commit is contained in:
amad-person 2020-11-27 18:03:18 +08:00
parent 981d5a95f5
commit 6c7d607e65

View file

@ -248,9 +248,10 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata,
test_fraction: float = 0.25, test_fraction: float = 0.25,
balance: bool = True) -> AttackerData: balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
) -> AttackerData:
"""Prepares Seq2SeqAttackInputData to train ML attackers. """Prepares Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test Uses logits and losses to generate ranks and performs a random train-test
@ -261,11 +262,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
Args: Args:
attack_input_data: Original Seq2SeqAttackInputData attack_input_data: Original Seq2SeqAttackInputData
privacy_report_metadata: the metadata of the model under attack.
test_fraction: Fraction of the dataset to include in the test split. test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack. training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns: Returns:
AttackerData. AttackerData.
@ -297,7 +298,6 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
features_all, labels_all, test_size=test_fraction, stratify=labels_all) features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate accuracy, loss fields in privacy report metadata # Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
privacy_report_metadata.loss_train = loss_train privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train privacy_report_metadata.accuracy_train = accuracy_train