Make validation that labels are integers.
PiperOrigin-RevId: 326216555
This commit is contained in:
parent
0fd06493cc
commit
59192e6f5c
2 changed files with 15 additions and 1 deletions
|
@ -101,6 +101,10 @@ class AttackType(enum.Enum):
|
|||
return '%s' % self.name
|
||||
|
||||
|
||||
def _is_integer_type_array(a):
|
||||
return np.issubdtype(a.dtype, np.integer)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
@ -175,6 +179,14 @@ class AttackInputData:
|
|||
self.logits_train is None):
|
||||
raise ValueError('At least one of labels, logits or losses should be set')
|
||||
|
||||
if self.labels_train is not None and not _is_integer_type_array(
|
||||
self.labels_train):
|
||||
raise ValueError('labels_train elements should have integer type')
|
||||
|
||||
if self.labels_test is not None and not _is_integer_type_array(
|
||||
self.labels_test):
|
||||
raise ValueError('labels_test elements should have integer type')
|
||||
|
||||
# TODO(b/161366709): Add checks for equal sizes
|
||||
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1):
|
|||
np.ones(samples_per_cluster),
|
||||
np.ones(samples_per_cluster) * 2,
|
||||
np.ones(samples_per_cluster) * 3,
|
||||
))
|
||||
)).astype("uint8")
|
||||
|
||||
return (features, labels)
|
||||
|
||||
|
@ -115,6 +115,8 @@ attack_results = mia.run_attacks(
|
|||
AttackInputData(
|
||||
labels_train=training_labels,
|
||||
labels_test=test_labels,
|
||||
logits_train=training_pred,
|
||||
logits_test=test_pred,
|
||||
loss_train=crossentropy(training_labels, training_pred),
|
||||
loss_test=crossentropy(test_labels, test_pred)),
|
||||
SlicingSpec(entire_dataset=True, by_class=True),
|
||||
|
|
Loading…
Reference in a new issue