Adds more validation checks for AttackInputData.
PiperOrigin-RevId: 327191245
This commit is contained in:
parent
a69b013390
commit
193ac3b1c8
1 changed files with 32 additions and 1 deletions
|
@ -105,6 +105,25 @@ def _is_integer_type_array(a):
|
|||
return np.issubdtype(a.dtype, np.integer)
|
||||
|
||||
|
||||
def _is_last_dim_equal(arr1, arr1_name, arr2, arr2_name):
|
||||
"""Checks whether the last dimension of the arrays is the same."""
|
||||
if arr1 is not None and arr2 is not None and arr1.shape[-1] != arr2.shape[-1]:
|
||||
raise ValueError('%s and %s should have the same number of features.' %
|
||||
(arr1_name, arr2_name))
|
||||
|
||||
|
||||
def _is_array_one_dimensional(arr, arr_name):
|
||||
"""Checks whether the array is one dimensional."""
|
||||
if arr is not None and len(arr.shape) != 1:
|
||||
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
|
||||
|
||||
|
||||
def _is_np_array(arr, arr_name):
|
||||
"""Checks whether array is a numpy array."""
|
||||
if arr is not None and not isinstance(arr, np.ndarray):
|
||||
raise ValueError('%s should be a numpy array.' % arr_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttackInputData:
|
||||
"""Input data for running an attack.
|
||||
|
@ -187,7 +206,19 @@ class AttackInputData:
|
|||
self.labels_test):
|
||||
raise ValueError('labels_test elements should have integer type')
|
||||
|
||||
# TODO(b/161366709): Add checks for equal sizes
|
||||
_is_np_array(self.logits_train, 'logits_train')
|
||||
_is_np_array(self.logits_test, 'logits_test')
|
||||
_is_np_array(self.labels_train, 'labels_train')
|
||||
_is_np_array(self.labels_test, 'labels_test')
|
||||
_is_np_array(self.loss_train, 'loss_train')
|
||||
_is_np_array(self.loss_test, 'loss_test')
|
||||
|
||||
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
|
||||
'logits_test')
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
Loading…
Reference in a new issue