Adds more validation checks for AttackInputData.

PiperOrigin-RevId: 327191245
This commit is contained in:
A. Unique TensorFlower 2020-08-18 02:34:35 -07:00
parent a69b013390
commit 193ac3b1c8

View file

@ -105,6 +105,25 @@ def _is_integer_type_array(a):
return np.issubdtype(a.dtype, np.integer) return np.issubdtype(a.dtype, np.integer)
def _is_last_dim_equal(arr1, arr1_name, arr2, arr2_name):
"""Checks whether the last dimension of the arrays is the same."""
if arr1 is not None and arr2 is not None and arr1.shape[-1] != arr2.shape[-1]:
raise ValueError('%s and %s should have the same number of features.' %
(arr1_name, arr2_name))
def _is_array_one_dimensional(arr, arr_name):
"""Checks whether the array is one dimensional."""
if arr is not None and len(arr.shape) != 1:
raise ValueError('%s should be a one dimensional numpy array.' % arr_name)
def _is_np_array(arr, arr_name):
"""Checks whether array is a numpy array."""
if arr is not None and not isinstance(arr, np.ndarray):
raise ValueError('%s should be a numpy array.' % arr_name)
@dataclass @dataclass
class AttackInputData: class AttackInputData:
"""Input data for running an attack. """Input data for running an attack.
@ -187,7 +206,19 @@ class AttackInputData:
self.labels_test): self.labels_test):
raise ValueError('labels_test elements should have integer type') raise ValueError('labels_test elements should have integer type')
# TODO(b/161366709): Add checks for equal sizes _is_np_array(self.logits_train, 'logits_train')
_is_np_array(self.logits_test, 'logits_test')
_is_np_array(self.labels_train, 'labels_train')
_is_np_array(self.labels_test, 'labels_test')
_is_np_array(self.loss_train, 'loss_train')
_is_np_array(self.loss_test, 'loss_test')
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
'logits_test')
_is_array_one_dimensional(self.loss_train, 'loss_train')
_is_array_one_dimensional(self.loss_test, 'loss_test')
_is_array_one_dimensional(self.labels_train, 'labels_train')
_is_array_one_dimensional(self.labels_test, 'labels_test')
@dataclass @dataclass