diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py index 650f696..aef28d8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py @@ -99,48 +99,41 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase): np.array([ np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), np.array([[0.4, 0.5, 0.1]], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array( [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], dtype=object), np.array([ np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), logits_test=iter([ np.array([ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([ np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), np.array([[0.3, 0.35, 0.35]], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), labels_train=iter([ np.array([ np.array([2, 0], dtype=np.float32), np.array([1], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([np.array([1, 0], dtype=np.float32)], dtype=object), np.array([ np.array([0, 1], dtype=np.float32), np.array([1, 2], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), labels_test=iter([ np.array([np.array([2, 1], dtype=np.float32)]), np.array([ np.array([2, 0], dtype=np.float32), np.array([1], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), vocab_size=3, train_size=3, @@ -168,52 +161,44 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase): np.array([ np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), np.array([[0.4, 0.5, 0.1]], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array( [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], dtype=object), np.array([ np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), logits_test=iter([ np.array([ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([ np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), np.array([[0.3, 0.35, 0.35]], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), labels_train=iter([ np.array([ np.array([2, 0], dtype=np.float32), np.array([1], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([np.array([1, 0], dtype=np.float32)], dtype=object), np.array([ np.array([0, 1], dtype=np.float32), np.array([1, 2], dtype=np.float32) - ], - dtype=object) + ], dtype=object) ]), labels_test=iter([ np.array([np.array([2, 1], dtype=np.float32)]), np.array([ np.array([2, 0], dtype=np.float32), np.array([1], dtype=np.float32) - ], - dtype=object), + ], dtype=object), np.array([np.array([2, 1], dtype=np.float32)]) ]), vocab_size=3, @@ -335,19 +320,72 @@ class RunSeq2SeqAttackTest(absltest.TestCase): seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) def test_run_seq2seq_attack_calculates_correct_metadata(self): - result = run_seq2seq_attack(get_seq2seq_test_input( - n_train=20, - n_test=10, - max_seq_in_batch=3, - max_tokens_in_sequence=5, - vocab_size=3, - seed=12345), - balance_attacker_training=False) + attack_input = Seq2SeqAttackInputData( + logits_train=iter([ + np.array([ + np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), + np.array([[0.4, 0.5, 0.1]], dtype=np.float32) + ], dtype=object), + np.array( + [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], + dtype=object), + np.array([ + np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), + np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) + ], dtype=object), + np.array([ + np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), + np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) + ], dtype=object) + ]), + logits_test=iter([ + np.array([ + np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) + ], dtype=object), + np.array([ + np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), + np.array([[0.3, 0.35, 0.35]], dtype=np.float32) + ], dtype=object), + np.array([ + np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) + ], dtype=object), + np.array([ + np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) + ], dtype=object) + ]), + labels_train=iter([ + np.array([ + np.array([2, 0], dtype=np.float32), + np.array([1], dtype=np.float32) + ], dtype=object), + np.array([np.array([1, 0], dtype=np.float32)], dtype=object), + np.array([ + np.array([0, 1], dtype=np.float32), + np.array([1, 2], dtype=np.float32) + ], dtype=object), + np.array([ + np.array([0, 0], dtype=np.float32), + np.array([0, 1], dtype=np.float32) + ], dtype=object) + ]), + labels_test=iter([ + np.array([np.array([2, 1], dtype=np.float32)]), + np.array([ + np.array([2, 0], dtype=np.float32), + np.array([1], dtype=np.float32) + ], dtype=object), + np.array([np.array([2, 1], dtype=np.float32)]), + np.array([np.array([2, 1], dtype=np.float32)]), + ]), + vocab_size=3, + train_size=4, + test_size=4) + result = run_seq2seq_attack(attack_input, balance_attacker_training=False) metadata = result.privacy_report_metadata - np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2) - np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2) - np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2) - np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2) + np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2) + np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2) + np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2) + np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2) if __name__ == '__main__':