forked from 626_privacy/tensorflow_privacy
Clarify documentation of labels_train/test usage wrt loss_train/test.
PiperOrigin-RevId: 670641843
This commit is contained in:
parent
66d05a22a3
commit
e3fd3afdf8
1 changed files with 3 additions and 2 deletions
|
@ -261,7 +261,8 @@ class AttackInputData:
|
||||||
# Contains ground-truth classes. For single-label classification, classes are
|
# Contains ground-truth classes. For single-label classification, classes are
|
||||||
# assumed to be integers starting from 0. For multi-label classification,
|
# assumed to be integers starting from 0. For multi-label classification,
|
||||||
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
# label is assumed to be multi-hot, i.e., labels is a binary array of shape
|
||||||
# (num_examples, num_classes).
|
# (num_examples, num_classes). Additionally used to compute the loss when
|
||||||
|
# loss_train/test is not provided. Leave empty for non-classification models.
|
||||||
labels_train: Optional[np.ndarray] = None
|
labels_train: Optional[np.ndarray] = None
|
||||||
labels_test: Optional[np.ndarray] = None
|
labels_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@ -270,7 +271,7 @@ class AttackInputData:
|
||||||
sample_weight_test: Optional[np.ndarray] = None
|
sample_weight_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
# Explicitly specified loss. If provided, this is used instead of deriving
|
# Explicitly specified loss. If provided, this is used instead of deriving
|
||||||
# loss from logits and labels
|
# loss from logits and labels.
|
||||||
loss_train: Optional[np.ndarray] = None
|
loss_train: Optional[np.ndarray] = None
|
||||||
loss_test: Optional[np.ndarray] = None
|
loss_test: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue