diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index c992dc8..32a22c7 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -82,6 +82,25 @@ class SlicingSpec: # examples will be generated. by_classification_correctness: bool = False + def __str__(self): + """Only keeps the True values.""" + result = ['SlicingSpec('] + if self.entire_dataset: + result.append(' Entire dataset,') + if self.by_class: + if isinstance(self.by_class, Iterable): + result.append(' Into classes %s,' % self.by_class) + elif isinstance(self.by_class, int): + result.append(' Up to class %d,' % self.by_class) + else: + result.append(' By classes,') + if self.by_percentiles: + result.append(' By percentiles,') + if self.by_classification_correctness: + result.append(' By classification correctness,') + result.append(')') + return '\n'.join(result) + class AttackType(enum.Enum): """An enum define attack types.""" @@ -96,8 +115,8 @@ class AttackType(enum.Enum): """Returns whether this type of attack requires training a model.""" return self != AttackType.THRESHOLD_ATTACK - # Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION def __str__(self): + """Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION.""" return '%s' % self.name @@ -220,6 +239,23 @@ class AttackInputData: _is_array_one_dimensional(self.labels_train, 'labels_train') _is_array_one_dimensional(self.labels_test, 'labels_test') + def __str__(self): + """Return the shapes of variables that are not None.""" + result = ['AttackInputData('] + _append_array_shape(self.loss_train, 'loss_train', result) + _append_array_shape(self.loss_test, 'loss_test', result) + _append_array_shape(self.logits_train, 'logits_train', result) + _append_array_shape(self.logits_test, 'logits_test', result) + _append_array_shape(self.labels_train, 'labels_train', result) + _append_array_shape(self.labels_test, 'labels_test', result) + result.append(')') + return '\n'.join(result) + + +def _append_array_shape(arr: np.array, arr_name: str, result): + if arr is not None: + result.append(' %s with shape: %s,' % (arr_name, arr.shape)) + @dataclass class RocCurve: @@ -251,6 +287,14 @@ class RocCurve: """ return max(np.abs(self.tpr - self.fpr)) + def __str__(self): + """Returns AUC and advantage metrics.""" + return '\n'.join([ + 'RocCurve(', + ' AUC: %f.02' % self.get_auc(), + ' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' + ]) + @dataclass class SingleAttackResult: @@ -272,6 +316,16 @@ class SingleAttackResult: def get_auc(self): return self.roc_curve.get_auc() + def __str__(self): + """Returns SliceSpec, AttackType, AUC and advantage metrics.""" + return '\n'.join([ + 'SingleAttackResult(', + ' SliceSpec: %s' % str(self.slice_spec), + ' AttackType: %s' % str(self.attack_type), + ' AUC: %f.02' % self.get_auc(), + ' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')' + ]) + @dataclass class AttackResults: