PiperOrigin-RevId: 229430188
This commit is contained in:
Nicolas Papernot 2019-01-15 13:32:27 -08:00 committed by A. Unique TensorFlower
parent 5ee12803f3
commit 4487099296

View file

@ -60,7 +60,7 @@ def cnn_model_fn(features, labels, mode):
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
# Calculate loss as a vector (to support microbatches in DP-SGD).
vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
# Define mean of loss across minibatch (for reporting through tf.Estimator).
scalar_loss = tf.reduce_mean(vector_loss)
@ -99,7 +99,7 @@ def cnn_model_fn(features, labels, mode):
eval_metric_ops = {
'accuracy':
tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1),
labels=labels,
predictions=tf.argmax(input=logits, axis=1))
}
return tf.estimator.EstimatorSpec(mode=mode,
@ -116,15 +116,15 @@ def load_mnist():
train_data = np.array(train_data, dtype=np.float32) / 255
test_data = np.array(test_data, dtype=np.float32) / 255
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)
train_labels = np.array(train_labels, dtype=np.int32)
test_labels = np.array(test_labels, dtype=np.int32)
assert train_data.min() == 0.
assert train_data.max() == 1.
assert test_data.min() == 0.
assert test_data.max() == 1.
assert train_labels.shape[1] == 10
assert test_labels.shape[1] == 10
assert len(train_labels.shape) == 1
assert len(test_labels.shape) == 1
return train_data, train_labels, test_data, test_labels