diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py index 84f5e2e..d882887 100644 --- a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py @@ -31,26 +31,48 @@ The output states that DP-SGD with these parameters satisfies (2.92, 1e-5)-DP. from absl import app from absl import flags -from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy +from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy_statement -FLAGS = flags.FLAGS -flags.DEFINE_integer('N', None, 'Total number of examples') -flags.DEFINE_integer('batch_size', None, 'Batch size') -flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD') -flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)') -flags.DEFINE_float('delta', 1e-6, 'Target delta') +_NUM_EXAMPLES = flags.DEFINE_integer('N', None, 'Total number of examples.') +_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'Batch size.') +_NOISE_MULTIPLIER = flags.DEFINE_float( + 'noise_multiplier', None, 'Noise multiplier for DP-SGD.' +) +_NUM_EPOCHS = flags.DEFINE_float( + 'epochs', None, 'Number of epochs (may be fractional).' +) +_DELTA = flags.DEFINE_float('delta', 1e-6, 'Target delta.') +_USED_MICROBATCHING = flags.DEFINE_bool( + 'used_microbatching', + True, + 'Whether microbatching was used (with microbatch size greater than one).', +) +_MAX_EXAMPLES_PER_USER = flags.DEFINE_integer( + 'max_examples_per_user', + None, + ( + 'Maximum number of examples per user, applicable. Used to compute a' + ' user-level DP guarantee.' + ), +) + +flags.mark_flags_as_required(['N', 'batch_size', 'noise_multiplier', 'epochs']) def main(argv): del argv # argv is not used. - assert FLAGS.N is not None, 'Flag N is missing.' - assert FLAGS.batch_size is not None, 'Flag batch_size is missing.' - assert FLAGS.noise_multiplier is not None, 'Flag noise_multiplier is missing.' - assert FLAGS.epochs is not None, 'Flag epochs is missing.' - compute_dp_sgd_privacy(FLAGS.N, FLAGS.batch_size, FLAGS.noise_multiplier, - FLAGS.epochs, FLAGS.delta) + statement = compute_dp_sgd_privacy_statement( + _NUM_EXAMPLES.value, + _BATCH_SIZE.value, + _NUM_EPOCHS.value, + _NOISE_MULTIPLIER.value, + _DELTA.value, + _USED_MICROBATCHING.value, + _MAX_EXAMPLES_PER_USER.value, + ) + print(statement) if __name__ == '__main__':