PiperOrigin-RevId: 229430188
This commit is contained in:
Nicolas Papernot 2019-01-15 13:32:27 -08:00 committed by A. Unique TensorFlower
parent 5ee12803f3
commit 4487099296

View file

@ -60,7 +60,7 @@ def cnn_model_fn(features, labels, mode):
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y) logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
# Calculate loss as a vector (to support microbatches in DP-SGD). # Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2( vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits) labels=labels, logits=logits)
# Define mean of loss across minibatch (for reporting through tf.Estimator). # Define mean of loss across minibatch (for reporting through tf.Estimator).
scalar_loss = tf.reduce_mean(vector_loss) scalar_loss = tf.reduce_mean(vector_loss)
@ -99,7 +99,7 @@ def cnn_model_fn(features, labels, mode):
eval_metric_ops = { eval_metric_ops = {
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1), labels=labels,
predictions=tf.argmax(input=logits, axis=1)) predictions=tf.argmax(input=logits, axis=1))
} }
return tf.estimator.EstimatorSpec(mode=mode, return tf.estimator.EstimatorSpec(mode=mode,
@ -116,15 +116,15 @@ def load_mnist():
train_data = np.array(train_data, dtype=np.float32) / 255 train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255 test_data = np.array(test_data, dtype=np.float32) / 255
train_labels = tf.keras.utils.to_categorical(train_labels) train_labels = np.array(train_labels, dtype=np.int32)
test_labels = tf.keras.utils.to_categorical(test_labels) test_labels = np.array(test_labels, dtype=np.int32)
assert train_data.min() == 0. assert train_data.min() == 0.
assert train_data.max() == 1. assert train_data.max() == 1.
assert test_data.min() == 0. assert test_data.min() == 0.
assert test_data.max() == 1. assert test_data.max() == 1.
assert train_labels.shape[1] == 10 assert len(train_labels.shape) == 1
assert test_labels.shape[1] == 10 assert len(test_labels.shape) == 1
return train_data, train_labels, test_data, test_labels return train_data, train_labels, test_data, test_labels