Use hard-coded attack input for the metadata calculation test

This commit is contained in:
amad-person 2020-12-02 21:17:33 +08:00
parent 6c7d607e65
commit 31c747cdd8

View file

@ -99,48 +99,41 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
vocab_size=3,
train_size=3,
@ -168,52 +161,44 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
], dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
], dtype=object),
np.array([np.array([2, 1], dtype=np.float32)])
]),
vocab_size=3,
@ -335,19 +320,72 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self):
result = run_seq2seq_attack(get_seq2seq_test_input(
n_train=20,
n_test=10,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=3,
seed=12345),
balance_attacker_training=False)
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
], dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
], dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
], dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
], dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
], dtype=object),
np.array([
np.array([0, 0], dtype=np.float32),
np.array([0, 1], dtype=np.float32)
], dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
], dtype=object),
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([np.array([2, 1], dtype=np.float32)]),
]),
vocab_size=3,
train_size=4,
test_size=4)
result = run_seq2seq_attack(attack_input, balance_attacker_training=False)
metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2)
if __name__ == '__main__':