forked from 626_privacy/tensorflow_privacy
parent
5ee12803f3
commit
4487099296
1 changed files with 6 additions and 6 deletions
|
@ -60,7 +60,7 @@ def cnn_model_fn(features, labels, mode):
|
||||||
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
|
logits = tf.keras.layers.Dense(10, kernel_initializer='he_normal').apply(y)
|
||||||
|
|
||||||
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
# Calculate loss as a vector (to support microbatches in DP-SGD).
|
||||||
vector_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
|
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
labels=labels, logits=logits)
|
labels=labels, logits=logits)
|
||||||
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||||
scalar_loss = tf.reduce_mean(vector_loss)
|
scalar_loss = tf.reduce_mean(vector_loss)
|
||||||
|
@ -99,7 +99,7 @@ def cnn_model_fn(features, labels, mode):
|
||||||
eval_metric_ops = {
|
eval_metric_ops = {
|
||||||
'accuracy':
|
'accuracy':
|
||||||
tf.metrics.accuracy(
|
tf.metrics.accuracy(
|
||||||
labels=tf.argmax(labels, axis=1),
|
labels=labels,
|
||||||
predictions=tf.argmax(input=logits, axis=1))
|
predictions=tf.argmax(input=logits, axis=1))
|
||||||
}
|
}
|
||||||
return tf.estimator.EstimatorSpec(mode=mode,
|
return tf.estimator.EstimatorSpec(mode=mode,
|
||||||
|
@ -116,15 +116,15 @@ def load_mnist():
|
||||||
train_data = np.array(train_data, dtype=np.float32) / 255
|
train_data = np.array(train_data, dtype=np.float32) / 255
|
||||||
test_data = np.array(test_data, dtype=np.float32) / 255
|
test_data = np.array(test_data, dtype=np.float32) / 255
|
||||||
|
|
||||||
train_labels = tf.keras.utils.to_categorical(train_labels)
|
train_labels = np.array(train_labels, dtype=np.int32)
|
||||||
test_labels = tf.keras.utils.to_categorical(test_labels)
|
test_labels = np.array(test_labels, dtype=np.int32)
|
||||||
|
|
||||||
assert train_data.min() == 0.
|
assert train_data.min() == 0.
|
||||||
assert train_data.max() == 1.
|
assert train_data.max() == 1.
|
||||||
assert test_data.min() == 0.
|
assert test_data.min() == 0.
|
||||||
assert test_data.max() == 1.
|
assert test_data.max() == 1.
|
||||||
assert train_labels.shape[1] == 10
|
assert len(train_labels.shape) == 1
|
||||||
assert test_labels.shape[1] == 10
|
assert len(test_labels.shape) == 1
|
||||||
|
|
||||||
return train_data, train_labels, test_data, test_labels
|
return train_data, train_labels, test_data, test_labels
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue