Clarify documentation of labels_train/test usage wrt loss_train/test.

PiperOrigin-RevId: 670641843
This commit is contained in:
A. Unique TensorFlower 2024-09-03 11:38:59 -07:00
parent 66d05a22a3
commit e3fd3afdf8

View file

@ -261,7 +261,8 @@ class AttackInputData:
# Contains ground-truth classes. For single-label classification, classes are # Contains ground-truth classes. For single-label classification, classes are
# assumed to be integers starting from 0. For multi-label classification, # assumed to be integers starting from 0. For multi-label classification,
# label is assumed to be multi-hot, i.e., labels is a binary array of shape # label is assumed to be multi-hot, i.e., labels is a binary array of shape
# (num_examples, num_classes). # (num_examples, num_classes). Additionally used to compute the loss when
# loss_train/test is not provided. Leave empty for non-classification models.
labels_train: Optional[np.ndarray] = None labels_train: Optional[np.ndarray] = None
labels_test: Optional[np.ndarray] = None labels_test: Optional[np.ndarray] = None
@ -270,7 +271,7 @@ class AttackInputData:
sample_weight_test: Optional[np.ndarray] = None sample_weight_test: Optional[np.ndarray] = None
# Explicitly specified loss. If provided, this is used instead of deriving # Explicitly specified loss. If provided, this is used instead of deriving
# loss from logits and labels # loss from logits and labels.
loss_train: Optional[np.ndarray] = None loss_train: Optional[np.ndarray] = None
loss_test: Optional[np.ndarray] = None loss_test: Optional[np.ndarray] = None