From 59192e6f5c9a3f4302fe5a2132699cd1868dfb39 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Aug 2020 06:00:03 -0700 Subject: [PATCH] Make validation that labels are integers. PiperOrigin-RevId: 326216555 --- .../membership_inference_attack/data_structures.py | 12 ++++++++++++ .../privacy/membership_inference_attack/example.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 36f7b75..efabdd8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -101,6 +101,10 @@ class AttackType(enum.Enum): return '%s' % self.name +def _is_integer_type_array(a): + return np.issubdtype(a.dtype, np.integer) + + @dataclass class AttackInputData: """Input data for running an attack. @@ -175,6 +179,14 @@ class AttackInputData: self.logits_train is None): raise ValueError('At least one of labels, logits or losses should be set') + if self.labels_train is not None and not _is_integer_type_array( + self.labels_train): + raise ValueError('labels_train elements should have integer type') + + if self.labels_test is not None and not _is_integer_type_array( + self.labels_test): + raise ValueError('labels_test elements should have integer type') + # TODO(b/161366709): Add checks for equal sizes diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index eb32c6d..f13383f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -66,7 +66,7 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1): np.ones(samples_per_cluster), np.ones(samples_per_cluster) * 2, np.ones(samples_per_cluster) * 3, - )) + )).astype("uint8") return (features, labels) @@ -115,6 +115,8 @@ attack_results = mia.run_attacks( AttackInputData( labels_train=training_labels, labels_test=test_labels, + logits_train=training_pred, + logits_test=test_pred, loss_train=crossentropy(training_labels, training_pred), loss_test=crossentropy(test_labels, test_pred)), SlicingSpec(entire_dataset=True, by_class=True),