forked from 626_privacy/tensorflow_privacy
Minor update to mnist_lr_tutorial.py
to avoid (some) deprecated items.
PiperOrigin-RevId: 339327388
This commit is contained in:
parent
67f7f35383
commit
f0daaf085f
1 changed files with 5 additions and 13 deletions
|
@ -30,8 +30,6 @@ import math
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
@ -39,10 +37,7 @@ from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
|||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -62,14 +57,11 @@ flags.DEFINE_float('data_l2_norm', 8, 'Bound on the L2 norm of normalized data')
|
|||
def lr_model_fn(features, labels, mode, nclasses, dim):
|
||||
"""Model function for logistic regression."""
|
||||
input_layer = tf.reshape(features['x'], tuple([-1]) + dim)
|
||||
|
||||
logits = tf.layers.dense(
|
||||
inputs=input_layer,
|
||||
logits = tf.keras.layers.Dense(
|
||||
units=nclasses,
|
||||
kernel_regularizer=tf.contrib.layers.l2_regularizer(
|
||||
scale=FLAGS.regularizer),
|
||||
bias_regularizer=tf.contrib.layers.l2_regularizer(
|
||||
scale=FLAGS.regularizer))
|
||||
kernel_regularizer=tf.keras.regularizers.L2(l2=FLAGS.regularizer),
|
||||
bias_regularizer=tf.keras.regularizers.L2(l2=FLAGS.regularizer)).apply(
|
||||
input_layer)
|
||||
|
||||
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
|
|
Loading…
Reference in a new issue