From e3fd3afdf8dd79ebeac11202d00b2adadebc3553 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Sep 2024 11:38:59 -0700 Subject: [PATCH] Clarify documentation of labels_train/test usage wrt loss_train/test. PiperOrigin-RevId: 670641843 --- .../membership_inference_attack/data_structures.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index f10c49a..3eba2f8 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -261,7 +261,8 @@ class AttackInputData: # Contains ground-truth classes. For single-label classification, classes are # assumed to be integers starting from 0. For multi-label classification, # label is assumed to be multi-hot, i.e., labels is a binary array of shape - # (num_examples, num_classes). + # (num_examples, num_classes). Additionally used to compute the loss when + # loss_train/test is not provided. Leave empty for non-classification models. labels_train: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None @@ -270,7 +271,7 @@ class AttackInputData: sample_weight_test: Optional[np.ndarray] = None # Explicitly specified loss. If provided, this is used instead of deriving - # loss from logits and labels + # loss from logits and labels. loss_train: Optional[np.ndarray] = None loss_test: Optional[np.ndarray] = None