forked from 626_privacy/tensorflow_privacy
Move initialization for privacy_report_metadata to args
This commit is contained in:
parent
981d5a95f5
commit
6c7d607e65
1 changed files with 4 additions and 4 deletions
|
@ -248,9 +248,10 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
||||||
|
|
||||||
|
|
||||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
privacy_report_metadata: PrivacyReportMetadata,
|
|
||||||
test_fraction: float = 0.25,
|
test_fraction: float = 0.25,
|
||||||
balance: bool = True) -> AttackerData:
|
balance: bool = True,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
|
||||||
|
) -> AttackerData:
|
||||||
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
||||||
|
|
||||||
Uses logits and losses to generate ranks and performs a random train-test
|
Uses logits and losses to generate ranks and performs a random train-test
|
||||||
|
@ -261,11 +262,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attack_input_data: Original Seq2SeqAttackInputData
|
attack_input_data: Original Seq2SeqAttackInputData
|
||||||
privacy_report_metadata: the metadata of the model under attack.
|
|
||||||
test_fraction: Fraction of the dataset to include in the test split.
|
test_fraction: Fraction of the dataset to include in the test split.
|
||||||
balance: Whether the training and test sets for the membership inference
|
balance: Whether the training and test sets for the membership inference
|
||||||
attacker should have a balanced (roughly equal) number of samples from the
|
attacker should have a balanced (roughly equal) number of samples from the
|
||||||
training and test sets used to develop the model under attack.
|
training and test sets used to develop the model under attack.
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AttackerData.
|
AttackerData.
|
||||||
|
@ -297,7 +298,6 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
|
|
||||||
# Populate accuracy, loss fields in privacy report metadata
|
# Populate accuracy, loss fields in privacy report metadata
|
||||||
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
|
||||||
privacy_report_metadata.loss_train = loss_train
|
privacy_report_metadata.loss_train = loss_train
|
||||||
privacy_report_metadata.loss_test = loss_test
|
privacy_report_metadata.loss_test = loss_test
|
||||||
privacy_report_metadata.accuracy_train = accuracy_train
|
privacy_report_metadata.accuracy_train = accuracy_train
|
||||||
|
|
Loading…
Reference in a new issue