Project import generated by Copybara.
PiperOrigin-RevId: 226345615
This commit is contained in:
parent
1595ed3cd1
commit
b4188446e0
6 changed files with 351 additions and 44 deletions
|
@ -19,7 +19,7 @@ from __future__ import print_function
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
import privacy.optimizers.gaussian_average_query as ph
|
||||
from privacy.optimizers import gaussian_query
|
||||
|
||||
|
||||
def make_optimizer_class(cls):
|
||||
|
@ -40,9 +40,9 @@ def make_optimizer_class(cls):
|
|||
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||
stddev = l2_norm_clip * noise_multiplier
|
||||
self._num_microbatches = num_microbatches
|
||||
self._privacy_helper = ph.GaussianAverageQuery(l2_norm_clip, stddev,
|
||||
num_microbatches)
|
||||
self._ph_global_state = self._privacy_helper.initial_global_state()
|
||||
self._private_query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip, stddev, num_microbatches)
|
||||
self._global_state = self._private_query.initial_global_state()
|
||||
|
||||
def compute_gradients(self,
|
||||
loss,
|
||||
|
@ -58,7 +58,7 @@ def make_optimizer_class(cls):
|
|||
# sampling from the dataset without replacement.
|
||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
self._privacy_helper.derive_sample_params(self._ph_global_state))
|
||||
self._private_query.derive_sample_params(self._global_state))
|
||||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
|
@ -66,7 +66,7 @@ def make_optimizer_class(cls):
|
|||
tf.gather(microbatches_losses, [i]), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = list(grads)
|
||||
sample_state = self._privacy_helper.accumulate_record(
|
||||
sample_state = self._private_query.accumulate_record(
|
||||
sample_params, sample_state, grads_list)
|
||||
return [tf.add(i, 1), sample_state]
|
||||
|
||||
|
@ -76,8 +76,8 @@ def make_optimizer_class(cls):
|
|||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
sample_state = self._privacy_helper.initial_sample_state(
|
||||
self._ph_global_state, var_list)
|
||||
sample_state = self._private_query.initial_sample_state(
|
||||
self._global_state, var_list)
|
||||
|
||||
# Use of while_loop here requires that sample_state be a nested structure
|
||||
# of tensors. In general, we would prefer to allow it to be an arbitrary
|
||||
|
@ -85,9 +85,9 @@ def make_optimizer_class(cls):
|
|||
_, final_state = tf.while_loop(
|
||||
lambda i, _: tf.less(i, self._num_microbatches), process_microbatch,
|
||||
[i, sample_state])
|
||||
final_grads, self._ph_global_state = (
|
||||
self._privacy_helper.get_noised_average(final_state,
|
||||
self._ph_global_state))
|
||||
final_grads, self._global_state = (
|
||||
self._private_query.get_noised_average(final_state,
|
||||
self._global_state))
|
||||
|
||||
return zip(final_grads, var_list)
|
||||
|
||||
|
|
|
@ -18,32 +18,42 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.optimizers import gaussian_average_query
|
||||
from privacy.optimizers import gaussian_query
|
||||
|
||||
|
||||
class GaussianAverageQueryTest(tf.test.TestCase):
|
||||
def _run_query(query, records):
|
||||
"""Executes query on the given set of records as a single sample.
|
||||
|
||||
def _run_query(self, query, *records):
|
||||
"""Executes query on the given set of records and returns the result."""
|
||||
Args:
|
||||
query: A PrivateQuery to run.
|
||||
records: An iterable containing records to pass to the query.
|
||||
|
||||
Returns:
|
||||
The result of the query.
|
||||
"""
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
sample_state = query.initial_sample_state(global_state, records[0])
|
||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
||||
for record in records:
|
||||
sample_state = query.accumulate_record(params, sample_state, record)
|
||||
result, _ = query.get_query_result(sample_state, global_state)
|
||||
return result
|
||||
|
||||
|
||||
class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_gaussian_sum_no_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
record1 = tf.constant([2.0, 0.0])
|
||||
record2 = tf.constant([-1.0, 1.0])
|
||||
|
||||
query = gaussian_average_query.GaussianSumQuery(
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query_result = self._run_query(query, record1, record2)
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -53,9 +63,9 @@ class GaussianAverageQueryTest(tf.test.TestCase):
|
|||
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
|
||||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
query = gaussian_average_query.GaussianSumQuery(
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=0.0)
|
||||
query_result = self._run_query(query, record1, record2)
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
@ -65,9 +75,9 @@ class GaussianAverageQueryTest(tf.test.TestCase):
|
|||
record1, record2 = 2.71828, 3.14159
|
||||
stddev = 1.0
|
||||
|
||||
query = gaussian_average_query.GaussianSumQuery(
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=5.0, stddev=stddev)
|
||||
query_result = self._run_query(query, record1, record2)
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
|
||||
noised_sums = []
|
||||
for _ in xrange(1000):
|
||||
|
@ -81,9 +91,9 @@ class GaussianAverageQueryTest(tf.test.TestCase):
|
|||
record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0].
|
||||
record2 = tf.constant([-1.0, 2.0]) # Not clipped.
|
||||
|
||||
query = gaussian_average_query.GaussianAverageQuery(
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=3.0, sum_stddev=0.0, denominator=2.0)
|
||||
query_result = self._run_query(query, record1, record2)
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected_average = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected_average)
|
||||
|
@ -94,9 +104,9 @@ class GaussianAverageQueryTest(tf.test.TestCase):
|
|||
sum_stddev = 1.0
|
||||
denominator = 2.0
|
||||
|
||||
query = gaussian_average_query.GaussianAverageQuery(
|
||||
query = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query_result = self._run_query(query, record1, record2)
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
|
||||
noised_averages = []
|
||||
for _ in xrange(1000):
|
||||
|
@ -106,6 +116,15 @@ class GaussianAverageQueryTest(tf.test.TestCase):
|
|||
avg_stddev = sum_stddev / denominator
|
||||
self.assertNear(result_stddev, avg_stddev, 0.1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('type_mismatch', [1.0], (1.0,), TypeError),
|
||||
('too_few_on_left', [1.0], [1.0, 1.0], ValueError),
|
||||
('too_few_on_right', [1.0, 1.0], [1.0], ValueError))
|
||||
def test_incompatible_records(self, record1, record2, error_type):
|
||||
query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||
with self.assertRaises(error_type):
|
||||
_run_query(query, [record1, record2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
123
privacy/optimizers/nested_query.py
Normal file
123
privacy/optimizers/nested_query.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implements PrivateQuery interface for queries over nested structures.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.optimizers import private_queries
|
||||
|
||||
nest = tf.contrib.framework.nest
|
||||
|
||||
|
||||
class NestedQuery(private_queries.PrivateQuery):
|
||||
"""Implements PrivateQuery interface for structured queries.
|
||||
|
||||
NestedQuery evaluates arbitrary nested structures of queries. Records must be
|
||||
nested structures of tensors that are compatible (in type and arity) with the
|
||||
query structure, but are allowed to have deeper structure within each leaf of
|
||||
the query structure. For example, the nested query [q1, q2] is compatible with
|
||||
the record [t1, t2] or [t1, (t2, t3)], but not with (t1, t2), [t1] or
|
||||
[t1, t2, t3]. The entire substructure of each record corresponding to a leaf
|
||||
node of the query structure is routed to the corresponding query. If the same
|
||||
tensor should be consumed by multiple sub-queries, it can be replicated in the
|
||||
record, for example [t1, t1].
|
||||
|
||||
NestedQuery is intended to allow privacy mechanisms for groups as described in
|
||||
[McMahan & Andrew, 2018: "A General Approach to Adding Differential Privacy to
|
||||
Iterative Training Procedures" (https://arxiv.org/abs/1812.06210)].
|
||||
"""
|
||||
|
||||
def __init__(self, queries):
|
||||
"""Initializes the NestedQuery.
|
||||
|
||||
Args:
|
||||
queries: A nested structure of queries.
|
||||
"""
|
||||
self._queries = queries
|
||||
|
||||
def _map_to_queries(self, fn, *inputs):
|
||||
def caller(query, *args):
|
||||
return getattr(query, fn)(*args)
|
||||
return nest.map_structure_up_to(
|
||||
self._queries, caller, self._queries, *inputs)
|
||||
|
||||
def initial_global_state(self):
|
||||
"""Returns the initial global state for the NestedQuery."""
|
||||
return self._map_to_queries('initial_global_state')
|
||||
|
||||
def derive_sample_params(self, global_state):
|
||||
"""Given the global state, derives parameters to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
|
||||
Returns:
|
||||
Parameters to use to process records in the next sample.
|
||||
"""
|
||||
return self._map_to_queries('derive_sample_params', global_state)
|
||||
|
||||
def initial_sample_state(self, global_state, tensors):
|
||||
"""Returns an initial state to use for the next sample.
|
||||
|
||||
Args:
|
||||
global_state: The current global state.
|
||||
tensors: A structure of tensors used as a template to create the initial
|
||||
sample state.
|
||||
|
||||
Returns: An initial sample state.
|
||||
"""
|
||||
return self._map_to_queries('initial_sample_state', global_state, tensors)
|
||||
|
||||
def accumulate_record(self, params, sample_state, record):
|
||||
"""Accumulates a single record into the sample state.
|
||||
|
||||
Args:
|
||||
params: The parameters for the sample.
|
||||
sample_state: The current sample state.
|
||||
record: The record to accumulate.
|
||||
|
||||
Returns:
|
||||
The updated sample state.
|
||||
"""
|
||||
return self._map_to_queries(
|
||||
'accumulate_record', params, sample_state, record)
|
||||
|
||||
def get_query_result(self, sample_state, global_state):
|
||||
"""Gets query result after all records of sample have been accumulated.
|
||||
|
||||
Args:
|
||||
sample_state: The sample state after all records have been accumulated.
|
||||
global_state: The global state.
|
||||
|
||||
Returns:
|
||||
A tuple (result, new_global_state) where "result" is a structure matching
|
||||
the query structure containing the results of the subqueries and
|
||||
"new_global_state" is a structure containing the updated global states
|
||||
for the subqueries.
|
||||
"""
|
||||
estimates_and_new_global_states = self._map_to_queries(
|
||||
'get_query_result', sample_state, global_state)
|
||||
|
||||
flat_estimates, flat_new_global_states = zip(
|
||||
*nest.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||
return (
|
||||
nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
nest.pack_sequence_as(self._queries, flat_new_global_states))
|
164
privacy/optimizers/nested_query_test.py
Normal file
164
privacy/optimizers/nested_query_test.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for NestedQuery."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from privacy.optimizers import gaussian_query
|
||||
from privacy.optimizers import nested_query
|
||||
|
||||
nest = tf.contrib.framework.nest
|
||||
|
||||
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||
|
||||
|
||||
def _run_query(query, records):
|
||||
"""Executes query on the given set of records as a single sample.
|
||||
|
||||
Args:
|
||||
query: A PrivateQuery to run.
|
||||
records: An iterable containing records to pass to the query.
|
||||
|
||||
Returns:
|
||||
The result of the query.
|
||||
"""
|
||||
global_state = query.initial_global_state()
|
||||
params = query.derive_sample_params(global_state)
|
||||
sample_state = query.initial_sample_state(global_state, next(iter(records)))
|
||||
for record in records:
|
||||
sample_state = query.accumulate_record(params, sample_state, record)
|
||||
result, _ = query.get_query_result(sample_state, global_state)
|
||||
return result
|
||||
|
||||
|
||||
class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_nested_gaussian_sum_no_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
query2 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [5.0, [5.0, 5.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_nested_gaussian_average_no_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=5.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [2.0, 3.0]]
|
||||
record2 = [4.0, [3.0, 2.0]]
|
||||
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, [1.0, 1.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_nested_gaussian_average_with_clip_no_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
query1 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=4.0, sum_stddev=0.0, denominator=5.0)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=5.0, sum_stddev=0.0, denominator=5.0)
|
||||
|
||||
query = nested_query.NestedQuery([query1, query2])
|
||||
|
||||
record1 = [1.0, [12.0, 9.0]] # Clipped to [1.0, [4.0, 3.0]]
|
||||
record2 = [5.0, [1.0, 2.0]] # Clipped to [4.0, [1.0, 2.0]]
|
||||
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, [1.0, 1.0]]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_complex_nested_query(self):
|
||||
with self.cached_session() as sess:
|
||||
query_ab = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=1.0, stddev=0.0)
|
||||
query_c = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=10.0, sum_stddev=0.0, denominator=2.0)
|
||||
query_d = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=10.0, stddev=0.0)
|
||||
|
||||
query = nested_query.NestedQuery(
|
||||
[query_ab, {'c': query_c, 'd': [query_d]}])
|
||||
|
||||
record1 = [{'a': 0.0, 'b': 2.71828}, {'c': (-4.0, 6.0), 'd': [-4.0]}]
|
||||
record2 = [{'a': 3.14159, 'b': 0.0}, {'c': (6.0, -4.0), 'd': [5.0]}]
|
||||
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
result = sess.run(query_result)
|
||||
expected = [{'a': 1.0, 'b': 1.0}, {'c': (1.0, 1.0), 'd': [1.0]}]
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
def test_nested_query_with_noise(self):
|
||||
with self.cached_session() as sess:
|
||||
sum_stddev = 2.71828
|
||||
denominator = 3.14159
|
||||
|
||||
query1 = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=1.5, stddev=sum_stddev)
|
||||
query2 = gaussian_query.GaussianAverageQuery(
|
||||
l2_norm_clip=0.5, sum_stddev=sum_stddev, denominator=denominator)
|
||||
query = nested_query.NestedQuery((query1, query2))
|
||||
|
||||
record1 = (3.0, [2.0, 1.5])
|
||||
record2 = (0.0, [-1.0, -3.5])
|
||||
|
||||
query_result = _run_query(query, [record1, record2])
|
||||
|
||||
noised_averages = []
|
||||
for _ in xrange(1000):
|
||||
noised_averages.append(nest.flatten(sess.run(query_result)))
|
||||
|
||||
result_stddev = np.std(noised_averages, 0)
|
||||
avg_stddev = sum_stddev / denominator
|
||||
expected_stddev = [sum_stddev, avg_stddev, avg_stddev]
|
||||
self.assertArrayNear(result_stddev, expected_stddev, 0.1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('type_mismatch', [_basic_query], (1.0,), TypeError),
|
||||
('too_many_queries', [_basic_query, _basic_query], [1.0], ValueError),
|
||||
('too_many_records', [_basic_query, _basic_query],
|
||||
[1.0, 2.0, 3.0], ValueError),
|
||||
('query_too_deep', [_basic_query, [_basic_query]], [1.0, 1.0], TypeError))
|
||||
def test_record_incompatible_with_query(
|
||||
self, queries, record, error_type):
|
||||
with self.assertRaises(error_type):
|
||||
_run_query(nested_query.NestedQuery(queries), [record])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training a CNN on MNIST with differentially private Adam optimizer."""
|
||||
"""Training a CNN on MNIST with differentially private SGD optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -25,14 +25,14 @@ from privacy.analysis.rdp_accountant import compute_rdp
|
|||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from privacy.optimizers import dp_optimizer
|
||||
|
||||
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-Adam. If False,'
|
||||
'train with vanilla Adam.')
|
||||
tf.flags.DEFINE_float('learning_rate', 0.0015, 'Learning rate for training')
|
||||
tf.flags.DEFINE_float('noise_multiplier', 1.0,
|
||||
tf.flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False,'
|
||||
'train with vanilla SGD.')
|
||||
tf.flags.DEFINE_float('learning_rate', 0.08, 'Learning rate for training')
|
||||
tf.flags.DEFINE_float('noise_multiplier', 1.12,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
tf.flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
tf.flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
tf.flags.DEFINE_integer('epochs', 15, 'Number of epochs')
|
||||
tf.flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
tf.flags.DEFINE_integer('microbatches', 256,
|
||||
'Number of microbatches (must evenly divide batch_size')
|
||||
tf.flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
@ -69,18 +69,19 @@ def cnn_model_fn(features, labels, mode):
|
|||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
|
||||
if FLAGS.dpsgd:
|
||||
# Use DP version of AdamOptimizer. For illustration purposes, we do that
|
||||
# here by calling make_optimizer_class() explicitly, though DP versions
|
||||
# of standard optimizers are available in dp_optimizer.
|
||||
# Use DP version of GradientDescentOptimizer. For illustration purposes,
|
||||
# we do that here by calling make_optimizer_class() explicitly, though DP
|
||||
# versions of standard optimizers are available in dp_optimizer.
|
||||
dp_optimizer_class = dp_optimizer.make_optimizer_class(
|
||||
tf.train.AdamOptimizer)
|
||||
tf.train.GradientDescentOptimizer)
|
||||
optimizer = dp_optimizer_class(
|
||||
learning_rate=FLAGS.learning_rate,
|
||||
noise_multiplier=FLAGS.noise_multiplier,
|
||||
l2_norm_clip=FLAGS.l2_norm_clip,
|
||||
num_microbatches=FLAGS.microbatches)
|
||||
else:
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
optimizer = tf.train.GradientDescentOptimizer(
|
||||
learning_rate=FLAGS.learning_rate)
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(mode=mode,
|
||||
|
@ -177,7 +178,7 @@ def main(unused_argv):
|
|||
eps = compute_epsilon(epoch * steps_per_epoch)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
else:
|
||||
print('Trained with vanilla non-private Adam optimizer')
|
||||
print('Trained with vanilla non-private SGD optimizer')
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
||||
|
|
Loading…
Reference in a new issue