Make validation that labels are integers.
PiperOrigin-RevId: 326216555
This commit is contained in:
parent
0fd06493cc
commit
59192e6f5c
2 changed files with 15 additions and 1 deletions
|
@ -101,6 +101,10 @@ class AttackType(enum.Enum):
|
||||||
return '%s' % self.name
|
return '%s' % self.name
|
||||||
|
|
||||||
|
|
||||||
|
def _is_integer_type_array(a):
|
||||||
|
return np.issubdtype(a.dtype, np.integer)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttackInputData:
|
class AttackInputData:
|
||||||
"""Input data for running an attack.
|
"""Input data for running an attack.
|
||||||
|
@ -175,6 +179,14 @@ class AttackInputData:
|
||||||
self.logits_train is None):
|
self.logits_train is None):
|
||||||
raise ValueError('At least one of labels, logits or losses should be set')
|
raise ValueError('At least one of labels, logits or losses should be set')
|
||||||
|
|
||||||
|
if self.labels_train is not None and not _is_integer_type_array(
|
||||||
|
self.labels_train):
|
||||||
|
raise ValueError('labels_train elements should have integer type')
|
||||||
|
|
||||||
|
if self.labels_test is not None and not _is_integer_type_array(
|
||||||
|
self.labels_test):
|
||||||
|
raise ValueError('labels_test elements should have integer type')
|
||||||
|
|
||||||
# TODO(b/161366709): Add checks for equal sizes
|
# TODO(b/161366709): Add checks for equal sizes
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ def generate_features_and_labels(samples_per_cluster=250, scale=0.1):
|
||||||
np.ones(samples_per_cluster),
|
np.ones(samples_per_cluster),
|
||||||
np.ones(samples_per_cluster) * 2,
|
np.ones(samples_per_cluster) * 2,
|
||||||
np.ones(samples_per_cluster) * 3,
|
np.ones(samples_per_cluster) * 3,
|
||||||
))
|
)).astype("uint8")
|
||||||
|
|
||||||
return (features, labels)
|
return (features, labels)
|
||||||
|
|
||||||
|
@ -115,6 +115,8 @@ attack_results = mia.run_attacks(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
labels_train=training_labels,
|
labels_train=training_labels,
|
||||||
labels_test=test_labels,
|
labels_test=test_labels,
|
||||||
|
logits_train=training_pred,
|
||||||
|
logits_test=test_pred,
|
||||||
loss_train=crossentropy(training_labels, training_pred),
|
loss_train=crossentropy(training_labels, training_pred),
|
||||||
loss_test=crossentropy(test_labels, test_pred)),
|
loss_test=crossentropy(test_labels, test_pred)),
|
||||||
SlicingSpec(entire_dataset=True, by_class=True),
|
SlicingSpec(entire_dataset=True, by_class=True),
|
||||||
|
|
Loading…
Reference in a new issue