Adds more validation checks for AttackInputData.
PiperOrigin-RevId: 327191245
This commit is contained in:
parent
a69b013390
commit
193ac3b1c8
1 changed files with 32 additions and 1 deletions
|
@ -105,6 +105,25 @@ def _is_integer_type_array(a):
|
||||||
return np.issubdtype(a.dtype, np.integer)
|
return np.issubdtype(a.dtype, np.integer)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_last_dim_equal(arr1, arr1_name, arr2, arr2_name):
|
||||||
|
"""Checks whether the last dimension of the arrays is the same."""
|
||||||
|
if arr1 is not None and arr2 is not None and arr1.shape[-1] != arr2.shape[-1]:
|
||||||
|
raise ValueError('%s and %s should have the same number of features.' %
|
||||||
|
(arr1_name, arr2_name))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_array_one_dimensional(arr, arr_name):
|
||||||
|
"""Checks whether the array is one dimensional."""
|
||||||
|
if arr is not None and len(arr.shape) != 1:
|
||||||
|
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_np_array(arr, arr_name):
|
||||||
|
"""Checks whether array is a numpy array."""
|
||||||
|
if arr is not None and not isinstance(arr, np.ndarray):
|
||||||
|
raise ValueError('%s should be a numpy array.' % arr_name)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttackInputData:
|
class AttackInputData:
|
||||||
"""Input data for running an attack.
|
"""Input data for running an attack.
|
||||||
|
@ -187,7 +206,19 @@ class AttackInputData:
|
||||||
self.labels_test):
|
self.labels_test):
|
||||||
raise ValueError('labels_test elements should have integer type')
|
raise ValueError('labels_test elements should have integer type')
|
||||||
|
|
||||||
# TODO(b/161366709): Add checks for equal sizes
|
_is_np_array(self.logits_train, 'logits_train')
|
||||||
|
_is_np_array(self.logits_test, 'logits_test')
|
||||||
|
_is_np_array(self.labels_train, 'labels_train')
|
||||||
|
_is_np_array(self.labels_test, 'labels_test')
|
||||||
|
_is_np_array(self.loss_train, 'loss_train')
|
||||||
|
_is_np_array(self.loss_test, 'loss_test')
|
||||||
|
|
||||||
|
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
|
||||||
|
'logits_test')
|
||||||
|
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||||
|
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||||
|
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||||
|
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
Loading…
Reference in a new issue