Use hard-coded attack input for the metadata calculation test

This commit is contained in:
amad-person 2020-12-02 21:17:33 +08:00
parent 6c7d607e65
commit 31c747cdd8

View file

@ -99,48 +99,41 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
np.array([ np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32) np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array( np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object), dtype=object),
np.array([ np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
logits_test=iter([ logits_test=iter([
np.array([ np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([ np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32) np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
labels_train=iter([ labels_train=iter([
np.array([ np.array([
np.array([2, 0], dtype=np.float32), np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32) np.array([1], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object), np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([ np.array([
np.array([0, 1], dtype=np.float32), np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32) np.array([1, 2], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
labels_test=iter([ labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]), np.array([np.array([2, 1], dtype=np.float32)]),
np.array([ np.array([
np.array([2, 0], dtype=np.float32), np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32) np.array([1], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
vocab_size=3, vocab_size=3,
train_size=3, train_size=3,
@ -168,52 +161,44 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
np.array([ np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32) np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array( np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object), dtype=object),
np.array([ np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
logits_test=iter([ logits_test=iter([
np.array([ np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([ np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32) np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([ np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
labels_train=iter([ labels_train=iter([
np.array([ np.array([
np.array([2, 0], dtype=np.float32), np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32) np.array([1], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object), np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([ np.array([
np.array([0, 1], dtype=np.float32), np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32) np.array([1, 2], dtype=np.float32)
], ], dtype=object)
dtype=object)
]), ]),
labels_test=iter([ labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]), np.array([np.array([2, 1], dtype=np.float32)]),
np.array([ np.array([
np.array([2, 0], dtype=np.float32), np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32) np.array([1], dtype=np.float32)
], ], dtype=object),
dtype=object),
np.array([np.array([2, 1], dtype=np.float32)]) np.array([np.array([2, 1], dtype=np.float32)])
]), ]),
vocab_size=3, vocab_size=3,
@ -335,19 +320,72 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self): def test_run_seq2seq_attack_calculates_correct_metadata(self):
result = run_seq2seq_attack(get_seq2seq_test_input( attack_input = Seq2SeqAttackInputData(
n_train=20, logits_train=iter([
n_test=10, np.array([
max_seq_in_batch=3, np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
max_tokens_in_sequence=5, np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
], dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
], dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
], dtype=object),
np.array([
np.array([0, 0], dtype=np.float32),
np.array([0, 1], dtype=np.float32)
], dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
], dtype=object),
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([np.array([2, 1], dtype=np.float32)]),
]),
vocab_size=3, vocab_size=3,
seed=12345), train_size=4,
balance_attacker_training=False) test_size=4)
result = run_seq2seq_attack(attack_input, balance_attacker_training=False)
metadata = result.privacy_report_metadata metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2) np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2) np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2)
if __name__ == '__main__': if __name__ == '__main__':