diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index f10c49a..3eba2f8 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -261,7 +261,8 @@ class AttackInputData: # Contains ground-truth classes. For single-label classification, classes are # assumed to be integers starting from 0. For multi-label classification, # label is assumed to be multi-hot, i.e., labels is a binary array of shape - # (num_examples, num_classes). + # (num_examples, num_classes). Additionally used to compute the loss when + # loss_train/test is not provided. Leave empty for non-classification models. labels_train: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None @@ -270,7 +271,7 @@ class AttackInputData: sample_weight_test: Optional[np.ndarray] = None # Explicitly specified loss. If provided, this is used instead of deriving - # loss from logits and labels + # loss from logits and labels. loss_train: Optional[np.ndarray] = None loss_test: Optional[np.ndarray] = None