Use hard-coded attack input for the metadata calculation test
This commit is contained in:
parent
6c7d607e65
commit
31c747cdd8
1 changed files with 80 additions and 42 deletions
|
@ -99,48 +99,41 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=3,
|
||||
|
@ -168,52 +161,44 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
], dtype=object),
|
||||
np.array([np.array([2, 1], dtype=np.float32)])
|
||||
]),
|
||||
vocab_size=3,
|
||||
|
@ -335,19 +320,72 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
|||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||
|
||||
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||
result = run_seq2seq_attack(get_seq2seq_test_input(
|
||||
n_train=20,
|
||||
n_test=10,
|
||||
max_seq_in_batch=3,
|
||||
max_tokens_in_sequence=5,
|
||||
vocab_size=3,
|
||||
seed=12345),
|
||||
balance_attacker_training=False)
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
], dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 0], dtype=np.float32),
|
||||
np.array([0, 1], dtype=np.float32)
|
||||
], dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
], dtype=object),
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=4,
|
||||
test_size=4)
|
||||
result = run_seq2seq_attack(attack_input, balance_attacker_training=False)
|
||||
metadata = result.privacy_report_metadata
|
||||
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in a new issue