Migrates compute_dp_sgd_privacy
to print new privacy statement from compute_dp_sgd_privacy_lib
.
PiperOrigin-RevId: 520147633
This commit is contained in:
parent
781483d1f2
commit
abb0c3f9f6
1 changed files with 35 additions and 13 deletions
|
@ -31,26 +31,48 @@ The output states that DP-SGD with these parameters satisfies (2.92, 1e-5)-DP.
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
|
||||
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy_statement
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('N', None, 'Total number of examples')
|
||||
flags.DEFINE_integer('batch_size', None, 'Batch size')
|
||||
flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD')
|
||||
flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)')
|
||||
flags.DEFINE_float('delta', 1e-6, 'Target delta')
|
||||
_NUM_EXAMPLES = flags.DEFINE_integer('N', None, 'Total number of examples.')
|
||||
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'Batch size.')
|
||||
_NOISE_MULTIPLIER = flags.DEFINE_float(
|
||||
'noise_multiplier', None, 'Noise multiplier for DP-SGD.'
|
||||
)
|
||||
_NUM_EPOCHS = flags.DEFINE_float(
|
||||
'epochs', None, 'Number of epochs (may be fractional).'
|
||||
)
|
||||
_DELTA = flags.DEFINE_float('delta', 1e-6, 'Target delta.')
|
||||
_USED_MICROBATCHING = flags.DEFINE_bool(
|
||||
'used_microbatching',
|
||||
True,
|
||||
'Whether microbatching was used (with microbatch size greater than one).',
|
||||
)
|
||||
_MAX_EXAMPLES_PER_USER = flags.DEFINE_integer(
|
||||
'max_examples_per_user',
|
||||
None,
|
||||
(
|
||||
'Maximum number of examples per user, applicable. Used to compute a'
|
||||
' user-level DP guarantee.'
|
||||
),
|
||||
)
|
||||
|
||||
flags.mark_flags_as_required(['N', 'batch_size', 'noise_multiplier', 'epochs'])
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # argv is not used.
|
||||
|
||||
assert FLAGS.N is not None, 'Flag N is missing.'
|
||||
assert FLAGS.batch_size is not None, 'Flag batch_size is missing.'
|
||||
assert FLAGS.noise_multiplier is not None, 'Flag noise_multiplier is missing.'
|
||||
assert FLAGS.epochs is not None, 'Flag epochs is missing.'
|
||||
compute_dp_sgd_privacy(FLAGS.N, FLAGS.batch_size, FLAGS.noise_multiplier,
|
||||
FLAGS.epochs, FLAGS.delta)
|
||||
statement = compute_dp_sgd_privacy_statement(
|
||||
_NUM_EXAMPLES.value,
|
||||
_BATCH_SIZE.value,
|
||||
_NUM_EPOCHS.value,
|
||||
_NOISE_MULTIPLIER.value,
|
||||
_DELTA.value,
|
||||
_USED_MICROBATCHING.value,
|
||||
_MAX_EXAMPLES_PER_USER.value,
|
||||
)
|
||||
print(statement)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in a new issue