Use hard-coded attack input for the metadata calculation test
This commit is contained in:
parent
6c7d607e65
commit
31c747cdd8
1 changed files with 80 additions and 42 deletions
|
@ -99,48 +99,41 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array(
|
np.array(
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
dtype=object),
|
dtype=object),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
logits_test=iter([
|
logits_test=iter([
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
labels_train=iter([
|
labels_train=iter([
|
||||||
np.array([
|
np.array([
|
||||||
np.array([2, 0], dtype=np.float32),
|
np.array([2, 0], dtype=np.float32),
|
||||||
np.array([1], dtype=np.float32)
|
np.array([1], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([0, 1], dtype=np.float32),
|
np.array([0, 1], dtype=np.float32),
|
||||||
np.array([1, 2], dtype=np.float32)
|
np.array([1, 2], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
labels_test=iter([
|
labels_test=iter([
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([2, 0], dtype=np.float32),
|
np.array([2, 0], dtype=np.float32),
|
||||||
np.array([1], dtype=np.float32)
|
np.array([1], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
vocab_size=3,
|
vocab_size=3,
|
||||||
train_size=3,
|
train_size=3,
|
||||||
|
@ -168,52 +161,44 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array(
|
np.array(
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
dtype=object),
|
dtype=object),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
logits_test=iter([
|
logits_test=iter([
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([
|
np.array([
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
labels_train=iter([
|
labels_train=iter([
|
||||||
np.array([
|
np.array([
|
||||||
np.array([2, 0], dtype=np.float32),
|
np.array([2, 0], dtype=np.float32),
|
||||||
np.array([1], dtype=np.float32)
|
np.array([1], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([0, 1], dtype=np.float32),
|
np.array([0, 1], dtype=np.float32),
|
||||||
np.array([1, 2], dtype=np.float32)
|
np.array([1, 2], dtype=np.float32)
|
||||||
],
|
], dtype=object)
|
||||||
dtype=object)
|
|
||||||
]),
|
]),
|
||||||
labels_test=iter([
|
labels_test=iter([
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
np.array([
|
np.array([
|
||||||
np.array([2, 0], dtype=np.float32),
|
np.array([2, 0], dtype=np.float32),
|
||||||
np.array([1], dtype=np.float32)
|
np.array([1], dtype=np.float32)
|
||||||
],
|
], dtype=object),
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)])
|
np.array([np.array([2, 1], dtype=np.float32)])
|
||||||
]),
|
]),
|
||||||
vocab_size=3,
|
vocab_size=3,
|
||||||
|
@ -335,19 +320,72 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||||
|
|
||||||
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||||
result = run_seq2seq_attack(get_seq2seq_test_input(
|
attack_input = Seq2SeqAttackInputData(
|
||||||
n_train=20,
|
logits_train=iter([
|
||||||
n_test=10,
|
np.array([
|
||||||
max_seq_in_batch=3,
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
max_tokens_in_sequence=5,
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
vocab_size=3,
|
], dtype=object),
|
||||||
seed=12345),
|
np.array(
|
||||||
balance_attacker_training=False)
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
], dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
], dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 0], dtype=np.float32),
|
||||||
|
np.array([0, 1], dtype=np.float32)
|
||||||
|
], dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
], dtype=object),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=4,
|
||||||
|
test_size=4)
|
||||||
|
result = run_seq2seq_attack(attack_input, balance_attacker_training=False)
|
||||||
metadata = result.privacy_report_metadata
|
metadata = result.privacy_report_metadata
|
||||||
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
|
np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
|
np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue