diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 36f7b75..efabdd8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -101,6 +101,10 @@ class AttackType(enum.Enum): return '%s' % self.name +def _is_integer_type_array(a): + return np.issubdtype(a.dtype, np.integer) + + @dataclass class AttackInputData: """Input data for running an attack. @@ -175,6 +179,14 @@ class AttackInputData: self.logits_train is None): raise ValueError('At least one of labels, logits or losses should be set') + if self.labels_train is not None and not _is_integer_type_array( + self.labels_train): + raise ValueError('labels_train elements should have integer type') + + if self.labels_test is not None and not _is_integer_type_array( + self.labels_test): + raise ValueError('labels_test elements should have integer type') + # TODO(b/161366709): Add checks for equal sizes diff --git a/tensorflow_privacy/privacy/membership_inference_attack/example.py b/tensorflow_privacy/privacy/membership_inference_attack/example.py index eb32c6d..f13383f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/example.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/example.py @@ -66,7 +66,7 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1): np.ones(samples_per_cluster), np.ones(samples_per_cluster) * 2, np.ones(samples_per_cluster) * 3, - )) + )).astype("uint8") return (features, labels) @@ -115,6 +115,8 @@ attack_results = mia.run_attacks( AttackInputData( labels_train=training_labels, labels_test=test_labels, + logits_train=training_pred, + logits_test=test_pred, loss_train=crossentropy(training_labels, training_pred), loss_test=crossentropy(test_labels, test_pred)), SlicingSpec(entire_dataset=True, by_class=True),