diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index efabdd8..c992dc8 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -105,6 +105,25 @@ def _is_integer_type_array(a): return np.issubdtype(a.dtype, np.integer) +def _is_last_dim_equal(arr1, arr1_name, arr2, arr2_name): + """Checks whether the last dimension of the arrays is the same.""" + if arr1 is not None and arr2 is not None and arr1.shape[-1] != arr2.shape[-1]: + raise ValueError('%s and %s should have the same number of features.' % + (arr1_name, arr2_name)) + + +def _is_array_one_dimensional(arr, arr_name): + """Checks whether the array is one dimensional.""" + if arr is not None and len(arr.shape) != 1: + raise ValueError('%s should be a one dimensional numpy array.' % arr_name) + + +def _is_np_array(arr, arr_name): + """Checks whether array is a numpy array.""" + if arr is not None and not isinstance(arr, np.ndarray): + raise ValueError('%s should be a numpy array.' % arr_name) + + @dataclass class AttackInputData: """Input data for running an attack. @@ -187,7 +206,19 @@ class AttackInputData: self.labels_test): raise ValueError('labels_test elements should have integer type') - # TODO(b/161366709): Add checks for equal sizes + _is_np_array(self.logits_train, 'logits_train') + _is_np_array(self.logits_test, 'logits_test') + _is_np_array(self.labels_train, 'labels_train') + _is_np_array(self.labels_test, 'labels_test') + _is_np_array(self.loss_train, 'loss_train') + _is_np_array(self.loss_test, 'loss_test') + + _is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test, + 'logits_test') + _is_array_one_dimensional(self.loss_train, 'loss_train') + _is_array_one_dimensional(self.loss_test, 'loss_test') + _is_array_one_dimensional(self.labels_train, 'labels_train') + _is_array_one_dimensional(self.labels_test, 'labels_test') @dataclass