Migrates compute_dp_sgd_privacy to print new privacy statement from compute_dp_sgd_privacy_lib.

PiperOrigin-RevId: 520147633
This commit is contained in:
Galen Andrew 2023-03-28 15:13:21 -07:00 committed by A. Unique TensorFlower
parent 781483d1f2
commit abb0c3f9f6

View file

@ -31,26 +31,48 @@ The output states that DP-SGD with these parameters satisfies (2.92, 1e-5)-DP.
from absl import app
from absl import flags
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy_statement
FLAGS = flags.FLAGS
flags.DEFINE_integer('N', None, 'Total number of examples')
flags.DEFINE_integer('batch_size', None, 'Batch size')
flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD')
flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)')
flags.DEFINE_float('delta', 1e-6, 'Target delta')
_NUM_EXAMPLES = flags.DEFINE_integer('N', None, 'Total number of examples.')
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'Batch size.')
_NOISE_MULTIPLIER = flags.DEFINE_float(
'noise_multiplier', None, 'Noise multiplier for DP-SGD.'
)
_NUM_EPOCHS = flags.DEFINE_float(
'epochs', None, 'Number of epochs (may be fractional).'
)
_DELTA = flags.DEFINE_float('delta', 1e-6, 'Target delta.')
_USED_MICROBATCHING = flags.DEFINE_bool(
'used_microbatching',
True,
'Whether microbatching was used (with microbatch size greater than one).',
)
_MAX_EXAMPLES_PER_USER = flags.DEFINE_integer(
'max_examples_per_user',
None,
(
'Maximum number of examples per user, applicable. Used to compute a'
' user-level DP guarantee.'
),
)
flags.mark_flags_as_required(['N', 'batch_size', 'noise_multiplier', 'epochs'])
def main(argv):
del argv # argv is not used.
assert FLAGS.N is not None, 'Flag N is missing.'
assert FLAGS.batch_size is not None, 'Flag batch_size is missing.'
assert FLAGS.noise_multiplier is not None, 'Flag noise_multiplier is missing.'
assert FLAGS.epochs is not None, 'Flag epochs is missing.'
compute_dp_sgd_privacy(FLAGS.N, FLAGS.batch_size, FLAGS.noise_multiplier,
FLAGS.epochs, FLAGS.delta)
statement = compute_dp_sgd_privacy_statement(
_NUM_EXAMPLES.value,
_BATCH_SIZE.value,
_NUM_EPOCHS.value,
_NOISE_MULTIPLIER.value,
_DELTA.value,
_USED_MICROBATCHING.value,
_MAX_EXAMPLES_PER_USER.value,
)
print(statement)
if __name__ == '__main__':