forked from 626_privacy/tensorflow_privacy
Merge pull request #45 from splatonline:ledger
PiperOrigin-RevId: 244412992
This commit is contained in:
commit
134b7d2093
1 changed files with 11 additions and 3 deletions
|
@ -40,13 +40,14 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
from privacy.analysis import privacy_ledger
|
||||||
from privacy.analysis.rdp_accountant import compute_rdp
|
from privacy.analysis.rdp_accountant import compute_rdp
|
||||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||||
from privacy.optimizers import dp_optimizer
|
from privacy.optimizers import dp_optimizer
|
||||||
|
|
||||||
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||||
'train with vanilla SGD.')
|
'train with vanilla SGD.')
|
||||||
tf.flags.DEFINE_float('learning_rate', .001, 'Learning rate for training')
|
tf.flags.DEFINE_float('learning_rate', 0.001, 'Learning rate for training')
|
||||||
tf.flags.DEFINE_float('noise_multiplier', 0.001,
|
tf.flags.DEFINE_float('noise_multiplier', 0.001,
|
||||||
'Ratio of the standard deviation to the clipping norm')
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
|
@ -84,13 +85,20 @@ def rnn_model_fn(features, labels, mode): # pylint: disable=unused-argument
|
||||||
# Configure the training op (for TRAIN mode).
|
# Configure the training op (for TRAIN mode).
|
||||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
if FLAGS.dpsgd:
|
if FLAGS.dpsgd:
|
||||||
|
|
||||||
|
ledger = privacy_ledger.PrivacyLedger(
|
||||||
|
population_size=NB_TRAIN,
|
||||||
|
selection_probability=(FLAGS.batch_size / NB_TRAIN),
|
||||||
|
max_samples=1e6,
|
||||||
|
max_queries=1e6)
|
||||||
|
|
||||||
optimizer = dp_optimizer.DPAdamGaussianOptimizer(
|
optimizer = dp_optimizer.DPAdamGaussianOptimizer(
|
||||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||||
noise_multiplier=FLAGS.noise_multiplier,
|
noise_multiplier=FLAGS.noise_multiplier,
|
||||||
num_microbatches=FLAGS.microbatches,
|
num_microbatches=FLAGS.microbatches,
|
||||||
|
ledger=ledger,
|
||||||
learning_rate=FLAGS.learning_rate,
|
learning_rate=FLAGS.learning_rate,
|
||||||
unroll_microbatches=True,
|
unroll_microbatches=True)
|
||||||
population_size=NB_TRAIN)
|
|
||||||
opt_loss = vector_loss
|
opt_loss = vector_loss
|
||||||
else:
|
else:
|
||||||
optimizer = tf.train.AdamOptimizer(
|
optimizer = tf.train.AdamOptimizer(
|
||||||
|
|
Loading…
Reference in a new issue