Make validation that labels are integers.

PiperOrigin-RevId: 326216555
This commit is contained in:
A. Unique TensorFlower 2020-08-12 06:00:03 -07:00
parent 0fd06493cc
commit 59192e6f5c
2 changed files with 15 additions and 1 deletions

View file

@ -101,6 +101,10 @@ class AttackType(enum.Enum):
return '%s' % self.name return '%s' % self.name
def _is_integer_type_array(a):
return np.issubdtype(a.dtype, np.integer)
@dataclass @dataclass
class AttackInputData: class AttackInputData:
"""Input data for running an attack. """Input data for running an attack.
@ -175,6 +179,14 @@ class AttackInputData:
self.logits_train is None): self.logits_train is None):
raise ValueError('At least one of labels, logits or losses should be set') raise ValueError('At least one of labels, logits or losses should be set')
if self.labels_train is not None and not _is_integer_type_array(
self.labels_train):
raise ValueError('labels_train elements should have integer type')
if self.labels_test is not None and not _is_integer_type_array(
self.labels_test):
raise ValueError('labels_test elements should have integer type')
# TODO(b/161366709): Add checks for equal sizes # TODO(b/161366709): Add checks for equal sizes

View file

@ -66,7 +66,7 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1):
np.ones(samples_per_cluster), np.ones(samples_per_cluster),
np.ones(samples_per_cluster) * 2, np.ones(samples_per_cluster) * 2,
np.ones(samples_per_cluster) * 3, np.ones(samples_per_cluster) * 3,
)) )).astype("uint8")
return (features, labels) return (features, labels)
@ -115,6 +115,8 @@ attack_results = mia.run_attacks(
AttackInputData( AttackInputData(
labels_train=training_labels, labels_train=training_labels,
labels_test=test_labels, labels_test=test_labels,
logits_train=training_pred,
logits_test=test_pred,
loss_train=crossentropy(training_labels, training_pred), loss_train=crossentropy(training_labels, training_pred),
loss_test=crossentropy(test_labels, test_pred)), loss_test=crossentropy(test_labels, test_pred)),
SlicingSpec(entire_dataset=True, by_class=True), SlicingSpec(entire_dataset=True, by_class=True),