harmonize analysis parameters with current DPSGD API
PiperOrigin-RevId: 255080643
This commit is contained in:
parent
45bcb3a0e4
commit
6171474465
1 changed files with 4 additions and 3 deletions
|
@ -61,6 +61,10 @@ flags.mark_flag_as_required('epochs')
|
||||||
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
def apply_dp_sgd_analysis(q, sigma, steps, orders, delta):
|
||||||
"""Compute and print results of DP-SGD analysis."""
|
"""Compute and print results of DP-SGD analysis."""
|
||||||
|
|
||||||
|
# compute_rdp requires that sigma be the ratio of the standard deviation of
|
||||||
|
# the Gaussian noise to the l2-sensitivity of the function to which it is
|
||||||
|
# added. Hence, sigma here corresponds to the `noise_multiplier` parameter
|
||||||
|
# in the DP-SGD implementation found in privacy.optimizers.dp_optimizer
|
||||||
rdp = compute_rdp(q, sigma, steps, orders)
|
rdp = compute_rdp(q, sigma, steps, orders)
|
||||||
|
|
||||||
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
|
eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta)
|
||||||
|
@ -80,13 +84,10 @@ def main(argv):
|
||||||
del argv # argv is not used.
|
del argv # argv is not used.
|
||||||
|
|
||||||
q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio.
|
q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio.
|
||||||
|
|
||||||
if q > 1:
|
if q > 1:
|
||||||
raise app.UsageError('N must be larger than the batch size.')
|
raise app.UsageError('N must be larger than the batch size.')
|
||||||
|
|
||||||
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] +
|
||||||
list(range(5, 64)) + [128, 256, 512])
|
list(range(5, 64)) + [128, 256, 512])
|
||||||
|
|
||||||
steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size))
|
steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size))
|
||||||
|
|
||||||
apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta)
|
apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta)
|
||||||
|
|
Loading…
Reference in a new issue