Introduce vectorized DP optimizer
PiperOrigin-RevId: 262414086
This commit is contained in:
parent
c08f3ebdc7
commit
c7ca8092fb
3 changed files with 564 additions and 0 deletions
153
privacy/optimizers/dp_optimizer_vectorized.py
Normal file
153
privacy/optimizers/dp_optimizer_vectorized.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Vectorized differentially private optimizers for TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import tensorflow as tf
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
AdagradOptimizer = tf.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
else:
|
||||
nest = tf.nest
|
||||
AdagradOptimizer = tf.optimizers.Adagrad
|
||||
AdamOptimizer = tf.optimizers.Adam
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
GATE_OP = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def make_vectorized_optimizer_class(cls):
|
||||
"""Constructs a vectorized DP optimizer class from an existing one."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
child_code = cls.compute_gradients.__code__
|
||||
else:
|
||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
if child_code is not parent_code:
|
||||
tf.logging.warning(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
cls.__name__)
|
||||
|
||||
class DPOptimizerClass(cls):
|
||||
"""Differentially private subclass of given class cls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches=None,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||
**kwargs):
|
||||
"""Initialize the DPOptimizerClass.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients)
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm
|
||||
num_microbatches: How many microbatches into which the minibatch is
|
||||
split. If None, will default to the size of the minibatch, and
|
||||
per-example gradients will be computed.
|
||||
"""
|
||||
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
self._num_microbatches = num_microbatches
|
||||
|
||||
def compute_gradients(self,
|
||||
loss,
|
||||
var_list,
|
||||
gate_gradients=GATE_OP,
|
||||
aggregation_method=None,
|
||||
colocate_gradients_with_ops=False,
|
||||
grad_loss=None,
|
||||
gradient_tape=None):
|
||||
if callable(loss):
|
||||
# TF is running in Eager mode
|
||||
raise NotImplementedError('Vectorized optimizer unavailable for TF2.')
|
||||
else:
|
||||
# TF is running in graph mode, check we did not receive a gradient tape.
|
||||
if gradient_tape:
|
||||
raise ValueError('When in graph mode, a tape should not be passed.')
|
||||
|
||||
batch_size = tf.shape(loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
|
||||
# Note: it would be closer to the correct i.i.d. sampling of records if
|
||||
# we sampled each microbatch from the appropriate binomial distribution,
|
||||
# although that still wouldn't be quite correct because it would be
|
||||
# sampling from the dataset without replacement.
|
||||
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
|
||||
if var_list is None:
|
||||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
|
||||
def process_microbatch(microbatch_loss):
|
||||
"""Compute clipped grads for one microbatch."""
|
||||
microbatch_loss = tf.reduce_mean(microbatch_loss)
|
||||
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
|
||||
microbatch_loss,
|
||||
var_list,
|
||||
gate_gradients,
|
||||
aggregation_method,
|
||||
colocate_gradients_with_ops,
|
||||
grad_loss))
|
||||
grads_list = [
|
||||
g if g is not None else tf.zeros_like(v)
|
||||
for (g, v) in zip(list(grads), var_list)
|
||||
]
|
||||
# Clip gradients to have L2 norm of l2_norm_clip.
|
||||
# Here, we use TF primitives rather than the built-in
|
||||
# tf.clip_by_global_norm() so that operations can be vectorized
|
||||
# across microbatches.
|
||||
grads_flat = nest.flatten(grads_list)
|
||||
squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat)
|
||||
return clipped_grads
|
||||
|
||||
clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses)
|
||||
|
||||
def reduce_noise_normalize_batch(stacked_grads):
|
||||
summed_grads = tf.reduce_sum(stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(tf.shape(summed_grads),
|
||||
stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
|
||||
|
||||
final_grads = nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_grads)
|
||||
|
||||
return list(zip(final_grads, var_list))
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
||||
VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer)
|
||||
VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer)
|
||||
VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer)
|
204
privacy/optimizers/dp_optimizer_vectorized_test.py
Normal file
204
privacy/optimizers/dp_optimizer_vectorized_test.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for differentially private optimizers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.optimizers import dp_optimizer_vectorized
|
||||
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad
|
||||
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam
|
||||
from privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD
|
||||
|
||||
|
||||
class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', VectorizedDPSGD, 1, [-2.5, -2.5]),
|
||||
('DPGradientDescent 2', VectorizedDPSGD, 2, [-2.5, -2.5]),
|
||||
('DPGradientDescent 4', VectorizedDPSGD, 4, [-2.5, -2.5]),
|
||||
('DPAdagrad 1', VectorizedDPAdagrad, 1, [-2.5, -2.5]),
|
||||
('DPAdagrad 2', VectorizedDPAdagrad, 2, [-2.5, -2.5]),
|
||||
('DPAdagrad 4', VectorizedDPAdagrad, 4, [-2.5, -2.5]),
|
||||
('DPAdam 1', VectorizedDPAdam, 1, [-2.5, -2.5]),
|
||||
('DPAdam 2', VectorizedDPAdam, 2, [-2.5, -2.5]),
|
||||
('DPAdam 4', VectorizedDPAdam, 4, [-2.5, -2.5]))
|
||||
def testBaseline(self, cls, num_microbatches, expected_answer):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=1.0e9,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
# Expected gradient is sum of differences divided by number of
|
||||
# microbatches.
|
||||
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
grads_and_vars = sess.run(gradient_op)
|
||||
self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', VectorizedDPSGD),
|
||||
('DPAdagrad', VectorizedDPAdagrad),
|
||||
('DPAdam', VectorizedDPAdam))
|
||||
def testClippingNorm(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0, 0.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||
|
||||
opt = cls(l2_norm_clip=1.0,
|
||||
noise_multiplier=0.,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||
|
||||
# Expected gradient is sum of differences.
|
||||
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
grads_and_vars = sess.run(gradient_op)
|
||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', VectorizedDPSGD),
|
||||
('DPAdagrad', VectorizedDPAdagrad),
|
||||
('DPAdam', VectorizedDPAdam))
|
||||
def testNoiseMultiplier(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0])
|
||||
data0 = tf.Variable([[0.0]])
|
||||
|
||||
opt = cls(l2_norm_clip=4.0,
|
||||
noise_multiplier=8.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
grads = []
|
||||
for _ in range(1000):
|
||||
grads_and_vars = sess.run(gradient_op)
|
||||
grads.append(grads_and_vars[0][0])
|
||||
|
||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
|
||||
|
||||
@mock.patch.object(tf, 'logging')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
||||
class SimpleOptimizer(tf.train.Optimizer):
|
||||
|
||||
def compute_gradients(self):
|
||||
return 0
|
||||
|
||||
dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer)
|
||||
mock_logging.warning.assert_called_once_with(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
'SimpleOptimizer')
|
||||
|
||||
def testEstimator(self):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
def linear_model_fn(features, labels, mode):
|
||||
preds = tf.keras.layers.Dense(
|
||||
1, activation='linear', name='dense').apply(features['x'])
|
||||
|
||||
vector_loss = tf.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
optimizer = VectorizedDPSGD(
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.,
|
||||
num_microbatches=1,
|
||||
learning_rate=1.0)
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
||||
linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn)
|
||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||
true_bias = 6.0
|
||||
train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32)
|
||||
|
||||
train_labels = np.matmul(train_data,
|
||||
true_weights) + true_bias + np.random.normal(
|
||||
scale=0.1, size=(200, 1)).astype(np.float32)
|
||||
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=20,
|
||||
num_epochs=10,
|
||||
shuffle=True)
|
||||
linear_regressor.train(input_fn=train_input_fn, steps=100)
|
||||
self.assertAllClose(
|
||||
linear_regressor.get_variable_value('dense/kernel'),
|
||||
true_weights,
|
||||
atol=1.0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', VectorizedDPSGD),
|
||||
('DPAdagrad', VectorizedDPAdagrad),
|
||||
('DPAdam', VectorizedDPAdam))
|
||||
def testDPGaussianOptimizerClass(self, cls):
|
||||
with self.cached_session() as sess:
|
||||
var0 = tf.Variable([0.0])
|
||||
data0 = tf.Variable([[0.0]])
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=4.0,
|
||||
noise_multiplier=2.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0])
|
||||
grads = []
|
||||
for _ in range(1000):
|
||||
grads_and_vars = sess.run(gradient_op)
|
||||
grads.append(grads_and_vars[0][0])
|
||||
|
||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
207
tutorials/mnist_dpsgd_tutorial_vectorized.py
Normal file
207
tutorials/mnist_dpsgd_tutorial_vectorized.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training a CNN on MNIST with vectorized DP-SGD optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.analysis.rdp_accountant import compute_rdp
|
||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from privacy.optimizers import dp_optimizer_vectorized
|
||||
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 200, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
flags.DEFINE_integer(
|
||||
'microbatches', 200, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
NUM_TRAIN_EXAMPLES = 60000
|
||||
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def compute_epsilon(steps):
|
||||
"""Computes epsilon value for given hyperparameters."""
|
||||
if FLAGS.noise_multiplier == 0.0:
|
||||
return float('inf')
|
||||
orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
|
||||
sampling_probability = FLAGS.batch_size / NUM_TRAIN_EXAMPLES
|
||||
rdp = compute_rdp(q=sampling_probability,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
steps=steps,
|
||||
orders=orders)
|
||||
# Delta is set to approximate 1 / (number of training points).
|
||||
return get_privacy_spent(orders, rdp, target_delta=1e-5)[0]
|
||||
|
||||
|
||||
def cnn_model_fn(features, labels, mode):
|
||||
"""Model function for a CNN."""
|
||||
|
||||
# Define CNN architecture using tf.keras.layers.
|
||||
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
|
||||
y = tf.keras.layers.Conv2D(16, 8,
|
||||
strides=2,
|
||||
padding='same',
|
||||
activation='relu').apply(input_layer)
|
||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||
y = tf.keras.layers.Conv2D(32, 4,
|
||||
strides=2,
|
||||
padding='valid',
|
||||
activation='relu').apply(y)
|
||||
y = tf.keras.layers.MaxPool2D(2, 1).apply(y)
|
||||
y = tf.keras.layers.Flatten().apply(y)
|
||||
y = tf.keras.layers.Dense(32, activation='relu').apply(y)
|
||||
logits = tf.keras.layers.Dense(10).apply(y)
|
||||
|
||||
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits)
|
||||
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
|
||||
# Configure the training op (for TRAIN mode).
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
|
||||
if FLAGS.dpsgd:
|
||||
# Use DP version of GradientDescentOptimizer. Other optimizers are
|
||||
# available in dp_optimizer. Most optimizers inheriting from
|
||||
# tf.train.Optimizer should be wrappable in differentially private
|
||||
# counterparts by calling dp_optimizer.optimizer_from_args().
|
||||
optimizer = dp_optimizer_vectorized.VectorizedDPSGD(
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
num_microbatches=FLAGS.microbatches,
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
opt_loss = vector_loss
|
||||
else:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
opt_loss = scalar_loss
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||
# used for evaluation and debugging by tf.estimator. The actual loss being
|
||||
# minimized is opt_loss defined above and passed to optimizer.minimize().
|
||||
return tf.estimator.EstimatorSpec(mode=mode,
|
||||
loss=scalar_loss,
|
||||
train_op=train_op)
|
||||
|
||||
# Add evaluation metrics (for EVAL mode).
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
eval_metric_ops = {
|
||||
'accuracy':
|
||||
tf.metrics.accuracy(
|
||||
labels=labels,
|
||||
predictions=tf.argmax(input=logits, axis=1))
|
||||
}
|
||||
|
||||
return tf.estimator.EstimatorSpec(mode=mode,
|
||||
loss=scalar_loss,
|
||||
eval_metric_ops=eval_metric_ops)
|
||||
|
||||
|
||||
def load_mnist():
|
||||
"""Loads MNIST and preprocesses to combine training and validation data."""
|
||||
train, test = tf.keras.datasets.mnist.load_data()
|
||||
train_data, train_labels = train
|
||||
test_data, test_labels = test
|
||||
|
||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||
|
||||
train_labels = np.array(train_labels, dtype=np.int32)
|
||||
test_labels = np.array(test_labels, dtype=np.int32)
|
||||
|
||||
assert train_data.min() == 0.
|
||||
assert train_data.max() == 1.
|
||||
assert test_data.min() == 0.
|
||||
assert test_data.max() == 1.
|
||||
assert train_labels.ndim == 1
|
||||
assert test_labels.ndim == 1
|
||||
|
||||
return train_data, train_labels, test_data, test_labels
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
# Load training and test data.
|
||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
|
||||
# Instantiate the tf.Estimator.
|
||||
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn,
|
||||
model_dir=FLAGS.model_dir)
|
||||
|
||||
# Create tf.Estimator input functions for the training and test data.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=FLAGS.epochs,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={'x': test_data},
|
||||
y=test_labels,
|
||||
num_epochs=1,
|
||||
shuffle=False)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = NUM_TRAIN_EXAMPLES // FLAGS.batch_size
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
# Train the model for one epoch.
|
||||
mnist_classifier.train(input_fn=train_input_fn, steps=steps_per_epoch)
|
||||
|
||||
# Evaluate the model and print results
|
||||
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
|
||||
test_accuracy = eval_results['accuracy']
|
||||
print('Test accuracy after %d epochs is: %.3f' % (epoch, test_accuracy))
|
||||
|
||||
# Compute the privacy budget expended.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_epsilon(epoch * NUM_TRAIN_EXAMPLES // FLAGS.batch_size)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
else:
|
||||
print('Trained with vanilla non-private SGD optimizer')
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
Loading…
Reference in a new issue