Fix MIA readme: labels are not needed in basic usage.

PiperOrigin-RevId: 430230630
This commit is contained in:
Shuang Song 2022-02-22 09:37:16 -08:00 committed by A. Unique TensorFlower
parent 7d5a57f0a8
commit 12541c23d4

View file

@ -32,20 +32,15 @@ The simplest possible usage is
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
# Suppose we have the labels as integers starting from 0 # Suppose we have evaluated the model on training and test examples to get the
# labels_train shape: (n_train, ) # per-example losses:
# labels_test shape: (n_test, )
# Evaluate your model on training and test examples to get
# loss_train shape: (n_train, ) # loss_train shape: (n_train, )
# loss_test shape: (n_test, ) # loss_test shape: (n_test, )
attacks_result = mia.run_attacks( attacks_result = mia.run_attacks(
AttackInputData( AttackInputData(
loss_train = loss_train, loss_train = loss_train,
loss_test = loss_test, loss_test = loss_test))
labels_train = labels_train,
labels_test = labels_test))
``` ```
This example calls `run_attacks` with the default options to run a host of This example calls `run_attacks` with the default options to run a host of
@ -94,6 +89,10 @@ First, similar as before, we specify the input for the attack as an
`AttackInputData` object: `AttackInputData` object:
```python ```python
# Suppose we have the labels as integers starting from 0
# labels_train shape: (n_train, )
# labels_test shape: (n_test, )
# Evaluate your model on training and test examples to get # Evaluate your model on training and test examples to get
# logits_train shape: (n_train, n_classes) # logits_train shape: (n_train, n_classes)
# logits_test shape: (n_test, n_classes) # logits_test shape: (n_test, n_classes)